Compare commits

...

37 Commits

Author SHA1 Message Date
pascal
9b10d74ab8 copy zones and records for testing 2026-04-24 15:50:49 +02:00
pascal
312bcf6398 remove service user exception 2026-04-24 15:25:20 +02:00
pascal
381858a865 add accountID to test data 2026-04-23 17:53:04 +02:00
pascal
4051af499a fix error response 2026-04-23 16:29:43 +02:00
pascal
d774e9a62a Merge branch 'main' into refactor/permissions-manager
# Conflicts:
#	management/server/http/testing/testing_tools/channel/channel.go
#	management/server/policy.go
2026-04-23 16:06:01 +02:00
pascal
a4b55af99c add additional integration tests 2026-04-20 17:24:04 +02:00
pascal
470307079b update account isolation 2026-04-20 16:33:24 +02:00
pascal
0b04c0d03b leave todo comments 2026-04-20 15:03:31 +02:00
pascal
bed8d89d9f account for peers permission on the groups endpoint 2026-04-20 14:00:57 +02:00
pascal
b65a8bcb9c update sql test data files and auth wrappers 2026-04-20 13:44:38 +02:00
pascal
a53c38a6ed add sql test files 2026-04-17 17:09:38 +02:00
pascal
8a41117403 fix version test 2026-04-17 16:51:08 +02:00
pascal
e0063731f2 update tests 2026-04-17 16:44:57 +02:00
pascal
a572d8819f fix user handler test 2026-04-17 14:59:19 +02:00
pascal
d139a4fc42 fix client cmd test 2026-04-17 14:34:39 +02:00
pascal
dbe533327f fix client server test 2026-04-17 13:23:26 +02:00
pascal
4835909a04 fix client test 2026-04-17 13:16:33 +02:00
pascal
fb6447e714 fix engine test 2026-04-16 18:44:40 +02:00
pascal
16d6cc1265 update management integrations 2026-04-16 18:36:55 +02:00
pascal
53e47da7bd fix tests 2026-04-16 18:24:47 +02:00
pascal
ce522ea69b fix mock 2026-04-16 18:16:44 +02:00
pascal
5be80e976f fix mock 2026-04-16 17:55:14 +02:00
pascal
dd9539a9bf fix gomod 2026-04-16 17:48:11 +02:00
pascal
8530f9b8fc Merge branch 'main' into refactor/permissions-manager
# Conflicts:
#	go.mod
#	go.sum
#	management/server/account_test.go
#	management/server/http/testing/testing_tools/channel/channel.go
2026-04-16 17:47:30 +02:00
pascal
9406de9610 add exception handler and add exception for users 2026-04-16 17:42:59 +02:00
pascal
9e385eb540 add optional auth error handlers 2026-04-16 16:24:55 +02:00
pascal
e46ea895c1 Merge branch 'main' into refactor/permissions-manager
# Conflicts:
#	management/server/account_test.go
#	management/server/http/handlers/networks/routers_handler.go
2026-04-16 14:52:46 +02:00
pascal
31d901c4b0 update role permissions for admins 2026-04-08 12:15:01 +02:00
pascal
20e6dff507 Merge branch 'main' into refactor/permissions-manager 2026-04-07 17:35:39 +02:00
pascal
f5c8a6fe1a update integrations 2026-03-27 16:46:08 +01:00
pascal
beee14b9bf fix merge conflicts 2026-03-27 14:47:06 +01:00
pascal
3013c98ab5 Merge branch 'main' into refactor/permissions-manager
# Conflicts:
#	management/internals/modules/reverseproxy/service/manager/api.go
#	management/server/http/testing/testing_tools/channel/channel.go
2026-03-27 14:37:29 +01:00
pascal
3741eb46dd Merge remote-tracking branch 'origin/main' into refactor/permissions-manager
# Conflicts:
#	management/internals/modules/reverseproxy/domain/manager/manager.go
#	management/internals/modules/reverseproxy/service/manager/api.go
#	management/internals/server/modules.go
#	management/server/http/testing/testing_tools/channel/channel.go
2026-03-17 12:38:08 +01:00
pascal
d85ee0b5a2 remove old permissions management 2026-03-06 18:19:33 +01:00
pascal
da4a0eb68a Merge branch 'main' into refactor/permissions-manager 2026-03-06 17:06:18 +01:00
pascal
b0ce0048b4 Merge branch 'refs/heads/main' into refactor/permissions-manager
# Conflicts:
#	management/internals/modules/reverseproxy/service/manager/api.go
#	management/server/http/handler.go
2026-03-06 11:37:53 +01:00
pascal
32730af33f use api wrapper for permissions management 2026-02-24 00:39:54 +01:00
145 changed files with 4180 additions and 3657 deletions

View File

@@ -13,6 +13,8 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
@@ -29,7 +31,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@@ -97,7 +98,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl)
peersmanager := peers.NewManager(store, permissionsManagerMock) peersmanager := peers.NewManager(store)
settingsManagerMock := settings.NewMockManager(ctrl) settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersmanager) jobManager := job.NewJobManager(nil, store, peersmanager)

View File

@@ -25,6 +25,7 @@ import (
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
@@ -57,7 +58,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@@ -1632,7 +1632,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
} }
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager) peersManager := peers.NewManager(store)
jobManager := job.NewJobManager(nil, store, peersManager) jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)

View File

@@ -15,6 +15,8 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/peers"
@@ -38,7 +40,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@@ -305,7 +306,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl)
peersManager := peers.NewManager(store, permissionsManagerMock) peersManager := peers.NewManager(store)
settingsManagerMock := settings.NewMockManager(ctrl) settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersManager) jobManager := job.NewJobManager(nil, store, peersManager)

2
go.mod
View File

@@ -71,7 +71,7 @@ require (
github.com/mdlayher/socket v0.5.1 github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/oapi-codegen/runtime v1.1.2 github.com/oapi-codegen/runtime v1.1.2
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0

4
go.sum
View File

@@ -453,8 +453,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8= github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34 h1:g74mB64wnjCagzE1spKgPfTI/ont1SdSL3uX5bOecgM=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42/go.mod h1:n47r67ZSPgwSmT/Z1o48JjZQW9YJ6m/6Bd/uAXkL3Pg= github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34/go.mod h1:lCOq5d1i19AQjEEW2d7aNK0Nn0KC0MKyfMz/PLwVBFg=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -16,9 +16,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
@@ -38,17 +35,15 @@ type Manager interface {
type managerImpl struct { type managerImpl struct {
store store.Store store store.Store
permissionsManager permissions.Manager
integratedPeerValidator integrated_validator.IntegratedValidator integratedPeerValidator integrated_validator.IntegratedValidator
accountManager account.Manager accountManager account.Manager
networkMapController network_map.Controller networkMapController network_map.Controller
} }
func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { func NewManager(store store.Store) Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
permissionsManager: permissionsManager,
} }
} }
@@ -65,28 +60,10 @@ func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
} }
func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) { func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
} }
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
}
return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
} }

View File

@@ -4,20 +4,29 @@ package permissions
import ( import (
"context" "context"
"net/http"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/permissions/roles"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
// AuthErrorHandler is called when an auth error occurs during permission validation.
// If it returns true, the error is considered handled and the default error response is skipped.
type AuthErrorHandler func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth, err error) bool
type Manager interface { type Manager interface {
WithPermission(module modules.Module, operation operations.Operation, handlerFunc func(w http.ResponseWriter, r *http.Request, auth *auth.UserAuth), authErrHandler ...AuthErrorHandler) http.HandlerFunc
ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error) ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error)
ValidateRoleModuleAccess(ctx context.Context, accountID string, role roles.RolePermissions, module modules.Module, operation operations.Operation) bool ValidateRoleModuleAccess(ctx context.Context, accountID string, role roles.RolePermissions, module modules.Module, operation operations.Operation) bool
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
@@ -36,6 +45,51 @@ func NewManager(store store.Store) Manager {
} }
} }
// WithPermission wraps an HTTP handler with permission checking logic.
// An optional AuthErrorHandler can be provided to intercept auth errors before the default response is written.
func (m *managerImpl) WithPermission(
module modules.Module,
operation operations.Operation,
handlerFunc func(w http.ResponseWriter, r *http.Request, auth *auth.UserAuth),
authErrHandler ...AuthErrorHandler,
) http.HandlerFunc {
var onAuthErr AuthErrorHandler
if len(authErrHandler) > 0 {
onAuthErr = authErrHandler[0]
}
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get user auth from context: %v", err)
util.WriteError(r.Context(), err, w)
return
}
allowed, err := m.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, module, operation)
if err != nil {
if onAuthErr != nil && onAuthErr(w, r, &userAuth, err) {
return
}
log.WithContext(r.Context()).Errorf("failed to validate permissions for user %s on account %s: %v", userAuth.UserId, userAuth.AccountId, err)
util.WriteError(r.Context(), status.NewPermissionValidationError(err), w)
return
}
if !allowed {
permErr := status.NewPermissionDeniedError()
if onAuthErr != nil && onAuthErr(w, r, &userAuth, permErr) {
return
}
log.WithContext(r.Context()).Tracef("user %s on account %s is not allowed to %s in %s", userAuth.UserId, userAuth.AccountId, operation, module)
util.WriteError(r.Context(), permErr, w)
return
}
handlerFunc(w, r, &userAuth)
}
}
func (m *managerImpl) ValidateUserPermissions( func (m *managerImpl) ValidateUserPermissions(
ctx context.Context, ctx context.Context,
accountID string, accountID string,
@@ -68,10 +122,6 @@ func (m *managerImpl) ValidateUserPermissions(
return false, err return false, err
} }
if operation == operations.Read && user.IsServiceUser {
return true, nil // this should be replaced by proper granular access role
}
role, ok := roles.RolesMap[user.Role] role, ok := roles.RolesMap[user.Role]
if !ok { if !ok {
return false, status.NewUserRoleNotFoundError(string(user.Role)) return false, status.NewUserRoleNotFoundError(string(user.Role))
@@ -127,3 +177,17 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR
func (m *managerImpl) SetAccountManager(accountManager account.Manager) { func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
// no-op // no-op
} }
// WrapHandler wraps a handler that expects UserAuth with context extraction.
// Unlike WithPermission, it does not perform any permission checks.
func WrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get user auth from context: %v", err)
util.WriteError(r.Context(), err, w)
return
}
h(w, r, &userAuth)
}
}

View File

@@ -5,15 +5,18 @@
package permissions package permissions
import ( import (
context "context" "context"
reflect "reflect" "net/http"
"reflect"
gomock "github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
account "github.com/netbirdio/netbird/management/server/account"
modules "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
operations "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
roles "github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/internals/modules/permissions/roles"
types "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
) )
// MockManager is a mock of Manager interface. // MockManager is a mock of Manager interface.
@@ -108,3 +111,22 @@ func (mr *MockManagerMockRecorder) ValidateUserPermissions(ctx, accountID, userI
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateUserPermissions", reflect.TypeOf((*MockManager)(nil).ValidateUserPermissions), ctx, accountID, userID, module, operation) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateUserPermissions", reflect.TypeOf((*MockManager)(nil).ValidateUserPermissions), ctx, accountID, userID, module, operation)
} }
// WithPermission mocks base method.
func (m *MockManager) WithPermission(module modules.Module, operation operations.Operation, handlerFunc func(http.ResponseWriter, *http.Request, *auth.UserAuth), authErrHandler ...AuthErrorHandler) http.HandlerFunc {
m.ctrl.T.Helper()
varargs := []interface{}{module, operation, handlerFunc}
for _, a := range authErrHandler {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "WithPermission", varargs...)
ret0, _ := ret[0].(http.HandlerFunc)
return ret0
}
// WithPermission indicates an expected call of WithPermission.
func (mr *MockManagerMockRecorder) WithPermission(module, operation, handlerFunc interface{}, authErrHandler ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{module, operation, handlerFunc}, authErrHandler...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithPermission", reflect.TypeOf((*MockManager)(nil).WithPermission), varargs...)
}

View File

@@ -1,8 +1,8 @@
package roles package roles
import ( import (
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )
@@ -18,7 +18,7 @@ var Admin = RolePermissions{
modules.Accounts: { modules.Accounts: {
operations.Read: true, operations.Read: true,
operations.Create: false, operations.Create: false,
operations.Update: false, operations.Update: true,
operations.Delete: false, operations.Delete: false,
}, },
}, },

View File

@@ -1,7 +1,7 @@
package roles package roles
import ( import (
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -1,8 +1,8 @@
package roles package roles
import ( import (
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -1,7 +1,7 @@
package roles package roles
import ( import (
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -1,8 +1,8 @@
package roles package roles
import ( import (
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -1,7 +1,7 @@
package roles package roles
import ( import (
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -5,8 +5,11 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
) )
@@ -15,21 +18,15 @@ type handler struct {
manager accesslogs.Manager manager accesslogs.Manager
} }
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager) { func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager, permissionsManager permissions.Manager) {
h := &handler{ h := &handler{
manager: manager, manager: manager,
} }
router.HandleFunc("/events/proxy", h.getAccessLogs).Methods("GET", "OPTIONS") router.HandleFunc("/events/proxy", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAccessLogs)).Methods("GET", "OPTIONS")
} }
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) { func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var filter accesslogs.AccessLogFilter var filter accesslogs.AccessLogFilter
filter.ParseFromRequest(r) filter.ParseFromRequest(r)

View File

@@ -9,25 +9,19 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
) )
type managerImpl struct { type managerImpl struct {
store store.Store store store.Store
permissionsManager permissions.Manager geo geolocation.Geolocation
geo geolocation.Geolocation cleanupCancel context.CancelFunc
cleanupCancel context.CancelFunc
} }
func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager { func NewManager(store store.Store, geo geolocation.Geolocation) accesslogs.Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
permissionsManager: permissionsManager, geo: geo,
geo: geo,
} }
} }
@@ -63,14 +57,6 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering // GetAllAccessLogs retrieves access logs for an account with pagination and filtering
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) { func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, 0, status.NewPermissionValidationError(err)
}
if !ok {
return nil, 0, status.NewPermissionDeniedError()
}
if err := m.resolveUserFilters(ctx, accountID, filter); err != nil { if err := m.resolveUserFilters(ctx, accountID, filter); err != nil {
log.WithContext(ctx).Warnf("failed to resolve user filters: %v", err) log.WithContext(ctx).Warnf("failed to resolve user filters: %v", err)
} }

View File

@@ -6,8 +6,11 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -17,15 +20,15 @@ type handler struct {
manager Manager manager Manager
} }
func RegisterEndpoints(router *mux.Router, manager Manager) { func RegisterEndpoints(router *mux.Router, manager Manager, permissionsManager permissions.Manager) {
h := &handler{ h := &handler{
manager: manager, manager: manager,
} }
router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS") router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllDomains)).Methods("GET", "OPTIONS")
router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS") router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Create, h.createCustomDomain)).Methods("POST", "OPTIONS")
router.HandleFunc("/domains/{domainId}", h.deleteCustomDomain).Methods("DELETE", "OPTIONS") router.HandleFunc("/domains/{domainId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteCustomDomain)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS") router.HandleFunc("/domains/{domainId}/validate", permissionsManager.WithPermission(modules.Services, operations.Create, h.triggerCustomDomainValidation)).Methods("GET", "OPTIONS") // TODO: this should be a POST
} }
func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType { func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType {
@@ -56,13 +59,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
return resp return resp
} }
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId) domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
@@ -77,13 +74,7 @@ func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, ret) util.WriteJSONObject(r.Context(), w, ret)
} }
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) { func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.PostApiReverseProxiesDomainsJSONRequestBody var req api.PostApiReverseProxiesDomainsJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -99,13 +90,7 @@ func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, domainToApi(domain)) util.WriteJSONObject(r.Context(), w, domainToApi(domain))
} }
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
domainID := mux.Vars(r)["domainId"] domainID := mux.Vars(r)["domainId"]
if domainID == "" { if domainID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
@@ -120,13 +105,7 @@ func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
} }
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request) { func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
domainID := mux.Vars(r)["domainId"] domainID := mux.Vars(r)["domainId"]
if domainID == "" { if domainID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)

View File

@@ -11,11 +11,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
) )
type store interface { type store interface {
@@ -37,32 +33,22 @@ type proxyManager interface {
} }
type Manager struct { type Manager struct {
store store store store
validator domain.Validator validator domain.Validator
proxyManager proxyManager proxyManager proxyManager
permissionsManager permissions.Manager accountManager account.Manager
accountManager account.Manager
} }
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager, accountManager account.Manager) Manager { func NewManager(store store, proxyMgr proxyManager, accountManager account.Manager) Manager {
return Manager{ return Manager{
store: store, store: store,
proxyManager: proxyMgr, proxyManager: proxyMgr,
validator: domain.Validator{Resolver: net.DefaultResolver}, validator: domain.Validator{Resolver: net.DefaultResolver},
permissionsManager: permissionsManager, accountManager: accountManager,
accountManager: accountManager,
} }
} }
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) { func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
domains, err := m.store.ListCustomDomains(ctx, accountID) domains, err := m.store.ListCustomDomains(ctx, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("list custom domains: %w", err) return nil, fmt.Errorf("list custom domains: %w", err)
@@ -118,14 +104,6 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
} }
func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) { func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
// Verify the target cluster is in the available clusters // Verify the target cluster is in the available clusters
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil { if err != nil {
@@ -159,14 +137,6 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
} }
func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error { func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
d, err := m.store.GetCustomDomain(ctx, accountID, domainID) d, err := m.store.GetCustomDomain(ctx, accountID, domainID)
if err != nil { if err != nil {
return fmt.Errorf("get domain from store: %w", err) return fmt.Errorf("get domain from store: %w", err)
@@ -183,21 +153,6 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s
} }
func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) { func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).WithError(err).Error("validate domain")
return
}
if !ok {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).WithError(err).Error("validate domain")
}
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"accountID": accountID, "accountID": accountID,
"domainID": domainID, "domainID": domainID,

View File

@@ -23,7 +23,7 @@ type Proxy struct {
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"` LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time ConnectedAt *time.Time
DisconnectedAt *time.Time DisconnectedAt *time.Time
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"` Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
Capabilities Capabilities `gorm:"embedded"` Capabilities Capabilities `gorm:"embedded"`
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time

View File

@@ -6,12 +6,14 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -30,25 +32,19 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma
} }
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter() domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
domainmanager.RegisterEndpoints(domainRouter, domainManager) domainmanager.RegisterEndpoints(domainRouter, domainManager, permissionsManager)
accesslogsmanager.RegisterEndpoints(router, accessLogsManager) accesslogsmanager.RegisterEndpoints(router, accessLogsManager, permissionsManager)
router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/clusters", permissionsManager.WithPermission(modules.Services, operations.Read, h.getClusters)).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllServices)).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS") router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Create, h.createService)).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Read, h.getService)).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS") router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Update, h.updateService)).Methods("PUT", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS") router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteService)).Methods("DELETE", "OPTIONS")
} }
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId) allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
@@ -63,13 +59,7 @@ func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, apiServices) util.WriteJSONObject(r.Context(), w, apiServices)
} }
func (h *handler) createService(w http.ResponseWriter, r *http.Request) { func (h *handler) createService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.ServiceRequest var req api.ServiceRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -77,12 +67,13 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
} }
service := new(rpservice.Service) service := new(rpservice.Service)
var err error
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil { if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return return
} }
if err = service.Validate(); err != nil { if err := service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return return
} }
@@ -96,13 +87,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse())
} }
func (h *handler) getService(w http.ResponseWriter, r *http.Request) { func (h *handler) getService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
serviceID := mux.Vars(r)["serviceId"] serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" { if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
@@ -118,13 +103,7 @@ func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, service.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, service.ToAPIResponse())
} }
func (h *handler) updateService(w http.ResponseWriter, r *http.Request) { func (h *handler) updateService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
serviceID := mux.Vars(r)["serviceId"] serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" { if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
@@ -139,12 +118,13 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
service := new(rpservice.Service) service := new(rpservice.Service)
service.ID = serviceID service.ID = serviceID
var err error
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil { if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return return
} }
if err = service.Validate(); err != nil { if err := service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return return
} }
@@ -158,13 +138,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse())
} }
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
serviceID := mux.Vars(r)["serviceId"] serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" { if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
@@ -179,13 +153,7 @@ func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) { func (h *handler) getClusters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId) clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)

View File

@@ -15,7 +15,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )
@@ -86,18 +85,17 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
accountMgr := &mock_server.MockAccountManager{ accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { GetGroupByNameFunc: func(ctx context.Context, groupName, accountID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
}, },
} }
mgr := &Manager{ mgr := &Manager{
store: testStore, store: testStore,
accountManager: accountMgr, accountManager: accountMgr,
permissionsManager: permissions.NewManager(testStore), proxyController: mockCtrl,
proxyController: mockCtrl, capabilities: mockCaps,
capabilities: mockCaps, clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
} }
mgr.exposeReaper = &exposeReaper{manager: mgr} mgr.exposeReaper = &exposeReaper{manager: mgr}

View File

@@ -21,9 +21,6 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
@@ -82,24 +79,22 @@ type CapabilityProvider interface {
} }
type Manager struct { type Manager struct {
store store.Store store store.Store
accountManager account.Manager accountManager account.Manager
permissionsManager permissions.Manager proxyController proxy.Controller
proxyController proxy.Controller capabilities CapabilityProvider
capabilities CapabilityProvider clusterDeriver ClusterDeriver
clusterDeriver ClusterDeriver exposeReaper *exposeReaper
exposeReaper *exposeReaper
} }
// NewManager creates a new service manager. // NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager { func NewManager(store store.Store, accountManager account.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager {
mgr := &Manager{ mgr := &Manager{
store: store, store: store,
accountManager: accountManager, accountManager: accountManager,
permissionsManager: permissionsManager, proxyController: proxyController,
proxyController: proxyController, capabilities: capabilities,
capabilities: capabilities, clusterDeriver: clusterDeriver,
clusterDeriver: clusterDeriver,
} }
mgr.exposeReaper = &exposeReaper{manager: mgr} mgr.exposeReaper = &exposeReaper{manager: mgr}
return mgr return mgr
@@ -112,26 +107,10 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
// GetActiveClusters returns all active proxy clusters with their connected proxy count. // GetActiveClusters returns all active proxy clusters with their connected proxy count.
func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) { func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetActiveProxyClusters(ctx) return m.store.GetActiveProxyClusters(ctx)
} }
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) { func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err) return nil, fmt.Errorf("failed to get services: %w", err)
@@ -185,14 +164,6 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
} }
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) { func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err) return nil, fmt.Errorf("failed to get service: %w", err)
@@ -206,14 +177,6 @@ func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID s
} }
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) { func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil { if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
return nil, err return nil, err
} }
@@ -224,7 +187,7 @@ func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s
m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta()) m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, s) err := m.replaceHostByLookup(ctx, accountID, s)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err) return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
} }
@@ -491,14 +454,6 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St
} }
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) { func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := service.Auth.HashSecrets(); err != nil { if err := service.Auth.HashSecrets(); err != nil {
return nil, fmt.Errorf("hash secrets: %w", err) return nil, fmt.Errorf("hash secrets: %w", err)
} }
@@ -785,16 +740,8 @@ func validateResourceTargetType(target *service.Target, resource *resourcetypes.
} }
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var s *service.Service var s *service.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error var err error
s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil { if err != nil {
@@ -825,16 +772,8 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
} }
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error { func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var services []*service.Service var services []*service.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error var err error
services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID) services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID)
if err != nil { if err != nil {
@@ -1119,7 +1058,7 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
} }
groupIDs := make([]string, 0, len(groupNames)) groupIDs := make([]string, 0, len(groupNames))
for _, groupName := range groupNames { for _, groupName := range groupNames {
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID, activity.SystemInitiator) g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err) return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
} }

View File

@@ -23,9 +23,6 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -700,12 +697,10 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID) err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID)
require.NoError(t, err) require.NoError(t, err)
permsMgr := permissions.NewManager(testStore)
accountMgr := &mock_server.MockAccountManager{ accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { GetGroupByNameFunc: func(ctx context.Context, groupName, accountID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
}, },
} }
@@ -718,10 +713,9 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
require.NoError(t, err) require.NoError(t, err)
mgr := &Manager{ mgr := &Manager{
store: testStore, store: testStore,
accountManager: accountMgr, accountManager: accountMgr,
permissionsManager: permsMgr, proxyController: proxyController,
proxyController: proxyController,
clusterDeriver: &testClusterDeriver{ clusterDeriver: &testClusterDeriver{
domains: []string{"test.netbird.io"}, domains: []string{"test.netbird.io"},
}, },
@@ -1130,7 +1124,6 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
mockPerms := permissions.NewMockManager(ctrl)
mockAcct := account.NewMockManager(ctrl) mockAcct := account.NewMockManager(ctrl)
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
@@ -1141,10 +1134,9 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
mgr := &Manager{ mgr := &Manager{
store: sqlStore, store: sqlStore,
permissionsManager: mockPerms, accountManager: mockAcct,
accountManager: mockAcct, proxyController: proxyController,
proxyController: proxyController,
} }
service := &rpservice.Service{ service := &rpservice.Service{
@@ -1167,9 +1159,6 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, retrievedService.Targets, 3, "Service should have 3 targets before deletion") require.Len(t, retrievedService.Targets, 3, "Service should have 3 targets before deletion")
mockPerms.EXPECT().
ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete).
Return(true, nil)
mockAcct.EXPECT(). mockAcct.EXPECT().
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any()) StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
mockAcct.EXPECT(). mockAcct.EXPECT().

View File

@@ -6,8 +6,11 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones"
nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -17,25 +20,19 @@ type handler struct {
manager zones.Manager manager zones.Manager
} }
func RegisterEndpoints(router *mux.Router, manager zones.Manager) { func RegisterEndpoints(router *mux.Router, manager zones.Manager, permissionsManager permissions.Manager) {
h := &handler{ h := &handler{
manager: manager, manager: manager,
} }
router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS") router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllZones)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS") router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createZone)).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS") router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getZone)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS") router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateZone)).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS") router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteZone)).Methods("DELETE", "OPTIONS")
} }
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId) allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
@@ -50,13 +47,7 @@ func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, apiZones) util.WriteJSONObject(r.Context(), w, apiZones)
} }
func (h *handler) createZone(w http.ResponseWriter, r *http.Request) { func (h *handler) createZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.PostApiDnsZonesJSONRequestBody var req api.PostApiDnsZonesJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -66,7 +57,7 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
zone := new(zones.Zone) zone := new(zones.Zone)
zone.FromAPIRequest(&req) zone.FromAPIRequest(&req)
if err = zone.Validate(); err != nil { if err := zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return return
} }
@@ -80,13 +71,7 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse())
} }
func (h *handler) getZone(w http.ResponseWriter, r *http.Request) { func (h *handler) getZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"] zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" { if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -102,13 +87,7 @@ func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse())
} }
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) { func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"] zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" { if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -116,7 +95,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
} }
var req api.PutApiDnsZonesZoneIdJSONRequestBody var req api.PutApiDnsZonesZoneIdJSONRequestBody
if err = json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
@@ -125,7 +104,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
zone.FromAPIRequest(&req) zone.FromAPIRequest(&req)
zone.ID = zoneID zone.ID = zoneID
if err = zone.Validate(); err != nil { if err := zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return return
} }
@@ -139,20 +118,14 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse())
} }
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"] zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" { if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return return
} }
if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil { if err := h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }

View File

@@ -7,62 +7,34 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
type managerImpl struct { type managerImpl struct {
store store.Store store store.Store
accountManager account.Manager accountManager account.Manager
permissionsManager permissions.Manager dnsDomain string
dnsDomain string
} }
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, dnsDomain string) zones.Manager { func NewManager(store store.Store, accountManager account.Manager, dnsDomain string) zones.Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
accountManager: accountManager, accountManager: accountManager,
permissionsManager: permissionsManager, dnsDomain: dnsDomain,
dnsDomain: dnsDomain,
} }
} }
func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) { func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID) return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
} }
func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) { func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID) return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID)
} }
func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) { func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create) var err error
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil { if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil {
return nil, err return nil, err
} }
@@ -102,14 +74,6 @@ func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string,
} }
func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) { func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID) zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get zone: %w", err) return nil, fmt.Errorf("failed to get zone: %w", err)
@@ -150,14 +114,6 @@ func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string,
} }
func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error { func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID) zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get zone: %w", err) return fmt.Errorf("failed to get zone: %w", err)

View File

@@ -13,9 +13,6 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -29,7 +26,7 @@ const (
testDNSDomain = "netbird.selfhosted" testDNSDomain = "netbird.selfhosted"
) )
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) { func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *gomock.Controller, func()) {
t.Helper() t.Helper()
ctx := context.Background() ctx := context.Background()
@@ -49,23 +46,17 @@ func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccoun
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
mockAccountManager := &mock_server.MockAccountManager{} mockAccountManager := &mock_server.MockAccountManager{}
mockPermissionsManager := permissions.NewMockManager(ctrl)
manager := &managerImpl{ manager := NewManager(testStore, mockAccountManager, testDNSDomain).(*managerImpl)
store: testStore,
accountManager: mockAccountManager,
permissionsManager: mockPermissionsManager,
dnsDomain: testDNSDomain,
}
return manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup return manager, testStore, mockAccountManager, ctrl, cleanup
} }
func TestManagerImpl_GetAllZones(t *testing.T) { func TestManagerImpl_GetAllZones(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, testStore, _, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -77,10 +68,6 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
err = testStore.CreateZone(ctx, zone2) err = testStore.CreateZone(ctx, zone2)
require.NoError(t, err) require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID) result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, result, 2) assert.Len(t, result, 2)
@@ -88,43 +75,13 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
assert.Equal(t, zone2.ID, result[1].ID) assert.Equal(t, zone2.ID, result[1].ID)
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("permission validation error", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
assert.Nil(t, result)
})
} }
func TestManagerImpl_GetZone(t *testing.T) { func TestManagerImpl_GetZone(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, testStore, _, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -132,10 +89,6 @@ func TestManagerImpl_GetZone(t *testing.T) {
err := testStore.CreateZone(ctx, zone) err := testStore.CreateZone(ctx, zone)
require.NoError(t, err) require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID) result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, zone.ID, result.ID) assert.Equal(t, zone.ID, result.ID)
@@ -143,29 +96,13 @@ func TestManagerImpl_GetZone(t *testing.T) {
assert.Equal(t, zone.Domain, result.Domain) assert.Equal(t, zone.Domain, result.Domain)
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
} }
func TestManagerImpl_CreateZone(t *testing.T) { func TestManagerImpl_CreateZone(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
manager, _, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, _, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -177,10 +114,6 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{testGroupID}, DistributionGroups: []string{testGroupID},
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID) assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID) assert.Equal(t, testAccountID, accountID)
@@ -199,31 +132,8 @@ func TestManagerImpl_CreateZone(t *testing.T) {
assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups) assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups)
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "new.example.com",
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("invalid group", func(t *testing.T) { t.Run("invalid group", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, _, _, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -233,17 +143,13 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{"invalid-group"}, DistributionGroups: []string{"invalid-group"},
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
}) })
t.Run("duplicate domain", func(t *testing.T) { t.Run("duplicate domain", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, testStore, _, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -259,10 +165,6 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{testGroupID}, DistributionGroups: []string{testGroupID},
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -273,7 +175,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
}) })
t.Run("peer DNS domain conflict", func(t *testing.T) { t.Run("peer DNS domain conflict", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, testStore, _, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -291,10 +193,6 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{testGroupID}, DistributionGroups: []string{testGroupID},
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -305,7 +203,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
}) })
t.Run("default DNS domain conflict", func(t *testing.T) { t.Run("default DNS domain conflict", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, _, _, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -317,10 +215,6 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{testGroupID}, DistributionGroups: []string{testGroupID},
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -335,7 +229,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -352,10 +246,6 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
DistributionGroups: []string{testGroupID}, DistributionGroups: []string{testGroupID},
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
storeEventCalled := false storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true storeEventCalled = true
@@ -375,7 +265,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
}) })
t.Run("domain change not allowed", func(t *testing.T) { t.Run("domain change not allowed", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, testStore, _, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -392,10 +282,6 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
DistributionGroups: []string{testGroupID}, DistributionGroups: []string{testGroupID},
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone) result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -405,31 +291,8 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
assert.Equal(t, status.InvalidArgument, s.Type()) assert.Equal(t, status.InvalidArgument, s.Type())
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedZone := &zones.Zone{
ID: testZoneID,
Name: "Updated Name",
Domain: "example.com",
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("zone not found", func(t *testing.T) { t.Run("zone not found", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, _, _, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -439,10 +302,6 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
Domain: "example.com", Domain: "example.com",
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone) result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -453,7 +312,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success with records", func(t *testing.T) { t.Run("success with records", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -469,10 +328,6 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
err = testStore.CreateDNSRecord(ctx, record2) err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err) require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCallCount := 0 storeEventCallCount := 0
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCallCount++ storeEventCallCount++
@@ -493,7 +348,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
}) })
t.Run("success without records", func(t *testing.T) { t.Run("success without records", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -501,10 +356,6 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
err := testStore.CreateZone(ctx, zone) err := testStore.CreateZone(ctx, zone)
require.NoError(t, err) require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCalled := false storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true storeEventCalled = true
@@ -522,31 +373,11 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
require.Error(t, err) require.Error(t, err)
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("zone not found", func(t *testing.T) { t.Run("zone not found", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, _, _, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone") err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone")
require.Error(t, err) require.Error(t, err)
}) })

View File

@@ -6,8 +6,11 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -17,25 +20,19 @@ type handler struct {
manager records.Manager manager records.Manager
} }
func RegisterEndpoints(router *mux.Router, manager records.Manager) { func RegisterEndpoints(router *mux.Router, manager records.Manager, permissionsManager permissions.Manager) {
h := &handler{ h := &handler{
manager: manager, manager: manager,
} }
router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS") router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllRecords)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS") router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createRecord)).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS") router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getRecord)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS") router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateRecord)).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS") router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteRecord)).Methods("DELETE", "OPTIONS")
} }
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"] zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" { if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -56,13 +53,7 @@ func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, apiRecords) util.WriteJSONObject(r.Context(), w, apiRecords)
} }
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) { func (h *handler) createRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"] zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" { if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -78,7 +69,7 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
record := new(records.Record) record := new(records.Record)
record.FromAPIRequest(&req) record.FromAPIRequest(&req)
if err = record.Validate(); err != nil { if err := record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return return
} }
@@ -92,13 +83,7 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse())
} }
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) { func (h *handler) getRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"] zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" { if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -120,13 +105,7 @@ func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, record.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, record.ToAPIResponse())
} }
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) { func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"] zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" { if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -140,7 +119,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
} }
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
if err = json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
@@ -149,7 +128,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
record.FromAPIRequest(&req) record.FromAPIRequest(&req)
record.ID = recordID record.ID = recordID
if err = record.Validate(); err != nil { if err := record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return return
} }
@@ -163,13 +142,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse())
} }
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"] zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" { if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -182,7 +155,7 @@ func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
return return
} }
if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil { if err := h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }

View File

@@ -9,64 +9,36 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
type managerImpl struct { type managerImpl struct {
store store.Store store store.Store
accountManager account.Manager accountManager account.Manager
permissionsManager permissions.Manager
} }
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager) records.Manager { func NewManager(store store.Store, accountManager account.Manager) records.Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
accountManager: accountManager, accountManager: accountManager,
permissionsManager: permissionsManager,
} }
} }
func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) { func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID) return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
} }
func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) { func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID) return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID)
} }
func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) { func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var zone *zones.Zone var zone *zones.Zone
record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL) record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL)
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID) zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get zone: %w", err) return fmt.Errorf("failed to get zone: %w", err)
@@ -101,18 +73,11 @@ func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneI
} }
func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) { func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var zone *zones.Zone var zone *zones.Zone
var record *records.Record var record *records.Record
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID) zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get zone: %w", err) return fmt.Errorf("failed to get zone: %w", err)
@@ -160,18 +125,11 @@ func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneI
} }
func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error { func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var record *records.Record var record *records.Record
var zone *zones.Zone var zone *zones.Zone
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID) zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get zone: %w", err) return fmt.Errorf("failed to get zone: %w", err)

View File

@@ -12,12 +12,8 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (
@@ -27,7 +23,7 @@ const (
testGroupID = "test-group-id" testGroupID = "test-group-id"
) )
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) { func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *gomock.Controller, func()) {
t.Helper() t.Helper()
ctx := context.Background() ctx := context.Background()
@@ -51,22 +47,17 @@ func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_serv
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
mockAccountManager := &mock_server.MockAccountManager{} mockAccountManager := &mock_server.MockAccountManager{}
mockPermissionsManager := permissions.NewMockManager(ctrl)
manager := &managerImpl{ manager := NewManager(testStore, mockAccountManager).(*managerImpl)
store: testStore,
accountManager: mockAccountManager,
permissionsManager: mockPermissionsManager,
}
return manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup return manager, testStore, zone, mockAccountManager, ctrl, cleanup
} }
func TestManagerImpl_GetAllRecords(t *testing.T) { func TestManagerImpl_GetAllRecords(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -78,10 +69,6 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
err = testStore.CreateDNSRecord(ctx, record2) err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err) require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID) result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, result, 2) assert.Len(t, result, 2)
@@ -89,43 +76,13 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
assert.Equal(t, record2.ID, result[1].ID) assert.Equal(t, record2.ID, result[1].ID)
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("permission validation error", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
assert.Nil(t, result)
})
} }
func TestManagerImpl_GetRecord(t *testing.T) { func TestManagerImpl_GetRecord(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -133,10 +90,6 @@ func TestManagerImpl_GetRecord(t *testing.T) {
err := testStore.CreateDNSRecord(ctx, record) err := testStore.CreateDNSRecord(ctx, record)
require.NoError(t, err) require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID) result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, record.ID, result.ID) assert.Equal(t, record.ID, result.ID)
@@ -146,29 +99,13 @@ func TestManagerImpl_GetRecord(t *testing.T) {
assert.Equal(t, record.TTL, result.TTL) assert.Equal(t, record.TTL, result.TTL)
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
} }
func TestManagerImpl_CreateRecord(t *testing.T) { func TestManagerImpl_CreateRecord(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success - A record", func(t *testing.T) { t.Run("success - A record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -179,10 +116,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300, TTL: 300,
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID) assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID) assert.Equal(t, testAccountID, accountID)
@@ -202,7 +135,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
}) })
t.Run("success - AAAA record", func(t *testing.T) { t.Run("success - AAAA record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -213,10 +146,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 600, TTL: 600,
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID) assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID) assert.Equal(t, testAccountID, accountID)
@@ -231,7 +160,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
}) })
t.Run("success - CNAME record", func(t *testing.T) { t.Run("success - CNAME record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -242,10 +171,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300, TTL: 300,
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID) assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID) assert.Equal(t, testAccountID, accountID)
@@ -259,32 +184,8 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
assert.Equal(t, inputRecord.Content, result.Content) assert.Equal(t, inputRecord.Content, result.Content)
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record name not in zone", func(t *testing.T) { t.Run("record name not in zone", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, _, zone, _, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -295,10 +196,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300, TTL: 300,
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord) result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -306,7 +203,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
}) })
t.Run("duplicate record", func(t *testing.T) { t.Run("duplicate record", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -321,10 +218,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300, TTL: 300,
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord) result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -332,7 +225,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
}) })
t.Run("CNAME conflict with existing A record", func(t *testing.T) { t.Run("CNAME conflict with existing A record", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -347,10 +240,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300, TTL: 300,
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord) result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -362,7 +251,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -378,10 +267,6 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
TTL: 600, // Changed TTL TTL: 600, // Changed TTL
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
storeEventCalled := false storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true storeEventCalled = true
@@ -400,7 +285,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
}) })
t.Run("update only TTL - no validation", func(t *testing.T) { t.Run("update only TTL - no validation", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -416,10 +301,6 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
TTL: 600, // Only TTL changed TTL: 600, // Only TTL changed
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
// Event should be stored // Event should be stored
} }
@@ -430,33 +311,8 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
assert.Equal(t, 600, result.TTL) assert.Equal(t, 600, result.TTL)
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedRecord := &records.Record{
ID: testRecordID,
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.100",
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record not found", func(t *testing.T) { t.Run("record not found", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, _, zone, _, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -468,17 +324,13 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
TTL: 600, TTL: 600,
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord) result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
}) })
t.Run("update creates duplicate", func(t *testing.T) { t.Run("update creates duplicate", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -498,10 +350,6 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
TTL: 300, TTL: 300,
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord) result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -513,7 +361,7 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -521,10 +369,6 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
err := testStore.CreateDNSRecord(ctx, record) err := testStore.CreateDNSRecord(ctx, record)
require.NoError(t, err) require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCalled := false storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true storeEventCalled = true
@@ -542,31 +386,11 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
require.Error(t, err) require.Error(t, err)
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record not found", func(t *testing.T) { t.Run("record not found", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) manager, _, zone, _, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record") err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record")
require.Error(t, err) require.Error(t, err)
}) })

View File

@@ -233,7 +233,7 @@ func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore {
func (s *BaseServer) AccessLogsManager() accesslogs.Manager { func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
return Create(s, func() accesslogs.Manager { return Create(s, func() accesslogs.Manager {
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager()) accessLogManager := accesslogsmanager.NewManager(s.Store(), s.GeoLocationManager())
accessLogManager.StartPeriodicCleanup( accessLogManager.StartPeriodicCleanup(
context.Background(), context.Background(),
s.Config.ReverseProxy.AccessLogRetentionDays, s.Config.ReverseProxy.AccessLogRetentionDays,

View File

@@ -9,6 +9,7 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
@@ -27,7 +28,6 @@ import (
"github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
) )
@@ -82,13 +82,13 @@ func (s *BaseServer) SettingsManager() settings.Manager {
idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled
} }
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig) return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, idpConfig)
}) })
} }
func (s *BaseServer) PeersManager() peers.Manager { func (s *BaseServer) PeersManager() peers.Manager {
return Create(s, func() peers.Manager { return Create(s, func() peers.Manager {
manager := peers.NewManager(s.Store(), s.PermissionsManager()) manager := peers.NewManager(s.Store())
s.AfterInit(func(s *BaseServer) { s.AfterInit(func(s *BaseServer) {
manager.SetNetworkMapController(s.NetworkMapController()) manager.SetNetworkMapController(s.NetworkMapController())
manager.SetIntegratedPeerValidator(s.IntegratedValidator()) manager.SetIntegratedPeerValidator(s.IntegratedValidator())
@@ -161,43 +161,43 @@ func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
func (s *BaseServer) GroupsManager() groups.Manager { func (s *BaseServer) GroupsManager() groups.Manager {
return Create(s, func() groups.Manager { return Create(s, func() groups.Manager {
return groups.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager()) return groups.NewManager(s.Store(), s.AccountManager())
}) })
} }
func (s *BaseServer) ResourcesManager() resources.Manager { func (s *BaseServer) ResourcesManager() resources.Manager {
return Create(s, func() resources.Manager { return Create(s, func() resources.Manager {
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager()) return resources.NewManager(s.Store(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
}) })
} }
func (s *BaseServer) RoutesManager() routers.Manager { func (s *BaseServer) RoutesManager() routers.Manager {
return Create(s, func() routers.Manager { return Create(s, func() routers.Manager {
return routers.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager()) return routers.NewManager(s.Store(), s.AccountManager())
}) })
} }
func (s *BaseServer) NetworksManager() networks.Manager { func (s *BaseServer) NetworksManager() networks.Manager {
return Create(s, func() networks.Manager { return Create(s, func() networks.Manager {
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager()) return networks.NewManager(s.Store(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
}) })
} }
func (s *BaseServer) ZonesManager() zones.Manager { func (s *BaseServer) ZonesManager() zones.Manager {
return Create(s, func() zones.Manager { return Create(s, func() zones.Manager {
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain()) return zonesManager.NewManager(s.Store(), s.AccountManager(), s.DNSDomain())
}) })
} }
func (s *BaseServer) RecordsManager() records.Manager { func (s *BaseServer) RecordsManager() records.Manager {
return Create(s, func() records.Manager { return Create(s, func() records.Manager {
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager()) return recordsManager.NewManager(s.Store(), s.AccountManager())
}) })
} }
func (s *BaseServer) ServiceManager() service.Manager { func (s *BaseServer) ServiceManager() service.Manager {
return Create(s, func() service.Manager { return Create(s, func() service.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager()) return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager())
}) })
} }
@@ -213,7 +213,7 @@ func (s *BaseServer) ProxyManager() proxy.Manager {
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return Create(s, func() *manager.Manager { return Create(s, func() *manager.Manager {
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager()) m := manager.NewManager(s.Store(), s.ProxyManager(), s.AccountManager())
return &m return &m
}) })
} }

View File

@@ -15,6 +15,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth"
@@ -39,9 +40,6 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -282,22 +280,14 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
// User that performs the update has to belong to the account. // User that performs the update has to belong to the account.
// Returns an updated Settings // Returns an updated Settings
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) { func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
var oldSettings *types.Settings var oldSettings *types.Settings
var updateAccountPeers bool var updateAccountPeers bool
var groupChangesAffectPeers bool var groupChangesAffectPeers bool
var reloadReverseProxy bool var reloadReverseProxy bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var groupsUpdated bool var groupsUpdated bool
var err error
oldSettings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID) oldSettings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID)
if err != nil { if err != nil {
@@ -725,15 +715,6 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return err return err
} }
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Delete)
if err != nil {
return fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
}
userInfosMap, err := am.BuildUserInfosForAccount(ctx, accountID, userID, maps.Values(account.Users)) userInfosMap, err := am.BuildUserInfosForAccount(ctx, accountID, userID, maps.Values(account.Users))
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err) return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
@@ -976,6 +957,10 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
return nil, err return nil, err
} }
if user.AccountID != accountID {
return nil, fmt.Errorf("user %s does not belong to account %s", userID, accountID)
}
key := user.IntegrationReference.CacheKey(accountID, userID) key := user.IntegrationReference.CacheKey(accountID, userID)
ud, err := am.externalCacheManager.Get(am.ctx, key) ud, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil { if err != nil {
@@ -1287,41 +1272,16 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
// GetAccountByID returns an account associated with this account ID. // GetAccountByID returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) { func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccount(ctx, accountID) return am.Store.GetAccount(ctx, accountID)
} }
// GetAccountMeta returns the account metadata associated with this account ID. // GetAccountMeta returns the account metadata associated with this account ID.
func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) { func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID) return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID)
} }
// GetAccountOnboarding retrieves the onboarding information for a specific account. // GetAccountOnboarding retrieves the onboarding information for a specific account.
func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) { func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() { if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
log.Errorf("failed to get account onboarding for account %s: %v", accountID, err) log.Errorf("failed to get account onboarding for account %s: %v", accountID, err)
@@ -1338,15 +1298,6 @@ func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accou
} }
func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() { if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
return nil, fmt.Errorf("failed to get account onboarding: %w", err) return nil, fmt.Errorf("failed to get account onboarding: %w", err)
@@ -1405,9 +1356,8 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return accountID, user.Id, nil return accountID, user.Id, nil
} }
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { // Permission checks are now handled by the HTTP middleware via WithPermission wrapper
return "", "", err // User account association is already validated above by GetUserByUserID
}
if !user.IsServiceUser && userAuth.Invited { if !user.IsServiceUser && userAuth.Invited {
err = am.redeemInvite(ctx, accountID, user.Id) err = am.redeemInvite(ctx, accountID, user.Id)
@@ -1849,13 +1799,6 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
} }
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
} }
@@ -2197,14 +2140,6 @@ func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, pee
} }
func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error { func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
if err != nil {
return fmt.Errorf("validate user permissions: %w", err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
updateNetworkMap, err := am.updatePeerIPInTransaction(ctx, accountID, userID, peerID, newIP) updateNetworkMap, err := am.updatePeerIPInTransaction(ctx, accountID, userID, peerID, newIP)
if err != nil { if err != nil {
return fmt.Errorf("update peer IP transaction: %w", err) return fmt.Errorf("update peer IP transaction: %w", err)

View File

@@ -60,7 +60,7 @@ type Manager interface {
GetUserByID(ctx context.Context, id string) (*types.User, error) GetUserByID(ctx context.Context, id string) (*types.User, error)
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error) ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string, all bool) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
@@ -75,7 +75,7 @@ type Manager interface {
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error

View File

@@ -736,18 +736,18 @@ func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID inte
} }
// GetGroupByName mocks base method. // GetGroupByName mocks base method.
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID) ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID)
ret0, _ := ret[0].(*types.Group) ret0, _ := ret[0].(*types.Group)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetGroupByName indicates an expected call of GetGroupByName. // GetGroupByName indicates an expected call of GetGroupByName.
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call { func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID)
} }
// GetIdentityProvider mocks base method. // GetIdentityProvider mocks base method.
@@ -946,18 +946,18 @@ func (mr *MockManagerMockRecorder) GetPeerNetwork(ctx, peerID interface{}) *gomo
} }
// GetPeers mocks base method. // GetPeers mocks base method.
func (m *MockManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*peer.Peer, error) { func (m *MockManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string, all bool) ([]*peer.Peer, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeers", ctx, accountID, userID, nameFilter, ipFilter) ret := m.ctrl.Call(m, "GetPeers", ctx, accountID, userID, nameFilter, ipFilter, all)
ret0, _ := ret[0].([]*peer.Peer) ret0, _ := ret[0].([]*peer.Peer)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetPeers indicates an expected call of GetPeers. // GetPeers indicates an expected call of GetPeers.
func (mr *MockManagerMockRecorder) GetPeers(ctx, accountID, userID, nameFilter, ipFilter interface{}) *gomock.Call { func (mr *MockManagerMockRecorder) GetPeers(ctx, accountID, userID, nameFilter, ipFilter, all interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeers", reflect.TypeOf((*MockManager)(nil).GetPeers), ctx, accountID, userID, nameFilter, ipFilter) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeers", reflect.TypeOf((*MockManager)(nil).GetPeers), ctx, accountID, userID, nameFilter, ipFilter, all)
} }
// GetPolicy mocks base method. // GetPolicy mocks base method.

View File

@@ -22,9 +22,11 @@ import (
"go.opentelemetry.io/otel/metric/noop" "go.opentelemetry.io/otel/metric/noop"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" "github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
@@ -49,7 +51,6 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -3147,7 +3148,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager) peersManager := peers.NewManager(store)
proxyManager := proxy.NewMockManager(ctrl) proxyManager := proxy.NewMockManager(ctrl)
proxyManager.EXPECT(). proxyManager.EXPECT().
@@ -3164,7 +3165,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
updateManager := update_channel.NewPeersUpdateManager(metrics) updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store) requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{})
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@@ -3175,7 +3176,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, proxyManager, nil)) manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, proxyController, proxyManager, nil))
return manager, updateManager, nil return manager, updateManager, nil
} }

View File

@@ -8,8 +8,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
@@ -22,14 +20,6 @@ const (
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID) return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
} }
@@ -39,18 +29,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
} }
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var updateAccountPeers bool var updateAccountPeers bool
var eventsToStore []func() var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil { if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
return err return err
} }

View File

@@ -14,11 +14,11 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@@ -28,7 +28,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (
@@ -79,16 +78,6 @@ func TestGetDNSSettings(t *testing.T) {
if len(dnsSettings.DisabledManagementGroups) != 1 { if len(dnsSettings.DisabledManagementGroups) != 1 {
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups) t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
} }
_, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID)
if err == nil {
t.Errorf("An error should be returned when getting the DNS settings with a regular user")
}
s, ok := status.FromError(err)
if !ok && s.Type() != status.PermissionDenied {
t.Errorf("returned error should be Permission Denied, got err: %s", err)
}
} }
func TestSaveDNSSettings(t *testing.T) { func TestSaveDNSSettings(t *testing.T) {
@@ -223,7 +212,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
// return empty extra settings for expected calls to UpdateAccountPeers // return empty extra settings for expected calls to UpdateAccountPeers
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager) peersManager := peers.NewManager(store)
ctx := context.Background() ctx := context.Background()
@@ -234,7 +223,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
updateManager := update_channel.NewPeersUpdateManager(metrics) updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store) requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{})
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
} }

View File

@@ -9,11 +9,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
) )
func isEnabled() bool { func isEnabled() bool {
@@ -23,14 +20,6 @@ func isEnabled() bool {
// GetEvents returns a list of activity events of an account // GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) { func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Events, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true) events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -12,8 +12,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
@@ -32,13 +30,24 @@ func (e *GroupLinkError) Error() string {
// CheckGroupPermissions validates if a user has the necessary permissions to view groups // CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read) // Permission checks are now handled by the HTTP middleware via WithPermission wrapper
// This method is called from authenticated/authorized handlers, so we just validate
// that the user exists and is part of the account
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return err return err
} }
if !allowed { if user == nil {
return status.NewPermissionDeniedError() return status.NewUserNotFoundError(userID)
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsBlocked() {
return status.NewUserBlockedError()
} }
return nil return nil
@@ -61,27 +70,17 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
} }
// CreateGroup object of the peers // CreateGroup object of the peers
func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func() var eventsToStore []func()
var updateAccountPeers bool var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err return err
} }
@@ -125,19 +124,11 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
// UpdateGroup object of the peers // UpdateGroup object of the peers
func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func() var eventsToStore []func()
var updateAccountPeers bool var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err return err
} }
@@ -196,33 +187,24 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
// It is the caller's responsibility to ensure proper locking is in place before invoking this method. // It is the caller's responsibility to ensure proper locking is in place before invoking this method.
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that. // This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func() var eventsToStore []func()
var updateAccountPeers bool var updateAccountPeers bool
var globalErr error var globalErr error
groupIDs := make([]string, 0, len(groups)) groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups { for _, newGroup := range groups {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err return err
} }
newGroup.AccountID = accountID newGroup.AccountID = accountID
if err = transaction.CreateGroup(ctx, newGroup); err != nil { if err := transaction.CreateGroup(ctx, newGroup); err != nil {
return err return err
} }
err = transaction.IncrementNetworkSerial(ctx, accountID) if err := transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
if err != nil {
return err return err
} }
@@ -243,6 +225,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
} }
} }
var err error
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil { if err != nil {
return err return err
@@ -264,21 +247,14 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
// It is the caller's responsibility to ensure proper locking is in place before invoking this method. // It is the caller's responsibility to ensure proper locking is in place before invoking this method.
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that. // This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func() var eventsToStore []func()
var updateAccountPeers bool var updateAccountPeers bool
var globalErr error var globalErr error
groupIDs := make([]string, 0, len(groups)) groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups { for _, newGroup := range groups {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err return err
} }
@@ -311,6 +287,7 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
} }
} }
var err error
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil { if err != nil {
return err return err
@@ -416,14 +393,6 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
// If an error occurs while deleting a group, the function skips it and continues deleting other groups. // If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end. // Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var allErrors error var allErrors error
var groupIDsToDelete []string var groupIDsToDelete []string
var deletedGroups []*types.Group var deletedGroups []*types.Group

View File

@@ -26,7 +26,6 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
peer2 "github.com/netbirdio/netbird/management/server/peer" peer2 "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -764,11 +763,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
// Saving a group linked to network router should update account peers and send peer update // Saving a group linked to network router should update account peers and send peer update
t.Run("saving group linked to network router", func(t *testing.T) { t.Run("saving group linked to network router", func(t *testing.T) {
permissionsManager := permissions.NewManager(manager.Store) groupsManager := groups.NewManager(manager.Store, manager)
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager) resourcesManager := resources.NewManager(manager.Store, groupsManager, manager, manager.serviceManager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager) routersManager := routers.NewManager(manager.Store, manager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager) networksManager := networks.NewManager(manager.Store, resourcesManager, routersManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
network, err := networksManager.CreateNetwork(context.Background(), userID, &networkTypes.Network{ network, err := networksManager.CreateNetwork(context.Background(), userID, &networkTypes.Network{
ID: "network_test", ID: "network_test",

View File

@@ -6,9 +6,6 @@ import (
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
@@ -25,31 +22,21 @@ type Manager interface {
} }
type managerImpl struct { type managerImpl struct {
store store.Store store store.Store
permissionsManager permissions.Manager accountManager account.Manager
accountManager account.Manager
} }
type mockManager struct { type mockManager struct {
} }
func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager account.Manager) Manager { func NewManager(store store.Store, accountManager account.Manager) Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
permissionsManager: permissionsManager, accountManager: accountManager,
accountManager: accountManager,
} }
} }
func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
if err != nil {
return nil, err
}
if !ok {
return nil, err
}
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err) return nil, fmt.Errorf("error getting account groups: %w", err)
@@ -73,14 +60,6 @@ func (m *managerImpl) GetAllGroupsMap(ctx context.Context, accountID, userID str
} }
func (m *managerImpl) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resource *types.Resource) error { func (m *managerImpl) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resource *types.Resource) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return err
}
if !ok {
return err
}
event, err := m.AddResourceToGroupInTransaction(ctx, m.store, accountID, userID, groupID, resource) event, err := m.AddResourceToGroupInTransaction(ctx, m.store, accountID, userID, groupID, resource)
if err != nil { if err != nil {
return fmt.Errorf("error adding resource to group: %w", err) return fmt.Errorf("error adding resource to group: %w", err)

View File

@@ -10,6 +10,7 @@ import (
"github.com/rs/cors" "github.com/rs/cors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -31,10 +32,8 @@ import (
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/http/handlers/proxy" "github.com/netbirdio/netbird/management/server/http/handlers/proxy"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers" nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/auth"
@@ -124,25 +123,25 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
return nil, fmt.Errorf("failed to create instance manager: %w", err) return nil, fmt.Errorf("failed to create instance manager: %w", err)
} }
accounts.AddEndpoints(accountManager, settingsManager, router) accounts.AddEndpoints(accountManager, settingsManager, router, permissionsManager)
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager) peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
users.AddEndpoints(accountManager, router) users.AddEndpoints(accountManager, router, permissionsManager)
users.AddInvitesEndpoints(accountManager, router) users.AddInvitesEndpoints(accountManager, router, permissionsManager)
users.AddPublicInvitesEndpoints(accountManager, router) users.AddPublicInvitesEndpoints(accountManager, router)
setup_keys.AddEndpoints(accountManager, router) setup_keys.AddEndpoints(accountManager, router, permissionsManager)
policies.AddEndpoints(accountManager, LocationManager, router) policies.AddEndpoints(accountManager, LocationManager, router, permissionsManager)
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router) policies.AddPostureCheckEndpoints(accountManager, LocationManager, router, permissionsManager)
policies.AddLocationsEndpoints(accountManager, LocationManager, permissionsManager, router) policies.AddLocationsEndpoints(accountManager, LocationManager, permissionsManager, router)
groups.AddEndpoints(accountManager, router) groups.AddEndpoints(accountManager, router, permissionsManager)
routes.AddEndpoints(accountManager, router) routes.AddEndpoints(accountManager, router, permissionsManager)
dns.AddEndpoints(accountManager, router) dns.AddEndpoints(accountManager, router, permissionsManager)
events.AddEndpoints(accountManager, router) events.AddEndpoints(accountManager, router, permissionsManager)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router) networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, permissionsManager, router)
zonesManager.RegisterEndpoints(router, zManager) zonesManager.RegisterEndpoints(router, zManager, permissionsManager)
recordsManager.RegisterEndpoints(router, rManager) recordsManager.RegisterEndpoints(router, rManager, permissionsManager)
idp.AddEndpoints(accountManager, router) idp.AddEndpoints(accountManager, router, permissionsManager)
instance.AddEndpoints(instanceManager, router) instance.AddEndpoints(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router) instance.AddVersionEndpoint(instanceManager, router, permissionsManager)
if serviceManager != nil && reverseProxyDomainManager != nil { if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router) reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
} }

View File

@@ -12,10 +12,13 @@ import (
goversion "github.com/hashicorp/go-version" goversion "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -40,11 +43,11 @@ type handler struct {
settingsManager settings.Manager settingsManager settings.Manager
} }
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) { func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router, permissionsManager permissions.Manager) {
accountsHandler := newHandler(accountManager, settingsManager) accountsHandler := newHandler(accountManager, settingsManager)
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS") router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Update, accountsHandler.updateAccount)).Methods("PUT", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS") router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Delete, accountsHandler.deleteAccount)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS") router.HandleFunc("/accounts", permissionsManager.WithPermission(modules.Accounts, operations.Read, accountsHandler.getAllAccounts)).Methods("GET", "OPTIONS")
} }
// newHandler creates a new handler HTTP handler // newHandler creates a new handler HTTP handler
@@ -99,7 +102,7 @@ func (h *handler) validateNetworkRange(ctx context.Context, accountID, userID st
} }
func (h *handler) validateCapacity(ctx context.Context, accountID, userID string, prefix netip.Prefix) error { func (h *handler) validateCapacity(ctx context.Context, accountID, userID string, prefix netip.Prefix) error {
peers, err := h.accountManager.GetPeers(ctx, accountID, userID, "", "") peers, err := h.accountManager.GetPeers(ctx, accountID, userID, "", "", true)
if err != nil { if err != nil {
return status.Errorf(status.Internal, "get peer count: %v", err) return status.Errorf(status.Internal, "get peer count: %v", err)
} }
@@ -136,34 +139,26 @@ func calculateRequiredAddresses(peerCount int) int64 {
} }
// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. // getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) meta, err := h.accountManager.GetAccountMeta(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
accountID, userID := userAuth.AccountId, userAuth.UserId settings, err := h.settingsManager.GetSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
settings, err := h.settingsManager.GetSettings(r.Context(), accountID, userID) onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID) resp := toAccountResponse(userAuth.AccountId, settings, meta, onboarding)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toAccountResponse(accountID, settings, meta, onboarding)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
} }
@@ -233,24 +228,15 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS
} }
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) // updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) accountID := mux.Vars(r)["accountId"]
if err != nil { if accountID != userAuth.AccountId {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "account ID mismatch"), w)
return
}
_, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
accountID := vars["accountId"]
if len(accountID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
return return
} }
var req api.PutApiAccountsAccountIdJSONRequestBody var req api.PutApiAccountsAccountIdJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
@@ -267,7 +253,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err), w)
return return
} }
if err := h.validateNetworkRange(r.Context(), accountID, userID, prefix); err != nil { if err := h.validateNetworkRange(r.Context(), accountID, userAuth.UserId, prefix); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -282,19 +268,19 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
} }
} }
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userID, onboarding) updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userAuth.UserId, onboarding)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userAuth.UserId, settings)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID) meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -306,21 +292,14 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
} }
// deleteAccount is a HTTP DELETE handler to delete an account // deleteAccount is a HTTP DELETE handler to delete an account
func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) accountID := mux.Vars(r)["accountId"]
if err != nil { if accountID != userAuth.AccountId {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "account ID mismatch"), w)
return return
} }
vars := mux.Vars(r) err := h.accountManager.DeleteAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
targetAccountID := vars["accountId"]
if len(targetAccountID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid account ID"), w)
return
}
err = h.accountManager.DeleteAccount(r.Context(), targetAccountID, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@@ -14,6 +14,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
@@ -290,8 +291,8 @@ func TestAccounts_AccountsHandler(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET") router.HandleFunc("/api/accounts", permissions.WrapHandler(handler.getAllAccounts)).Methods("GET")
router.HandleFunc("/api/accounts/{accountId}", handler.updateAccount).Methods("PUT") router.HandleFunc("/api/accounts/{accountId}", permissions.WrapHandler(handler.updateAccount)).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@@ -5,11 +5,13 @@ import (
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
) )
@@ -19,15 +21,15 @@ type dnsSettingsHandler struct {
accountManager account.Manager accountManager account.Manager
} }
func AddEndpoints(accountManager account.Manager, router *mux.Router) { func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
addDNSSettingEndpoint(accountManager, router) addDNSSettingEndpoint(accountManager, router, permissionsManager)
addDNSNameserversEndpoint(accountManager, router) addDNSNameserversEndpoint(accountManager, router, permissionsManager)
} }
func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router) { func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
dnsSettingsHandler := newDNSSettingsHandler(accountManager) dnsSettingsHandler := newDNSSettingsHandler(accountManager)
router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS") router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Read, dnsSettingsHandler.getDNSSettings)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS") router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Update, dnsSettingsHandler.updateDNSSettings)).Methods("PUT", "OPTIONS")
} }
// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler // newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler
@@ -36,17 +38,8 @@ func newDNSSettingsHandler(accountManager account.Manager) *dnsSettingsHandler {
} }
// getDNSSettings returns the DNS settings for the account // getDNSSettings returns the DNS settings for the account
func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) { func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -60,17 +53,9 @@ func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Reque
} }
// updateDNSSettings handles update to DNS settings of an account // updateDNSSettings handles update to DNS settings of an account
func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) { func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.PutApiDnsSettingsJSONRequestBody var req api.PutApiDnsSettingsJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
@@ -80,7 +65,7 @@ func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: req.DisabledManagementGroups, DisabledManagementGroups: req.DisabledManagementGroups,
} }
err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings) err = h.accountManager.SaveDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId, updateDNSSettings)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@@ -17,6 +17,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth"
@@ -115,8 +116,8 @@ func TestDNSSettingsHandlers(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET") router.HandleFunc("/api/dns/settings", permissions.WrapHandler(p.getDNSSettings)).Methods("GET")
router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT") router.HandleFunc("/api/dns/settings", permissions.WrapHandler(p.updateDNSSettings)).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@@ -6,11 +6,13 @@ import (
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -21,13 +23,13 @@ type nameserversHandler struct {
accountManager account.Manager accountManager account.Manager
} }
func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router) { func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
nameserversHandler := newNameserversHandler(accountManager) nameserversHandler := newNameserversHandler(accountManager)
router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS") router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getAllNameservers)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS") router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Create, nameserversHandler.createNameserverGroup)).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS") router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Update, nameserversHandler.updateNameserverGroup)).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.getNameserverGroup).Methods("GET", "OPTIONS") router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getNameserverGroup)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.deleteNameserverGroup).Methods("DELETE", "OPTIONS") router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Delete, nameserversHandler.deleteNameserverGroup)).Methods("DELETE", "OPTIONS")
} }
// newNameserversHandler returns a new instance of nameserversHandler handler // newNameserversHandler returns a new instance of nameserversHandler handler
@@ -36,17 +38,8 @@ func newNameserversHandler(accountManager account.Manager) *nameserversHandler {
} }
// getAllNameservers returns the list of nameserver groups for the account // getAllNameservers returns the list of nameserver groups for the account
func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) { func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -61,17 +54,9 @@ func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Re
} }
// createNameserverGroup handles nameserver group creation request // createNameserverGroup handles nameserver group creation request
func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.PostApiDnsNameserversJSONRequestBody var req api.PostApiDnsNameserversJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
@@ -83,7 +68,7 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
return return
} }
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled) nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), userAuth.AccountId, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userAuth.UserId, req.SearchDomainsEnabled)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -95,15 +80,7 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
} }
// updateNameserverGroup handles update to a nameserver group identified by a given ID // updateNameserverGroup handles update to a nameserver group identified by a given ID
func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
nsGroupID := mux.Vars(r)["nsgroupId"] nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 { if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
@@ -111,7 +88,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
} }
var req api.PutApiDnsNameserversNsgroupIdJSONRequestBody var req api.PutApiDnsNameserversNsgroupIdJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
@@ -135,7 +112,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
SearchDomainsEnabled: req.SearchDomainsEnabled, SearchDomainsEnabled: req.SearchDomainsEnabled,
} }
err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup) err = h.accountManager.SaveNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, updatedNSGroup)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -147,22 +124,14 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
} }
// deleteNameserverGroup handles nameserver group deletion request // deleteNameserverGroup handles nameserver group deletion request
func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
nsGroupID := mux.Vars(r)["nsgroupId"] nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 { if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return return
} }
err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID) err := h.accountManager.DeleteNameServerGroup(r.Context(), userAuth.AccountId, nsGroupID, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -172,22 +141,14 @@ func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *htt
} }
// getNameserverGroup handles a nameserver group Get request identified by ID // getNameserverGroup handles a nameserver group Get request identified by ID
func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
nsGroupID := mux.Vars(r)["nsgroupId"] nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 { if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return return
} }
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID) nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, nsGroupID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@@ -18,6 +18,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth"
@@ -201,10 +202,10 @@ func TestNameserversHandlers(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET") router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.getNameserverGroup)).Methods("GET")
router.HandleFunc("/api/dns/nameservers", p.createNameserverGroup).Methods("POST") router.HandleFunc("/api/dns/nameservers", permissions.WrapHandler(p.createNameserverGroup)).Methods("POST")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.deleteNameserverGroup).Methods("DELETE") router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.deleteNameserverGroup)).Methods("DELETE")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.updateNameserverGroup).Methods("PUT") router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.updateNameserverGroup)).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@@ -5,11 +5,13 @@ import (
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
) )
@@ -19,10 +21,10 @@ type handler struct {
accountManager account.Manager accountManager account.Manager
} }
func AddEndpoints(accountManager account.Manager, router *mux.Router) { func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
eventsHandler := newHandler(accountManager) eventsHandler := newHandler(accountManager)
router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS") router.HandleFunc("/events", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS")
router.HandleFunc("/events/audit", eventsHandler.getAllEvents).Methods("GET", "OPTIONS") router.HandleFunc("/events/audit", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS")
} }
// newHandler creates a new events handler // newHandler creates a new events handler
@@ -31,17 +33,8 @@ func newHandler(accountManager account.Manager) *handler {
} }
// getAllEvents list of the given account // getAllEvents list of the given account
func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) accountEvents, err := h.accountManager.GetEvents(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@@ -13,6 +13,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth"
@@ -196,7 +197,7 @@ func TestEvents_GetEvents(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET") router.HandleFunc("/api/events/", permissions.WrapHandler(handler.getAllEvents)).Methods("GET")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@@ -7,11 +7,13 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -19,46 +21,45 @@ import (
// handler is a handler that returns groups of the account // handler is a handler that returns groups of the account
type handler struct { type handler struct {
accountManager account.Manager accountManager account.Manager
permissionsManager permissions.Manager
} }
func AddEndpoints(accountManager account.Manager, router *mux.Router) { func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
groupsHandler := newHandler(accountManager) groupsHandler := newHandler(accountManager, permissionsManager)
router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS") router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getAllGroups)).Methods("GET", "OPTIONS")
router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS") router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Create, groupsHandler.createGroup)).Methods("POST", "OPTIONS")
router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS") router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Update, groupsHandler.updateGroup)).Methods("PUT", "OPTIONS")
router.HandleFunc("/groups/{groupId}", groupsHandler.getGroup).Methods("GET", "OPTIONS") router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getGroup)).Methods("GET", "OPTIONS")
router.HandleFunc("/groups/{groupId}", groupsHandler.deleteGroup).Methods("DELETE", "OPTIONS") router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Delete, groupsHandler.deleteGroup)).Methods("DELETE", "OPTIONS")
} }
// newHandler creates a new groups handler // newHandler creates a new groups handler
func newHandler(accountManager account.Manager) *handler { func newHandler(accountManager account.Manager, permissionsManager permissions.Manager) *handler {
return &handler{ return &handler{
accountManager: accountManager, accountManager: accountManager,
permissionsManager: permissionsManager,
} }
} }
// getAllGroups list for the account func (h *handler) canReadPeers(r *http.Request, userAuth *auth.UserAuth) bool {
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Peers, operations.Read)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) return err == nil && allowed
if err != nil { }
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
// getAllGroups list for the account
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
// Check if filtering by name // Check if filtering by name
groupName := r.URL.Query().Get("name") groupName := r.URL.Query().Get("name")
if groupName != "" { if groupName != "" {
// Get single group by name // Get single group by name
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID, userID) group, err := h.accountManager.GetGroupByName(r.Context(), groupName, userAuth.AccountId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -71,13 +72,13 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
} }
// Get all groups // Get all groups
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) groups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -92,15 +93,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
} }
// updateGroup handles update to a group identified by a given ID // updateGroup handles update to a group identified by a given ID
func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
groupID, ok := vars["groupId"] groupID, ok := vars["groupId"]
if !ok { if !ok {
@@ -112,13 +105,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
return return
} }
existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID) existingGroup, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID, userID) allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", userAuth.AccountId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -166,13 +159,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
IntegrationReference: existingGroup.IntegrationReference, IntegrationReference: existingGroup.IntegrationReference,
} }
if err := h.accountManager.UpdateGroup(r.Context(), accountID, userID, &group); err != nil { if err := h.accountManager.UpdateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group); err != nil {
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err) log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, userAuth.AccountId, err)
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -182,17 +175,9 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
} }
// createGroup handles group creation request // createGroup handles group creation request
func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { func (h *handler) createGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.PostApiGroupsJSONRequestBody var req api.PostApiGroupsJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
@@ -226,13 +211,13 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
Issued: types.GroupIssuedAPI, Issued: types.GroupIssuedAPI,
} }
err = h.accountManager.CreateGroup(r.Context(), accountID, userID, &group) err = h.accountManager.CreateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -242,22 +227,14 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
} }
// deleteGroup handles group deletion request // deleteGroup handles group deletion request
func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
groupID := mux.Vars(r)["groupId"] groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 { if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return return
} }
err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID) err := h.accountManager.DeleteGroup(r.Context(), userAuth.AccountId, userAuth.UserId, groupID)
if err != nil { if err != nil {
wrappedErr, ok := err.(interface{ Unwrap() []error }) wrappedErr, ok := err.(interface{ Unwrap() []error })
if ok && len(wrappedErr.Unwrap()) > 0 { if ok && len(wrappedErr.Unwrap()) > 0 {
@@ -273,34 +250,26 @@ func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
} }
// getGroup returns a group // getGroup returns a group
func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) { func (h *handler) getGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
groupID := mux.Vars(r)["groupId"] groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 { if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return return
} }
group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID) group, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group)) util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group))
} }
func toGroupResponse(peers []*nbpeer.Peer, group *types.Group) *api.Group { func toGroupResponse(peers []*nbpeer.Peer, group *types.Group) *api.Group {

View File

@@ -13,10 +13,14 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
@@ -33,8 +37,18 @@ var TestPeers = map[string]*nbpeer.Peer{
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
} }
func initGroupTestData(initGroups ...*types.Group) *handler { func initGroupTestData(t *testing.T, initGroups ...*types.Group) *handler {
t.Helper()
ctrl := gomock.NewController(t)
permissionsManagerMock := permissions.NewMockManager(ctrl)
permissionsManagerMock.EXPECT().
ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Eq(modules.Peers), gomock.Eq(operations.Read)).
Return(true, nil).
AnyTimes()
return &handler{ return &handler{
permissionsManager: permissionsManagerMock,
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group, create bool) error { SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group, create bool) error {
if !strings.HasPrefix(group.ID, "id-") { if !strings.HasPrefix(group.ID, "id-") {
@@ -71,14 +85,14 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
return groups, nil return groups, nil
}, },
GetGroupByNameFunc: func(ctx context.Context, groupName, _, _ string) (*types.Group, error) { GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
if groupName == "All" { if groupName == "All" {
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
} }
return nil, status.Errorf(status.NotFound, "unknown group name") return nil, status.Errorf(status.NotFound, "unknown group name")
}, },
GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string, _ bool) ([]*nbpeer.Peer, error) {
return maps.Values(TestPeers), nil return maps.Values(TestPeers), nil
}, },
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error { DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
@@ -128,7 +142,7 @@ func TestGetGroup(t *testing.T) {
Name: "Group", Name: "Group",
} }
p := initGroupTestData(group) p := initGroupTestData(t, group)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@@ -141,7 +155,7 @@ func TestGetGroup(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET") router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.getGroup)).Methods("GET")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@@ -254,7 +268,7 @@ func TestWriteGroup(t *testing.T) {
}, },
} }
p := initGroupTestData() p := initGroupTestData(t)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@@ -267,8 +281,8 @@ func TestWriteGroup(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/groups", p.createGroup).Methods("POST") router.HandleFunc("/api/groups", permissions.WrapHandler(p.createGroup)).Methods("POST")
router.HandleFunc("/api/groups/{groupId}", p.updateGroup).Methods("PUT") router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.updateGroup)).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@@ -332,7 +346,7 @@ func TestGetAllGroups(t *testing.T) {
}, },
} }
p := initGroupTestData() p := initGroupTestData(t)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@@ -345,7 +359,7 @@ func TestGetAllGroups(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/groups", p.getAllGroups).Methods("GET") router.HandleFunc("/api/groups", permissions.WrapHandler(p.getAllGroups)).Methods("GET")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@@ -414,7 +428,7 @@ func TestDeleteGroup(t *testing.T) {
}, },
} }
p := initGroupTestData() p := initGroupTestData(t)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@@ -426,7 +440,7 @@ func TestDeleteGroup(t *testing.T) {
AccountId: "test_id", AccountId: "test_id",
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE") router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.deleteGroup)).Methods("DELETE")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@@ -6,9 +6,12 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -20,13 +23,13 @@ type handler struct {
} }
// AddEndpoints registers identity provider endpoints // AddEndpoints registers identity provider endpoints
func AddEndpoints(accountManager account.Manager, router *mux.Router) { func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
h := newHandler(accountManager) h := newHandler(accountManager)
router.HandleFunc("/identity-providers", h.getAllIdentityProviders).Methods("GET", "OPTIONS") router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getAllIdentityProviders)).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers", h.createIdentityProvider).Methods("POST", "OPTIONS") router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Create, h.createIdentityProvider)).Methods("POST", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET", "OPTIONS") router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getIdentityProvider)).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT", "OPTIONS") router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Update, h.updateIdentityProvider)).Methods("PUT", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE", "OPTIONS") router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Delete, h.deleteIdentityProvider)).Methods("DELETE", "OPTIONS")
} }
func newHandler(accountManager account.Manager) *handler { func newHandler(accountManager account.Manager) *handler {
@@ -36,16 +39,8 @@ func newHandler(accountManager account.Manager) *handler {
} }
// getAllIdentityProviders returns all identity providers for the account // getAllIdentityProviders returns all identity providers for the account
func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) providers, err := h.accountManager.GetIdentityProviders(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
providers, err := h.accountManager.GetIdentityProviders(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -60,15 +55,7 @@ func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request
} }
// getIdentityProvider returns a specific identity provider // getIdentityProvider returns a specific identity provider
func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) { func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
idpID := vars["idpId"] idpID := vars["idpId"]
if idpID == "" { if idpID == "" {
@@ -76,7 +63,7 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
return return
} }
provider, err := h.accountManager.GetIdentityProvider(r.Context(), accountID, idpID, userID) provider, err := h.accountManager.GetIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -86,15 +73,7 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
} }
// createIdentityProvider creates a new identity provider // createIdentityProvider creates a new identity provider
func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request) { func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.IdentityProviderRequest var req api.IdentityProviderRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -103,7 +82,7 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request)
idp := fromAPIRequest(&req) idp := fromAPIRequest(&req)
created, err := h.accountManager.CreateIdentityProvider(r.Context(), accountID, userID, idp) created, err := h.accountManager.CreateIdentityProvider(r.Context(), userAuth.AccountId, userAuth.UserId, idp)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -113,15 +92,7 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request)
} }
// updateIdentityProvider updates an existing identity provider // updateIdentityProvider updates an existing identity provider
func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request) { func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
idpID := vars["idpId"] idpID := vars["idpId"]
if idpID == "" { if idpID == "" {
@@ -137,7 +108,7 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request)
idp := fromAPIRequest(&req) idp := fromAPIRequest(&req)
updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), accountID, idpID, userID, idp) updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId, idp)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -147,15 +118,7 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request)
} }
// deleteIdentityProvider deletes an identity provider // deleteIdentityProvider deletes an identity provider
func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
idpID := vars["idpId"] idpID := vars["idpId"]
if idpID == "" { if idpID == "" {
@@ -163,7 +126,7 @@ func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request)
return return
} }
if err := h.accountManager.DeleteIdentityProvider(r.Context(), accountID, idpID, userID); err != nil { if err := h.accountManager.DeleteIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }

View File

@@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -120,7 +121,7 @@ func TestGetAllIdentityProviders(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/identity-providers", h.getAllIdentityProviders).Methods("GET") router.HandleFunc("/api/identity-providers", permissions.WrapHandler(h.getAllIdentityProviders)).Methods("GET")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@@ -180,7 +181,7 @@ func TestGetIdentityProvider(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET") router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.getIdentityProvider)).Methods("GET")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@@ -242,7 +243,7 @@ func TestCreateIdentityProvider(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/identity-providers", h.createIdentityProvider).Methods("POST") router.HandleFunc("/api/identity-providers", permissions.WrapHandler(h.createIdentityProvider)).Methods("POST")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@@ -328,7 +329,7 @@ func TestUpdateIdentityProvider(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT") router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.updateIdentityProvider)).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@@ -388,7 +389,7 @@ func TestDeleteIdentityProvider(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE") router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.deleteIdentityProvider)).Methods("DELETE")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@@ -7,7 +7,11 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
nbinstance "github.com/netbirdio/netbird/management/server/instance" nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
) )
@@ -29,12 +33,12 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
} }
// AddVersionEndpoint registers the authenticated version endpoint. // AddVersionEndpoint registers the authenticated version endpoint.
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) { func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router, permissionsManager permissions.Manager) {
h := &handler{ h := &handler{
instanceManager: instanceManager, instanceManager: instanceManager,
} }
router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS") router.HandleFunc("/instance/version", permissionsManager.WithPermission(modules.Settings, operations.Read, h.getVersionInfo)).Methods("GET", "OPTIONS")
} }
// getInstanceStatus returns the instance status including whether setup is required. // getInstanceStatus returns the instance status including whether setup is required.
@@ -77,7 +81,7 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
// getVersionInfo returns version information for NetBird components. // getVersionInfo returns version information for NetBird components.
// This endpoint requires authentication. // This endpoint requires authentication.
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) { func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
versionInfo, err := h.instanceManager.GetVersionInfo(r.Context()) versionInfo, err := h.instanceManager.GetVersionInfo(r.Context())
if err != nil { if err != nil {
log.WithContext(r.Context()).Errorf("failed to get version info: %v", err) log.WithContext(r.Context()).Errorf("failed to get version info: %v", err)

View File

@@ -10,12 +10,17 @@ import (
"net/mail" "net/mail"
"testing" "testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
nbinstance "github.com/netbirdio/netbird/management/server/instance" nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
@@ -295,8 +300,15 @@ func TestSetup_ManagerError(t *testing.T) {
func TestGetVersionInfo_Success(t *testing.T) { func TestGetVersionInfo_Success(t *testing.T) {
manager := &mockInstanceManager{} manager := &mockInstanceManager{}
ctrl := gomock.NewController(t)
permissionsManager := permissions.NewMockManager(ctrl)
permissionsManager.EXPECT().WithPermission(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(module modules.Module, operation operations.Operation, handler func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth), authErrHandler ...permissions.AuthErrorHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
handler(w, r, &auth.UserAuth{})
}
}).AnyTimes()
router := mux.NewRouter() router := mux.NewRouter()
AddVersionEndpoint(manager, router) AddVersionEndpoint(manager, router, permissionsManager)
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil) req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
@@ -323,8 +335,15 @@ func TestGetVersionInfo_Error(t *testing.T) {
return nil, errors.New("failed to fetch versions") return nil, errors.New("failed to fetch versions")
}, },
} }
ctrl := gomock.NewController(t)
permissionsManager := permissions.NewMockManager(ctrl)
permissionsManager.EXPECT().WithPermission(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(module modules.Module, operation operations.Operation, handler func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth), authErrHandler ...permissions.AuthErrorHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
handler(w, r, &auth.UserAuth{})
}
}).AnyTimes()
router := mux.NewRouter() router := mux.NewRouter()
AddVersionEndpoint(manager, router) AddVersionEndpoint(manager, router, permissionsManager)
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil) req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()

View File

@@ -9,8 +9,10 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources"
@@ -18,6 +20,7 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/networks/types" "github.com/netbirdio/netbird/management/server/networks/types"
nbtypes "github.com/netbirdio/netbird/management/server/types" nbtypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -33,16 +36,16 @@ type handler struct {
groupsManager groups.Manager groupsManager groups.Manager
} }
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, router *mux.Router) { func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, permissionsManager permissions.Manager, router *mux.Router) {
addRouterEndpoints(routerManager, router) addRouterEndpoints(routerManager, permissionsManager, router)
addResourceEndpoints(resourceManager, groupsManager, router) addResourceEndpoints(resourceManager, groupsManager, permissionsManager, router)
networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager) networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager)
router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS") router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getAllNetworks)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS") router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Create, networksHandler.createNetwork)).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS") router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getNetwork)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.updateNetwork).Methods("PUT", "OPTIONS") router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Update, networksHandler.updateNetwork)).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS") router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, networksHandler.deleteNetwork)).Methods("DELETE", "OPTIONS")
} }
func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager) *handler { func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager) *handler {
@@ -55,40 +58,32 @@ func newHandler(networksManager networks.Manager, resourceManager resources.Mana
} }
} }
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) networks, err := h.networksManager.GetAllNetworks(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
accountID, userID := userAuth.AccountId, userAuth.UserId resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), accountID, userID) groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), accountID, userID) routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), accountID, userID) account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
account, err := h.accountManager.GetAccount(r.Context(), accountID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -97,16 +92,9 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups, account)) util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups, account))
} }
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.NetworkRequest var req api.NetworkRequest
err = json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
@@ -115,14 +103,14 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
network := &types.Network{} network := &types.Network{}
network.FromAPIRequest(&req) network.FromAPIRequest(&req)
network.AccountID = accountID network.AccountID = userAuth.AccountId
network, err = h.networksManager.CreateNetwork(r.Context(), userID, network) network, err = h.networksManager.CreateNetwork(r.Context(), userAuth.UserId, network)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
account, err := h.accountManager.GetAccount(r.Context(), accountID) account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -133,14 +121,7 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0, policyIDs)) util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0, policyIDs))
} }
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
networkID := vars["networkId"] networkID := vars["networkId"]
if len(networkID) == 0 { if len(networkID) == 0 {
@@ -148,19 +129,19 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
return return
} }
network, err := h.networksManager.GetNetwork(r.Context(), accountID, userID, networkID) network, err := h.networksManager.GetNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID) routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
account, err := h.accountManager.GetAccount(r.Context(), accountID) account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -171,14 +152,7 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs)) util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
} }
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
networkID := vars["networkId"] networkID := vars["networkId"]
if len(networkID) == 0 { if len(networkID) == 0 {
@@ -187,7 +161,7 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
} }
var req api.NetworkRequest var req api.NetworkRequest
err = json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
@@ -197,20 +171,20 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
network.FromAPIRequest(&req) network.FromAPIRequest(&req)
network.ID = networkID network.ID = networkID
network.AccountID = accountID network.AccountID = userAuth.AccountId
network, err = h.networksManager.UpdateNetwork(r.Context(), userID, network) network, err = h.networksManager.UpdateNetwork(r.Context(), userAuth.UserId, network)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID) routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
account, err := h.accountManager.GetAccount(r.Context(), accountID) account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -221,14 +195,7 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs)) util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
} }
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
networkID := vars["networkId"] networkID := vars["networkId"]
if len(networkID) == 0 { if len(networkID) == 0 {
@@ -236,7 +203,7 @@ func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.networksManager.DeleteNetwork(r.Context(), accountID, userID, networkID) err := h.networksManager.DeleteNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@@ -6,10 +6,13 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/resources/types" "github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
) )
@@ -19,14 +22,14 @@ type resourceHandler struct {
groupsManager groups.Manager groupsManager groups.Manager
} }
func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, router *mux.Router) { func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, permissionsManager permissions.Manager, router *mux.Router) {
resourceHandler := newResourceHandler(resourcesManager, groupsManager) resourceHandler := newResourceHandler(resourcesManager, groupsManager)
router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS") router.HandleFunc("/networks/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInAccount)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS") router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInNetwork)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS") router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Create, resourceHandler.createResource)).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.getResource).Methods("GET", "OPTIONS") router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getResource)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.updateResource).Methods("PUT", "OPTIONS") router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Update, resourceHandler.updateResource)).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS") router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, resourceHandler.deleteResource)).Methods("DELETE", "OPTIONS")
} }
func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager) *resourceHandler { func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager) *resourceHandler {
@@ -36,22 +39,15 @@ func newResourceHandler(resourceManager resources.Manager, groupsManager groups.
} }
} }
func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) { func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"] networkID := mux.Vars(r)["networkId"]
resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID) resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -66,22 +62,14 @@ func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *htt
util.WriteJSONObject(r.Context(), w, resourcesResponse) util.WriteJSONObject(r.Context(), w, resourcesResponse)
} }
func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) { func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
accountID, userID := userAuth.AccountId, userAuth.UserId grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -97,17 +85,9 @@ func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *htt
util.WriteJSONObject(r.Context(), w, resourcesResponse) util.WriteJSONObject(r.Context(), w, resourcesResponse)
} }
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) { func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.NetworkResourceRequest var req api.NetworkResourceRequest
err = json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
@@ -117,14 +97,14 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request)
resource.FromAPIRequest(&req) resource.FromAPIRequest(&req)
resource.NetworkID = mux.Vars(r)["networkId"] resource.NetworkID = mux.Vars(r)["networkId"]
resource.AccountID = accountID resource.AccountID = userAuth.AccountId
resource, err = h.resourceManager.CreateResource(r.Context(), userID, resource) resource, err = h.resourceManager.CreateResource(r.Context(), userAuth.UserId, resource)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -135,23 +115,16 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request)
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID])) util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
} }
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) { func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"] networkID := mux.Vars(r)["networkId"]
resourceID := mux.Vars(r)["resourceId"] resourceID := mux.Vars(r)["resourceId"]
resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID) resource, err := h.resourceManager.GetResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -162,16 +135,9 @@ func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID])) util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
} }
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) { func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.NetworkResourceRequest var req api.NetworkResourceRequest
err = json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
@@ -182,14 +148,14 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request)
resource.ID = mux.Vars(r)["resourceId"] resource.ID = mux.Vars(r)["resourceId"]
resource.NetworkID = mux.Vars(r)["networkId"] resource.NetworkID = mux.Vars(r)["networkId"]
resource.AccountID = accountID resource.AccountID = userAuth.AccountId
resource, err = h.resourceManager.UpdateResource(r.Context(), userID, resource) resource, err = h.resourceManager.UpdateResource(r.Context(), userAuth.UserId, resource)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -200,17 +166,10 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request)
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID])) util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
} }
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) { func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"] networkID := mux.Vars(r)["networkId"]
resourceID := mux.Vars(r)["resourceId"] resourceID := mux.Vars(r)["resourceId"]
err = h.resourceManager.DeleteResource(r.Context(), accountID, userID, networkID, resourceID) err := h.resourceManager.DeleteResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@@ -6,9 +6,12 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
) )
@@ -17,14 +20,14 @@ type routersHandler struct {
routersManager routers.Manager routersManager routers.Manager
} }
func addRouterEndpoints(routersManager routers.Manager, router *mux.Router) { func addRouterEndpoints(routersManager routers.Manager, permissionsManager permissions.Manager, router *mux.Router) {
routersHandler := newRoutersHandler(routersManager) routersHandler := newRoutersHandler(routersManager)
router.HandleFunc("/networks/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS") router.HandleFunc("/networks/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getAllRouters)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", routersHandler.getNetworkRouters).Methods("GET", "OPTIONS") router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getNetworkRouters)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS") router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Create, routersHandler.createRouter)).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS") router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getRouter)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.updateRouter).Methods("PUT", "OPTIONS") router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Update, routersHandler.updateRouter)).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS") router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, routersHandler.deleteRouter)).Methods("DELETE", "OPTIONS")
} }
func newRoutersHandler(routersManager routers.Manager) *routersHandler { func newRoutersHandler(routersManager routers.Manager) *routersHandler {
@@ -33,16 +36,8 @@ func newRoutersHandler(routersManager routers.Manager) *routersHandler {
} }
} }
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) { func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -58,17 +53,9 @@ func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, routersResponse) util.WriteJSONObject(r.Context(), w, routersResponse)
} }
func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request) { func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"] networkID := mux.Vars(r)["networkId"]
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID) routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -82,18 +69,10 @@ func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Reques
util.WriteJSONObject(r.Context(), w, routersResponse) util.WriteJSONObject(r.Context(), w, routersResponse)
} }
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"] networkID := mux.Vars(r)["networkId"]
var req api.NetworkRouterRequest var req api.NetworkRouterRequest
err = json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
@@ -103,7 +82,7 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
router.FromAPIRequest(&req) router.FromAPIRequest(&req)
router.NetworkID = networkID router.NetworkID = networkID
router.AccountID = accountID router.AccountID = userAuth.AccountId
router.Enabled = true router.Enabled = true
if err := router.Validate(); err != nil { if err := router.Validate(); err != nil {
@@ -111,7 +90,7 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
return return
} }
router, err = h.routersManager.CreateRouter(r.Context(), userID, router) router, err = h.routersManager.CreateRouter(r.Context(), userAuth.UserId, router)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -120,18 +99,10 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
} }
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) { func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routerID := mux.Vars(r)["routerId"] routerID := mux.Vars(r)["routerId"]
networkID := mux.Vars(r)["networkId"] networkID := mux.Vars(r)["networkId"]
router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID) router, err := h.routersManager.GetRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -140,17 +111,9 @@ func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
} }
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.NetworkRouterRequest var req api.NetworkRouterRequest
err = json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
@@ -161,14 +124,14 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
router.NetworkID = mux.Vars(r)["networkId"] router.NetworkID = mux.Vars(r)["networkId"]
router.ID = mux.Vars(r)["routerId"] router.ID = mux.Vars(r)["routerId"]
router.AccountID = accountID router.AccountID = userAuth.AccountId
if err := router.Validate(); err != nil { if err := router.Validate(); err != nil {
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w) util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
return return
} }
router, err = h.routersManager.UpdateRouter(r.Context(), userID, router) router, err = h.routersManager.UpdateRouter(r.Context(), userAuth.UserId, router)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -177,17 +140,10 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
} }
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) { func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routerID := mux.Vars(r)["routerId"] routerID := mux.Vars(r)["routerId"]
networkID := mux.Vars(r)["networkId"] networkID := mux.Vars(r)["networkId"]
err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID) err := h.routersManager.DeleteRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@@ -1,7 +1,6 @@
package peers package peers
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -12,15 +11,15 @@ import (
"github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -35,14 +34,15 @@ type Handler struct {
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) { func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) {
peersHandler := NewHandler(accountManager, networkMapController, permissionsManager) peersHandler := NewHandler(accountManager, networkMapController, permissionsManager)
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") router.HandleFunc("/peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAllPeers)).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetPeer)).Methods("GET", "OPTIONS")
Methods("GET", "PUT", "DELETE", "OPTIONS") router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Update, peersHandler.UpdatePeer)).Methods("PUT", "OPTIONS")
router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Delete, peersHandler.DeletePeer)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/peers/{peerId}/temporary-access", peersHandler.CreateTemporaryAccess).Methods("POST", "OPTIONS") router.HandleFunc("/peers/{peerId}/accessible-peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAccessiblePeers)).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.ListJobs).Methods("GET", "OPTIONS") router.HandleFunc("/peers/{peerId}/temporary-access", permissionsManager.WithPermission(modules.Peers, operations.Create, peersHandler.CreateTemporaryAccess)).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.CreateJob).Methods("POST", "OPTIONS") router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.ListJobs)).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", peersHandler.GetJob).Methods("GET", "OPTIONS") router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Create, peersHandler.CreateJob)).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.GetJob)).Methods("GET", "OPTIONS")
} }
// NewHandler creates a new peers Handler // NewHandler creates a new peers Handler
@@ -54,14 +54,7 @@ func NewHandler(accountManager account.Manager, networkMapController network_map
} }
} }
func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
vars := mux.Vars(r) vars := mux.Vars(r)
peerID := vars["peerId"] peerID := vars["peerId"]
@@ -73,37 +66,30 @@ func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) {
job, err := types.NewJob(userAuth.UserId, userAuth.AccountId, peerID, req) job, err := types.NewJob(userAuth.UserId, userAuth.AccountId, peerID, req)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(r.Context(), err, w)
return return
} }
if err := h.accountManager.CreatePeerJob(ctx, userAuth.AccountId, peerID, userAuth.UserId, job); err != nil { if err := h.accountManager.CreatePeerJob(r.Context(), userAuth.AccountId, peerID, userAuth.UserId, job); err != nil {
util.WriteError(ctx, err, w) util.WriteError(r.Context(), err, w)
return return
} }
resp, err := toSingleJobResponse(job) resp, err := toSingleJobResponse(job)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(ctx, w, resp) util.WriteJSONObject(r.Context(), w, resp)
} }
func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) { func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
vars := mux.Vars(r) vars := mux.Vars(r)
peerID := vars["peerId"] peerID := vars["peerId"]
jobs, err := h.accountManager.GetAllPeerJobs(ctx, userAuth.AccountId, userAuth.UserId, peerID) jobs, err := h.accountManager.GetAllPeerJobs(r.Context(), userAuth.AccountId, userAuth.UserId, peerID)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -111,79 +97,88 @@ func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) {
for _, job := range jobs { for _, job := range jobs {
resp, err := toSingleJobResponse(job) resp, err := toSingleJobResponse(job)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(r.Context(), err, w)
return return
} }
respBody = append(respBody, resp) respBody = append(respBody, resp)
} }
util.WriteJSONObject(ctx, w, respBody) util.WriteJSONObject(r.Context(), w, respBody)
} }
func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
vars := mux.Vars(r) vars := mux.Vars(r)
peerID := vars["peerId"] peerID := vars["peerId"]
jobID := vars["jobId"] jobID := vars["jobId"]
job, err := h.accountManager.GetPeerJobByID(ctx, userAuth.AccountId, userAuth.UserId, peerID, jobID) job, err := h.accountManager.GetPeerJobByID(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, jobID)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(r.Context(), err, w)
return return
} }
resp, err := toSingleJobResponse(job) resp, err := toSingleJobResponse(job)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(ctx, w, resp) util.WriteJSONObject(r.Context(), w, resp)
} }
func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) { // GetPeer handles GET request for a single peer
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID) func (h *Handler) GetPeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
peer, err := h.accountManager.GetPeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(r.Context(), err, w)
return return
} }
if peer.ProxyMeta.Embedded { if peer.ProxyMeta.Embedded {
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "not allowed to read peer"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "not allowed to read peer"), w)
return return
} }
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator) settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(r.Context(), err, w)
return return
} }
dnsDomain := h.networkMapController.GetDNSDomain(settings) dnsDomain := h.networkMapController.GetDNSDomain(settings)
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grps, _ := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peerID)
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w) util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return return
} }
_, valid := validPeers[peer.ID] _, valid := validPeers[peer.ID]
reason := invalidPeers[peer.ID] reason := invalidPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
} }
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { // UpdatePeer handles PUT request to update a peer
func (h *Handler) UpdatePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
req := &api.PeerRequest{} req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
@@ -192,11 +187,10 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
} }
update := &nbpeer.Peer{ update := &nbpeer.Peer{
ID: peerID, ID: peerID,
SSHEnabled: req.SshEnabled, SSHEnabled: req.SshEnabled,
Name: req.Name, Name: req.Name,
LoginExpirationEnabled: req.LoginExpirationEnabled, LoginExpirationEnabled: req.LoginExpirationEnabled,
InactivityExpirationEnabled: req.InactivityExpirationEnabled, InactivityExpirationEnabled: req.InactivityExpirationEnabled,
} }
@@ -210,41 +204,41 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
if req.Ip != nil { if req.Ip != nil {
addr, err := netip.ParseAddr(*req.Ip) addr, err := netip.ParseAddr(*req.Ip)
if err != nil { if err != nil {
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
return return
} }
if err = h.accountManager.UpdatePeerIP(ctx, accountID, userID, peerID, addr); err != nil { if err = h.accountManager.UpdatePeerIP(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, addr); err != nil {
util.WriteError(ctx, err, w) util.WriteError(r.Context(), err, w)
return return
} }
} }
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update) peer, err := h.accountManager.UpdatePeer(r.Context(), userAuth.AccountId, userAuth.UserId, update)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(r.Context(), err, w)
return return
} }
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator) settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(r.Context(), err, w)
return return
} }
dnsDomain := h.networkMapController.GetDNSDomain(settings) dnsDomain := h.networkMapController.GetDNSDomain(settings)
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peer.ID)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(r.Context(), err, w)
return return
} }
grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0) grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get validated peers: %v", err) log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w) util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return return
} }
@@ -254,25 +248,8 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
} }
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { // DeletePeer handles DELETE request to delete a peer
err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID) func (h *Handler) DeletePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if err != nil {
log.WithContext(ctx).Errorf("failed to delete peer: %v", err)
util.WriteError(ctx, err, w)
return
}
util.WriteJSONObject(ctx, w, util.EmptyObject{})
}
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
peerID := vars["peerId"] peerID := vars["peerId"]
if len(peerID) == 0 { if len(peerID) == 0 {
@@ -280,48 +257,34 @@ func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
return return
} }
switch r.Method { err := h.accountManager.DeletePeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId)
case http.MethodDelete: if err != nil {
h.deletePeer(r.Context(), accountID, userID, peerID, w) log.WithContext(r.Context()).Errorf("failed to delete peer: %v", err)
util.WriteError(r.Context(), err, w)
return return
case http.MethodGet:
h.getPeer(r.Context(), accountID, peerID, userID, w)
return
case http.MethodPut:
h.updatePeer(r.Context(), accountID, userID, peerID, w, r)
return
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
} }
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
// GetAllPeers returns a list of all peers associated with a provided account // GetAllPeers returns a list of all peers associated with a provided account
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
nameFilter := r.URL.Query().Get("name") nameFilter := r.URL.Query().Get("name")
ipFilter := r.URL.Query().Get("ip") ipFilter := r.URL.Query().Get("ip")
accountID, userID := userAuth.AccountId, userAuth.UserId peers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, nameFilter, ipFilter, true)
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, nameFilter, ipFilter)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, activity.SystemInitiator) settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
dnsDomain := h.networkMapController.GetDNSDomain(settings) dnsDomain := h.networkMapController.GetDNSDomain(settings)
grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) grps, _ := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers)) grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers))
respBody := make([]*api.PeerBatch, 0, len(peers)) respBody := make([]*api.PeerBatch, 0, len(peers))
@@ -332,7 +295,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0)) respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0))
} }
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
if err != nil { if err != nil {
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err) log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w) util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
@@ -356,15 +319,7 @@ func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersM
} }
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
peerID := vars["peerId"] peerID := vars["peerId"]
if len(peerID) == 0 { if len(peerID) == 0 {
@@ -372,25 +327,22 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
return return
} }
user, err := h.accountManager.GetUserByID(r.Context(), userID) user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), accountID, userID, modules.Peers, operations.Read) account, err := h.accountManager.GetAccountByID(r.Context(), userAuth.AccountId, activity.SystemInitiator)
if err != nil {
util.WriteError(r.Context(), status.NewPermissionValidationError(err), w)
return
}
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
if !allowed && !userAuth.IsChild { // Check if user is an admin/service user through their role
isAdmin := user.Role == types.UserRoleAdmin || user.Role == types.UserRoleOwner
if !isAdmin && !user.IsServiceUser && !userAuth.IsChild {
if account.Settings.RegularUsersViewBlocked { if account.Settings.RegularUsersViewBlocked {
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{}) util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
return return
@@ -408,7 +360,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
} }
} }
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
if err != nil { if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w) util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
@@ -422,13 +374,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
} }
func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r) vars := mux.Vars(r)
peerID := vars["peerId"] peerID := vars["peerId"]
if len(peerID) == 0 { if len(peerID) == 0 {
@@ -437,7 +383,7 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
} }
var req api.PeerTemporaryAccessRequest var req api.PeerTemporaryAccessRequest
err = json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return

View File

@@ -19,11 +19,11 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
@@ -174,7 +174,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
return nil, fmt.Errorf("user not found") return nil, fmt.Errorf("user not found")
} }
}, },
GetPeersFunc: func(_ context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { GetPeersFunc: func(_ context.Context, accountID, userID, nameFilter, ipFilter string, _ bool) ([]*nbpeer.Peer, error) {
return peers, nil return peers, nil
}, },
GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) {
@@ -307,9 +307,9 @@ func TestGetPeers(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET") router.HandleFunc("/api/peers/", permissions.WrapHandler(p.GetAllPeers)).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("GET") router.HandleFunc("/api/peers/{peerId}", permissions.WrapHandler(p.GetPeer)).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("PUT") router.HandleFunc("/api/peers/{peerId}", permissions.WrapHandler(p.UpdatePeer)).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@@ -498,7 +498,7 @@ func TestGetAccessiblePeers(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET") router.HandleFunc("/api/peers/{peerId}/accessible-peers", permissions.WrapHandler(p.GetAccessiblePeers)).Methods("GET")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@@ -582,7 +582,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/peers/{peerId}", p.HandlePeer).Methods("PUT") router.HandleFunc("/peers/{peerId}", permissions.WrapHandler(p.UpdatePeer)).Methods("PUT")
router.ServeHTTP(rr, req) router.ServeHTTP(rr, req)

View File

@@ -14,12 +14,12 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
@@ -121,7 +121,7 @@ func TestGetCitiesByCountry(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET") router.HandleFunc("/api/locations/countries/{country}/cities", permissions.WrapHandler(geolocationHandler.getCitiesByCountry)).Methods("GET")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@@ -214,7 +214,7 @@ func TestGetAllCountries(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET") router.HandleFunc("/api/locations/countries", permissions.WrapHandler(geolocationHandler.getAllCountries)).Methods("GET")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@@ -6,12 +6,12 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -30,8 +30,8 @@ type geolocationsHandler struct {
func AddLocationsEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, permissionsManager permissions.Manager, router *mux.Router) { func AddLocationsEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, permissionsManager permissions.Manager, router *mux.Router) {
locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, permissionsManager) locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, permissionsManager)
router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS") router.HandleFunc("/locations/countries", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getAllCountries)).Methods("GET", "OPTIONS")
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS") router.HandleFunc("/locations/countries/{country}/cities", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getCitiesByCountry)).Methods("GET", "OPTIONS")
} }
// newGeolocationsHandlerHandler creates a new Geolocations handler // newGeolocationsHandlerHandler creates a new Geolocations handler
@@ -44,12 +44,7 @@ func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationMa
} }
// getAllCountries retrieves a list of all countries // getAllCountries retrieves a list of all countries
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) { func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
if l.geolocationManager == nil { if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready // TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w) util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
@@ -70,12 +65,7 @@ func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Req
} }
// getCitiesByCountry retrieves a list of cities based on the given country code // getCitiesByCountry retrieves a list of cities based on the given country code
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) { func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r) vars := mux.Vars(r)
countryCode := vars["country"] countryCode := vars["country"]
if !countryCodeRegex.MatchString(countryCode) { if !countryCodeRegex.MatchString(countryCode) {
@@ -102,27 +92,6 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.
util.WriteJSONObject(r.Context(), w, cities) util.WriteJSONObject(r.Context(), w, cities)
} }
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
return err
}
accountID, userID := userAuth.AccountId, userAuth.UserId
allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
return nil
}
func toCountryResponse(country geolocation.Country) api.Country { func toCountryResponse(country geolocation.Country) api.Country {
return api.Country{ return api.Country{
CountryName: country.CountryName, CountryName: country.CountryName,

View File

@@ -7,10 +7,13 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -21,13 +24,13 @@ type handler struct {
accountManager account.Manager accountManager account.Manager
} }
func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) { func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) {
policiesHandler := newHandler(accountManager) policiesHandler := newHandler(accountManager)
router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS") router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getAllPolicies)).Methods("GET", "OPTIONS")
router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS") router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Create, policiesHandler.createPolicy)).Methods("POST", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS") router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Update, policiesHandler.updatePolicy)).Methods("PUT", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS") router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getPolicy)).Methods("GET", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS") router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, policiesHandler.deletePolicy)).Methods("DELETE", "OPTIONS")
} }
// newHandler creates a new policies handler // newHandler creates a new policies handler
@@ -38,22 +41,14 @@ func newHandler(accountManager account.Manager) *handler {
} }
// getAllPolicies list for the account // getAllPolicies list for the account
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) listPolicies, err := h.accountManager.ListPolicies(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
accountID, userID := userAuth.AccountId, userAuth.UserId allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -73,15 +68,7 @@ func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
} }
// updatePolicy handles update to a policy identified by a given ID // updatePolicy handles update to a policy identified by a given ID
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) { func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
policyID := vars["policyId"] policyID := vars["policyId"]
if len(policyID) == 0 { if len(policyID) == 0 {
@@ -89,26 +76,18 @@ func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
return return
} }
_, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID) _, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
h.savePolicy(w, r, accountID, userID, policyID, false) h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, policyID, false)
} }
// createPolicy handles policy creation request // createPolicy handles policy creation request
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) { func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, "", true)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
h.savePolicy(w, r, accountID, userID, "", true)
} }
// savePolicy handles policy creation and update // savePolicy handles policy creation and update
@@ -303,14 +282,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
} }
// deletePolicy handles policy deletion request // deletePolicy handles policy deletion request
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) { func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
policyID := vars["policyId"] policyID := vars["policyId"]
if len(policyID) == 0 { if len(policyID) == 0 {
@@ -318,7 +290,7 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
return return
} }
if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil { if err := h.accountManager.DeletePolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -327,15 +299,7 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
} }
// getPolicy handles a group Get request identified by ID // getPolicy handles a group Get request identified by ID
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) { func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
policyID := vars["policyId"] policyID := vars["policyId"]
if len(policyID) == 0 { if len(policyID) == 0 {
@@ -343,13 +307,13 @@ func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
return return
} }
policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID) policy, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@@ -13,6 +13,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -111,7 +112,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET") router.HandleFunc("/api/policies/{policyId}", permissions.WrapHandler(p.getPolicy)).Methods("GET")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@@ -275,8 +276,8 @@ func TestPoliciesWritePolicy(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/policies", p.createPolicy).Methods("POST") router.HandleFunc("/api/policies", permissions.WrapHandler(p.createPolicy)).Methods("POST")
router.HandleFunc("/api/policies/{policyId}", p.updatePolicy).Methods("PUT") router.HandleFunc("/api/policies/{policyId}", permissions.WrapHandler(p.updatePolicy)).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@@ -6,10 +6,13 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -21,13 +24,13 @@ type postureChecksHandler struct {
geolocationManager geolocation.Geolocation geolocationManager geolocation.Geolocation
} }
func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) { func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) {
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager) postureCheckHandler := newPostureChecksHandler(accountManager, locationManager)
router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS") router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getAllPostureChecks)).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS") router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Create, postureCheckHandler.createPostureCheck)).Methods("POST", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS") router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Update, postureCheckHandler.updatePostureCheck)).Methods("PUT", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS") router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getPostureCheck)).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS") router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, postureCheckHandler.deletePostureCheck)).Methods("DELETE", "OPTIONS")
} }
// newPostureChecksHandler creates a new PostureChecks handler // newPostureChecksHandler creates a new PostureChecks handler
@@ -39,15 +42,8 @@ func newPostureChecksHandler(accountManager account.Manager, geolocationManager
} }
// getAllPostureChecks list for the account // getAllPostureChecks list for the account
func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) { func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -62,15 +58,7 @@ func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *htt
} }
// updatePostureCheck handles update to a posture check identified by a given ID // updatePostureCheck handles update to a posture check identified by a given ID
func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) { func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"] postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 { if len(postureChecksID) == 0 {
@@ -78,37 +66,22 @@ func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http
return return
} }
_, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID) _, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
p.savePostureChecks(w, r, accountID, userID, postureChecksID, false) p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, postureChecksID, false)
} }
// createPostureCheck handles posture check creation request // createPostureCheck handles posture check creation request
func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) { func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, "", true)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
p.savePostureChecks(w, r, accountID, userID, "", true)
} }
// getPostureCheck handles a posture check Get request identified by ID // getPostureCheck handles a posture check Get request identified by ID
func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) { func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"] postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 { if len(postureChecksID) == 0 {
@@ -116,7 +89,7 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re
return return
} }
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID) postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -126,14 +99,7 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re
} }
// deletePostureCheck handles posture check deletion request // deletePostureCheck handles posture check deletion request
func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) { func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"] postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 { if len(postureChecksID) == 0 {
@@ -141,7 +107,7 @@ func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http
return return
} }
if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil { if err := p.accountManager.DeletePostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }

View File

@@ -14,6 +14,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
@@ -183,7 +184,7 @@ func TestGetPostureCheck(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET") router.HandleFunc("/api/posture-checks/{postureCheckId}", permissions.WrapHandler(p.getPostureCheck)).Methods("GET")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@@ -841,8 +842,8 @@ func TestPostureCheckUpdate(t *testing.T) {
} }
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/posture-checks", defaultHandler.createPostureCheck).Methods("POST") router.HandleFunc("/api/posture-checks", permissions.WrapHandler(defaultHandler.createPostureCheck)).Methods("POST")
router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.updatePostureCheck).Methods("PUT") router.HandleFunc("/api/posture-checks/{postureCheckId}", permissions.WrapHandler(defaultHandler.updatePostureCheck)).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@@ -8,9 +8,12 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
@@ -26,13 +29,13 @@ type handler struct {
accountManager account.Manager accountManager account.Manager
} }
func AddEndpoints(accountManager account.Manager, router *mux.Router) { func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
routesHandler := newHandler(accountManager) routesHandler := newHandler(accountManager)
router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS") router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getAllRoutes)).Methods("GET", "OPTIONS")
router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS") router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Create, routesHandler.createRoute)).Methods("POST", "OPTIONS")
router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS") router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Update, routesHandler.updateRoute)).Methods("PUT", "OPTIONS")
router.HandleFunc("/routes/{routeId}", routesHandler.getRoute).Methods("GET", "OPTIONS") router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getRoute)).Methods("GET", "OPTIONS")
router.HandleFunc("/routes/{routeId}", routesHandler.deleteRoute).Methods("DELETE", "OPTIONS") router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Delete, routesHandler.deleteRoute)).Methods("DELETE", "OPTIONS")
} }
// newHandler returns a new instance of routes handler // newHandler returns a new instance of routes handler
@@ -43,16 +46,8 @@ func newHandler(accountManager account.Manager) *handler {
} }
// getAllRoutes returns the list of routes for the account // getAllRoutes returns the list of routes for the account
func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) routes, err := h.accountManager.ListRoutes(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -71,17 +66,9 @@ func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
} }
// createRoute handles route creation request // createRoute handles route creation request
func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { func (h *handler) createRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.PostApiRoutesJSONRequestBody var req api.PostApiRoutesJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
@@ -134,8 +121,8 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
skipAutoApply = false skipAutoApply = false
} }
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds, newRoute, err := h.accountManager.CreateRoute(r.Context(), userAuth.AccountId, newPrefix, networkType, domains, peerId, peerGroupIds,
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute, skipAutoApply) req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userAuth.UserId, req.KeepRoute, skipAutoApply)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
@@ -185,14 +172,7 @@ func (h *handler) validateRouteCommon(network *string, domains *[]string, peer *
} }
// updateRoute handles update to a route identified by a given ID // updateRoute handles update to a route identified by a given ID
func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
routeID := vars["routeId"] routeID := vars["routeId"]
if len(routeID) == 0 { if len(routeID) == 0 {
@@ -200,7 +180,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
return return
} }
_, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID) _, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -271,7 +251,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
newRoute.AccessControlGroups = *req.AccessControlGroups newRoute.AccessControlGroups = *req.AccessControlGroups
} }
err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute) err = h.accountManager.SaveRoute(r.Context(), userAuth.AccountId, userAuth.UserId, newRoute)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -287,21 +267,14 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
} }
// deleteRoute handles route deletion request // deleteRoute handles route deletion request
func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routeID := mux.Vars(r)["routeId"] routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 { if len(routeID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return return
} }
err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID) err := h.accountManager.DeleteRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -311,22 +284,14 @@ func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
} }
// getRoute handles a route Get request identified by ID // getRoute handles a route Get request identified by ID
func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) { func (h *handler) getRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routeID := mux.Vars(r)["routeId"] routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 { if len(routeID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return return
} }
foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID) foundRoute, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@@ -15,6 +15,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
@@ -501,10 +502,10 @@ func TestRoutesHandlers(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET") router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.getRoute)).Methods("GET")
router.HandleFunc("/api/routes/{routeId}", p.deleteRoute).Methods("DELETE") router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.deleteRoute)).Methods("DELETE")
router.HandleFunc("/api/routes", p.createRoute).Methods("POST") router.HandleFunc("/api/routes", permissions.WrapHandler(p.createRoute)).Methods("POST")
router.HandleFunc("/api/routes/{routeId}", p.updateRoute).Methods("PUT") router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.updateRoute)).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@@ -8,9 +8,12 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -21,13 +24,13 @@ type handler struct {
accountManager account.Manager accountManager account.Manager
} }
func AddEndpoints(accountManager account.Manager, router *mux.Router) { func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
keysHandler := newHandler(accountManager) keysHandler := newHandler(accountManager)
router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS") router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getAllSetupKeys)).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS") router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Create, keysHandler.createSetupKey)).Methods("POST", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS") router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getSetupKey)).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", keysHandler.updateSetupKey).Methods("PUT", "OPTIONS") router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Update, keysHandler.updateSetupKey)).Methods("PUT", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", keysHandler.deleteSetupKey).Methods("DELETE", "OPTIONS") router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Delete, keysHandler.deleteSetupKey)).Methods("DELETE", "OPTIONS")
} }
// newHandler creates a new setup key handler // newHandler creates a new setup key handler
@@ -38,16 +41,9 @@ func newHandler(accountManager account.Manager) *handler {
} }
// createSetupKey is a POST requests that creates a new SetupKey // createSetupKey is a POST requests that creates a new SetupKey
func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
req := &api.PostApiSetupKeysJSONRequestBody{} req := &api.PostApiSetupKeysJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
@@ -85,8 +81,8 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
allowExtraDNSLabels = *req.AllowExtraDnsLabels allowExtraDNSLabels = *req.AllowExtraDnsLabels
} }
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, types.SetupKeyType(req.Type), expiresIn, setupKey, err := h.accountManager.CreateSetupKey(r.Context(), userAuth.AccountId, req.Name, types.SetupKeyType(req.Type), expiresIn,
req.AutoGroups, req.UsageLimit, userID, ephemeral, allowExtraDNSLabels) req.AutoGroups, req.UsageLimit, userAuth.UserId, ephemeral, allowExtraDNSLabels)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -100,14 +96,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
} }
// getSetupKey is a GET request to get a SetupKey by ID // getSetupKey is a GET request to get a SetupKey by ID
func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) { func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
keyID := vars["keyId"] keyID := vars["keyId"]
if len(keyID) == 0 { if len(keyID) == 0 {
@@ -115,7 +104,7 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
return return
} }
key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID) key, err := h.accountManager.GetSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -125,14 +114,7 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
} }
// updateSetupKey is a PUT request to update server.SetupKey // updateSetupKey is a PUT request to update server.SetupKey
func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
keyID := vars["keyId"] keyID := vars["keyId"]
if len(keyID) == 0 { if len(keyID) == 0 {
@@ -141,7 +123,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
} }
req := &api.PutApiSetupKeysKeyIdJSONRequestBody{} req := &api.PutApiSetupKeysKeyIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
@@ -157,7 +139,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
newKey.Revoked = req.Revoked newKey.Revoked = req.Revoked
newKey.Id = keyID newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID) newKey, err = h.accountManager.SaveSetupKey(r.Context(), userAuth.AccountId, newKey, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -166,15 +148,8 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
} }
// getAllSetupKeys is a GET request that returns a list of SetupKey // getAllSetupKeys is a GET request that returns a list of SetupKey
func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -188,14 +163,7 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, apiSetupKeys) util.WriteJSONObject(r.Context(), w, apiSetupKeys)
} }
func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
keyID := vars["keyId"] keyID := vars["keyId"]
if len(keyID) == 0 { if len(keyID) == 0 {
@@ -203,7 +171,7 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.accountManager.DeleteSetupKey(r.Context(), accountID, userID, keyID) err := h.accountManager.DeleteSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@@ -14,6 +14,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -171,11 +172,11 @@ func TestSetupKeysHandlers(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS") router.HandleFunc("/api/setup-keys", permissions.WrapHandler(handler.getAllSetupKeys)).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys", handler.createSetupKey).Methods("POST", "OPTIONS") router.HandleFunc("/api/setup-keys", permissions.WrapHandler(handler.createSetupKey)).Methods("POST", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.getSetupKey).Methods("GET", "OPTIONS") router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.getSetupKey)).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.updateSetupKey).Methods("PUT", "OPTIONS") router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.updateSetupKey)).Methods("PUT", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.deleteSetupKey).Methods("DELETE", "OPTIONS") router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.deleteSetupKey)).Methods("DELETE", "OPTIONS")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@@ -9,10 +9,13 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -55,14 +58,14 @@ type invitesHandler struct {
} }
// AddInvitesEndpoints registers invite-related endpoints // AddInvitesEndpoints registers invite-related endpoints
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) { func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
h := &invitesHandler{accountManager: accountManager} h := &invitesHandler{accountManager: accountManager}
// Authenticated endpoints (require admin) // Authenticated endpoints (require admin)
router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS") router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Read, h.listInvites)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS") router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Create, h.createInvite)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS") router.HandleFunc("/users/invites/{inviteId}", permissionsManager.WithPermission(modules.Users, operations.Delete, h.deleteInvite)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS") router.HandleFunc("/users/invites/{inviteId}/regenerate", permissionsManager.WithPermission(modules.Users, operations.Update, h.regenerateInvite)).Methods("POST", "OPTIONS")
} }
// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting // AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting
@@ -79,14 +82,7 @@ func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Route
} }
// listInvites handles GET /api/users/invites // listInvites handles GET /api/users/invites
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) { func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId) invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
@@ -102,14 +98,7 @@ func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
} }
// createInvite handles POST /api/users/invites // createInvite handles POST /api/users/invites
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) { func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.UserInviteCreateRequest var req api.UserInviteCreateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -191,18 +180,12 @@ func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) {
} }
// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate // regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) { func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return return
} }
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r) vars := mux.Vars(r)
inviteID := vars["inviteId"] inviteID := vars["inviteId"]
if inviteID == "" { if inviteID == "" {
@@ -238,14 +221,7 @@ func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request
} }
// deleteInvite handles DELETE /api/users/invites/{inviteId} // deleteInvite handles DELETE /api/users/invites/{inviteId}
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) { func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r) vars := mux.Vars(r)
inviteID := vars["inviteId"] inviteID := vars["inviteId"]
if inviteID == "" { if inviteID == "" {
@@ -253,7 +229,7 @@ func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID) err := h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@@ -110,7 +110,11 @@ func TestListInvites(t *testing.T) {
}) })
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
handler.listInvites(rr, req) userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.listInvites(rr, req, userAuth)
assert.Equal(t, tc.expectedStatus, rr.Code) assert.Equal(t, tc.expectedStatus, rr.Code)
@@ -235,7 +239,11 @@ func TestCreateInvite(t *testing.T) {
}) })
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
handler.createInvite(rr, req) userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.createInvite(rr, req, userAuth)
assert.Equal(t, tc.expectedStatus, rr.Code) assert.Equal(t, tc.expectedStatus, rr.Code)
@@ -573,7 +581,11 @@ func TestRegenerateInvite(t *testing.T) {
} }
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
handler.regenerateInvite(rr, req) userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.regenerateInvite(rr, req, userAuth)
assert.Equal(t, tc.expectedStatus, rr.Code) assert.Equal(t, tc.expectedStatus, rr.Code)
@@ -651,7 +663,11 @@ func TestDeleteInvite(t *testing.T) {
} }
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
handler.deleteInvite(rr, req) userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.deleteInvite(rr, req, userAuth)
assert.Equal(t, tc.expectedStatus, rr.Code) assert.Equal(t, tc.expectedStatus, rr.Code)
}) })

View File

@@ -6,9 +6,12 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -19,12 +22,12 @@ type patHandler struct {
accountManager account.Manager accountManager account.Manager
} }
func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router) { func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
tokenHandler := newPATsHandler(accountManager) tokenHandler := newPATsHandler(accountManager)
router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS") router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getAllTokens)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS") router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Create, tokenHandler.createToken)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS") router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getToken)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS") router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Delete, tokenHandler.deleteToken)).Methods("DELETE", "OPTIONS")
} }
// newPATsHandler creates a new patHandler HTTP handler // newPATsHandler creates a new patHandler HTTP handler
@@ -35,22 +38,15 @@ func newPATsHandler(accountManager account.Manager) *patHandler {
} }
// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user // getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) { func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
targetUserID := vars["userId"] targetUserID := vars["userId"]
if len(userID) == 0 { if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return return
} }
pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID) pats, err := h.accountManager.GetAllPATs(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -65,14 +61,7 @@ func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
} }
// getToken is HTTP GET handler that returns a personal access token for the given user // getToken is HTTP GET handler that returns a personal access token for the given user
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) { func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
targetUserID := vars["userId"] targetUserID := vars["userId"]
if len(targetUserID) == 0 { if len(targetUserID) == 0 {
@@ -86,7 +75,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
return return
} }
pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID) pat, err := h.accountManager.GetPAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -96,14 +85,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
} }
// createToken is HTTP POST handler that creates a personal access token for the given user // createToken is HTTP POST handler that creates a personal access token for the given user
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) { func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
targetUserID := vars["userId"] targetUserID := vars["userId"]
if len(targetUserID) == 0 { if len(targetUserID) == 0 {
@@ -112,13 +94,13 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
} }
var req api.PostApiUsersUserIdTokensJSONRequestBody var req api.PostApiUsersUserIdTokensJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn) pat, err := h.accountManager.CreatePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.Name, req.ExpiresIn)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -128,14 +110,7 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
} }
// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user // deleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) { func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
targetUserID := vars["userId"] targetUserID := vars["userId"]
if len(targetUserID) == 0 { if len(targetUserID) == 0 {
@@ -149,7 +124,7 @@ func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID) err := h.accountManager.DeletePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -181,10 +182,10 @@ func TestTokenHandlers(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET") router.HandleFunc("/api/users/{userId}/tokens", permissions.WrapHandler(p.getAllTokens)).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET") router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", permissions.WrapHandler(p.getToken)).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST") router.HandleFunc("/api/users/{userId}/tokens", permissions.WrapHandler(p.createToken)).Methods("POST")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE") router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", permissions.WrapHandler(p.deleteToken)).Methods("DELETE")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@@ -8,14 +8,16 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
nbcontext "github.com/netbirdio/netbird/management/server/context"
) )
// handler is a handler that returns users of the account // handler is a handler that returns users of the account
@@ -23,18 +25,18 @@ type handler struct {
accountManager account.Manager accountManager account.Manager
} }
func AddEndpoints(accountManager account.Manager, router *mux.Router) { func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
userHandler := newHandler(accountManager) userHandler := newHandler(accountManager)
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS") router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getAllUsers, userHandler.getOwnUser)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/current", userHandler.getCurrentUser).Methods("GET", "OPTIONS") router.HandleFunc("/users/current", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getCurrentUser, userHandler.getCurrentUserFallback)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS") router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.updateUser)).Methods("PUT", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS") router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.deleteUser)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS") router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.createUser)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS") router.HandleFunc("/users/{userId}/invite", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.inviteUser)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS") router.HandleFunc("/users/{userId}/approve", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.approveUser)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS") router.HandleFunc("/users/{userId}/reject", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.rejectUser)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/{userId}/password", userHandler.changePassword).Methods("PUT", "OPTIONS") router.HandleFunc("/users/{userId}/password", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.changePassword)).Methods("PUT", "OPTIONS")
addUsersTokensEndpoint(accountManager, router) addUsersTokensEndpoint(accountManager, router, permissionsManager)
} }
// newHandler creates a new UsersHandler HTTP handler // newHandler creates a new UsersHandler HTTP handler
@@ -45,19 +47,12 @@ func newHandler(accountManager account.Manager) *handler {
} }
// updateUser is a PUT requests to update User data // updateUser is a PUT requests to update User data
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { func (h *handler) updateUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPut { if r.Method != http.MethodPut {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return return
} }
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
targetUserID := vars["userId"] targetUserID := vars["userId"]
if len(targetUserID) == 0 { if len(targetUserID) == 0 {
@@ -71,6 +66,11 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
return return
} }
if existingUser.AccountID != userAuth.AccountId {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "user not found"), w)
return
}
req := &api.PutApiUsersUserIdJSONRequestBody{} req := &api.PutApiUsersUserIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req) err = json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
@@ -89,7 +89,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
return return
} }
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &types.User{ newUser, err := h.accountManager.SaveUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.User{
Id: targetUserID, Id: targetUserID,
Role: userRole, Role: userRole,
AutoGroups: req.AutoGroups, AutoGroups: req.AutoGroups,
@@ -102,23 +102,16 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID)) util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
} }
// deleteUser is a DELETE request to delete a user // deleteUser is a DELETE request to delete a user
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodDelete { if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return return
} }
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
targetUserID := vars["userId"] targetUserID := vars["userId"]
if len(targetUserID) == 0 { if len(targetUserID) == 0 {
@@ -126,7 +119,7 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID) err := h.accountManager.DeleteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -136,21 +129,14 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
} }
// createUser creates a User in the system with a status "invited" (effectively this is a user invite). // createUser creates a User in the system with a status "invited" (effectively this is a user invite).
func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { func (h *handler) createUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return return
} }
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
req := &api.PostApiUsersJSONRequestBody{} req := &api.PostApiUsersJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
@@ -171,7 +157,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
name = *req.Name name = *req.Name
} }
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &types.UserInfo{ newUser, err := h.accountManager.CreateUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.UserInfo{
Email: email, Email: email,
Name: name, Name: name,
Role: req.Role, Role: req.Role,
@@ -183,25 +169,18 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID)) util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
} }
// getAllUsers returns a list of users of the account this user belongs to. // getAllUsers returns a list of users of the account this user belongs to.
// It also gathers additional user data (like email and name) from the IDP manager. // It also gathers additional user data (like email and name) from the IDP manager.
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return return
} }
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) data, err := h.accountManager.GetUsersFromAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -215,7 +194,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
continue continue
} }
if serviceUser == "" { if serviceUser == "" {
users = append(users, toUserResponse(d, userID)) users = append(users, toUserResponse(d, userAuth.UserId))
continue continue
} }
@@ -226,7 +205,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
return return
} }
if includeServiceUser == d.IsServiceUser { if includeServiceUser == d.IsServiceUser {
users = append(users, toUserResponse(d, userID)) users = append(users, toUserResponse(d, userAuth.UserId))
} }
} }
@@ -235,19 +214,12 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
// inviteUser resend invitations to users who haven't activated their accounts, // inviteUser resend invitations to users who haven't activated their accounts,
// prior to the expiration period. // prior to the expiration period.
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) { func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return return
} }
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r) vars := mux.Vars(r)
targetUserID := vars["userId"] targetUserID := vars["userId"]
if len(targetUserID) == 0 { if len(targetUserID) == 0 {
@@ -255,7 +227,7 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID) err := h.accountManager.InviteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -264,19 +236,13 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request) { func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return return
} }
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.GetCurrentUserInfo(ctx, userAuth) user, err := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -356,7 +322,7 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
} }
// approveUser is a POST request to approve a user that is pending approval // approveUser is a POST request to approve a user that is pending approval
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) { func (h *handler) approveUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return return
@@ -369,11 +335,6 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
return return
} }
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
@@ -385,7 +346,7 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
} }
// rejectUser is a DELETE request to reject a user that is pending approval // rejectUser is a DELETE request to reject a user that is pending approval
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) { func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodDelete { if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return return
@@ -398,12 +359,7 @@ func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
return return
} }
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) err := h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -421,7 +377,7 @@ type passwordChangeRequest struct {
// changePassword is a PUT request to change user's password. // changePassword is a PUT request to change user's password.
// Only available when embedded IDP is enabled. // Only available when embedded IDP is enabled.
// Users can only change their own password. // Users can only change their own password.
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) { func (h *handler) changePassword(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPut { if r.Method != http.MethodPut {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return return
@@ -434,19 +390,13 @@ func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
return return
} }
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req passwordChangeRequest var req passwordChangeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
err = h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword) err := h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -454,3 +404,39 @@ func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
func (h *handler) getCurrentUserFallback(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth, err error) bool {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.PermissionDenied {
return false
}
user, userErr := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
if userErr != nil {
util.WriteError(r.Context(), userErr, w)
return true
}
util.WriteJSONObject(r.Context(), w, toUserWithPermissionsResponse(user, userAuth.UserId))
return true
}
func (h *handler) getOwnUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth, err error) bool {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.PermissionDenied {
return false
}
if r.URL.Query().Get("service_user") != "" {
return false
}
user, userErr := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
if userErr != nil {
util.WriteError(r.Context(), userErr, w)
return true
}
util.WriteJSONObject(r.Context(), w, []*api.User{toUserResponse(user.UserInfo, userAuth.UserId)})
return true
}

View File

@@ -15,10 +15,11 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
roles2 "github.com/netbirdio/netbird/management/internals/modules/permissions/roles"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth"
@@ -38,6 +39,7 @@ var usersTestAccount = &types.Account{
Users: map[string]*types.User{ Users: map[string]*types.User{
existingUserID: { existingUserID: {
Id: existingUserID, Id: existingUserID,
AccountID: existingAccountID,
Role: "admin", Role: "admin",
IsServiceUser: false, IsServiceUser: false,
AutoGroups: []string{"group_1"}, AutoGroups: []string{"group_1"},
@@ -45,6 +47,7 @@ var usersTestAccount = &types.Account{
}, },
regularUserID: { regularUserID: {
Id: regularUserID, Id: regularUserID,
AccountID: existingAccountID,
Role: "user", Role: "user",
IsServiceUser: false, IsServiceUser: false,
AutoGroups: []string{"group_1"}, AutoGroups: []string{"group_1"},
@@ -52,6 +55,7 @@ var usersTestAccount = &types.Account{
}, },
serviceUserID: { serviceUserID: {
Id: serviceUserID, Id: serviceUserID,
AccountID: existingAccountID,
Role: "user", Role: "user",
IsServiceUser: true, IsServiceUser: true,
AutoGroups: []string{"group_1"}, AutoGroups: []string{"group_1"},
@@ -59,6 +63,7 @@ var usersTestAccount = &types.Account{
}, },
nonDeletableServiceUserID: { nonDeletableServiceUserID: {
Id: nonDeletableServiceUserID, Id: nonDeletableServiceUserID,
AccountID: existingAccountID,
Role: "admin", Role: "admin",
IsServiceUser: true, IsServiceUser: true,
NonDeletable: true, NonDeletable: true,
@@ -151,7 +156,7 @@ func initUsersTestData() *handler {
NonDeletable: false, NonDeletable: false,
Issued: "api", Issued: "api",
}, },
Permissions: mergeRolePermissions(roles.Owner), Permissions: mergeRolePermissions(roles2.Owner),
}, nil }, nil
case "regular-user": case "regular-user":
return &users.UserInfoWithPermissions{ return &users.UserInfoWithPermissions{
@@ -165,7 +170,7 @@ func initUsersTestData() *handler {
NonDeletable: false, NonDeletable: false,
Issued: "api", Issued: "api",
}, },
Permissions: mergeRolePermissions(roles.User), Permissions: mergeRolePermissions(roles2.User),
}, nil }, nil
case "admin-user": case "admin-user":
@@ -181,7 +186,7 @@ func initUsersTestData() *handler {
LastLogin: time.Time{}, LastLogin: time.Time{},
Issued: "api", Issued: "api",
}, },
Permissions: mergeRolePermissions(roles.Admin), Permissions: mergeRolePermissions(roles2.Admin),
}, nil }, nil
case "restricted-user": case "restricted-user":
return &users.UserInfoWithPermissions{ return &users.UserInfoWithPermissions{
@@ -196,7 +201,7 @@ func initUsersTestData() *handler {
LastLogin: time.Time{}, LastLogin: time.Time{},
Issued: "api", Issued: "api",
}, },
Permissions: mergeRolePermissions(roles.User), Permissions: mergeRolePermissions(roles2.User),
Restricted: true, Restricted: true,
}, nil }, nil
} }
@@ -232,7 +237,11 @@ func TestGetUsers(t *testing.T) {
AccountId: existingAccountID, AccountId: existingAccountID,
}) })
userHandler.getAllUsers(recorder, req) userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.getAllUsers(recorder, req, userAuth)
res := recorder.Result() res := recorder.Result()
defer res.Body.Close() defer res.Body.Close()
@@ -343,7 +352,7 @@ func TestUpdateUser(t *testing.T) {
}) })
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT") router.HandleFunc("/api/users/{userId}", permissions.WrapHandler(userHandler.updateUser)).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@@ -439,7 +448,11 @@ func TestCreateUser(t *testing.T) {
AccountId: existingAccountID, AccountId: existingAccountID,
}) })
userHandler.createUser(rr, req) userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.createUser(rr, req, userAuth)
res := rr.Result() res := rr.Result()
defer res.Body.Close() defer res.Body.Close()
@@ -490,7 +503,11 @@ func TestInviteUser(t *testing.T) {
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
userHandler.inviteUser(rr, req) userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.inviteUser(rr, req, userAuth)
res := rr.Result() res := rr.Result()
defer res.Body.Close() defer res.Body.Close()
@@ -549,7 +566,11 @@ func TestDeleteUser(t *testing.T) {
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
userHandler.deleteUser(rr, req) userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.deleteUser(rr, req, userAuth)
res := rr.Result() res := rr.Result()
defer res.Body.Close() defer res.Body.Close()
@@ -608,7 +629,7 @@ func TestCurrentUser(t *testing.T) {
Issued: ptr("api"), Issued: ptr("api"),
LastLogin: ptr(time.Time{}), LastLogin: ptr(time.Time{}),
Permissions: &api.UserPermissions{ Permissions: &api.UserPermissions{
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Owner)), Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.Owner)),
}, },
}, },
}, },
@@ -627,7 +648,7 @@ func TestCurrentUser(t *testing.T) {
Issued: ptr("api"), Issued: ptr("api"),
LastLogin: ptr(time.Time{}), LastLogin: ptr(time.Time{}),
Permissions: &api.UserPermissions{ Permissions: &api.UserPermissions{
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)), Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.User)),
}, },
}, },
}, },
@@ -646,7 +667,7 @@ func TestCurrentUser(t *testing.T) {
Issued: ptr("api"), Issued: ptr("api"),
LastLogin: ptr(time.Time{}), LastLogin: ptr(time.Time{}),
Permissions: &api.UserPermissions{ Permissions: &api.UserPermissions{
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Admin)), Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.Admin)),
}, },
}, },
}, },
@@ -666,7 +687,7 @@ func TestCurrentUser(t *testing.T) {
LastLogin: ptr(time.Time{}), LastLogin: ptr(time.Time{}),
Permissions: &api.UserPermissions{ Permissions: &api.UserPermissions{
IsRestricted: true, IsRestricted: true,
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)), Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.User)),
}, },
}, },
}, },
@@ -682,7 +703,11 @@ func TestCurrentUser(t *testing.T) {
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
userHandler.getCurrentUser(rr, req) userAuth := &auth.UserAuth{
UserId: tc.requestAuth.UserId,
AccountId: existingAccountID,
}
userHandler.getCurrentUser(rr, req, userAuth)
res := rr.Result() res := rr.Result()
defer res.Body.Close() defer res.Body.Close()
@@ -702,8 +727,8 @@ func ptr[T any, PT *T](x T) PT {
return &x return &x
} }
func mergeRolePermissions(role roles.RolePermissions) roles.Permissions { func mergeRolePermissions(role roles2.RolePermissions) roles2.Permissions {
permissions := roles.Permissions{} permissions := roles2.Permissions{}
for k := range modules.All { for k := range modules.All {
if rolePermissions, ok := role.Permissions[k]; ok { if rolePermissions, ok := role.Permissions[k]; ok {
@@ -716,7 +741,7 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
return permissions return permissions
} }
func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[string]bool { func stringifyPermissionsKeys(permissions roles2.Permissions) map[string]map[string]bool {
modules := make(map[string]map[string]bool) modules := make(map[string]map[string]bool)
for module, operations := range permissions { for module, operations := range permissions {
modules[string(module)] = make(map[string]bool) modules[string(module)] = make(map[string]bool)
@@ -779,7 +804,7 @@ func TestApproveUserEndpoint(t *testing.T) {
handler := newHandler(am) handler := newHandler(am)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST") router.HandleFunc("/users/{userId}/approve", permissions.WrapHandler(handler.approveUser)).Methods("POST")
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil) req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
require.NoError(t, err) require.NoError(t, err)
@@ -837,7 +862,7 @@ func TestRejectUserEndpoint(t *testing.T) {
handler := newHandler(am) handler := newHandler(am)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/users/{userId}/reject", handler.rejectUser).Methods("DELETE") router.HandleFunc("/users/{userId}/reject", permissions.WrapHandler(handler.rejectUser)).Methods("DELETE")
req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil) req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
require.NoError(t, err) require.NoError(t, err)
@@ -928,7 +953,7 @@ func TestChangePasswordEndpoint(t *testing.T) {
handler := newHandler(am) handler := newHandler(am)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/users/{userId}/password", handler.changePassword).Methods("PUT") router.HandleFunc("/users/{userId}/password", permissions.WrapHandler(handler.changePassword)).Methods("PUT")
reqPath := "/users/" + tc.targetUserID + "/password" reqPath := "/users/" + tc.targetUserID + "/password"
req, err := http.NewRequest("PUT", reqPath, bytes.NewBufferString(tc.requestBody)) req, err := http.NewRequest("PUT", reqPath, bytes.NewBufferString(tc.requestBody))
@@ -967,7 +992,7 @@ func TestChangePasswordEndpoint_WrongMethod(t *testing.T) {
req = nbcontext.SetUserAuthInRequest(req, userAuth) req = nbcontext.SetUserAuthInRequest(req, userAuth)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
handler.changePassword(rr, req) handler.changePassword(rr, req, &userAuth)
assert.Equal(t, http.StatusMethodNotAllowed, rr.Code) assert.Equal(t, http.StatusMethodNotAllowed, rr.Code)
} }

View File

@@ -27,7 +27,7 @@ func Test_Accounts_GetAll(t *testing.T) {
{"Regular user", testing_tools.TestUserId, false}, {"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true}, {"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true}, {"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, true}, {"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true}, {"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false}, {"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false}, {"Other user", testing_tools.OtherUserId, false},
@@ -233,6 +233,71 @@ func Test_Accounts_Update(t *testing.T) {
} }
} }
func Test_Accounts_Update_CrossAccountAttack(t *testing.T) {
t.Run("Other user attempts to update testAccount via URL", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
body, err := json.Marshal(&api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: false,
PeerLoginExpiration: 86400,
},
})
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
// OtherUserId belongs to otherAccountId, but we target testAccountId in URL
req := testing_tools.BuildRequest(t, body, http.MethodPut, "/api/accounts/"+testing_tools.TestAccountId, testing_tools.OtherUserId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account update must be rejected")
})
}
func Test_Accounts_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, false},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, false},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Delete account", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/accounts/"+testing_tools.TestAccountId, user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
})
}
}
func Test_Accounts_Delete_CrossAccountAttack(t *testing.T) {
t.Run("Other user attempts to delete testAccount via URL", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
// OtherUserId belongs to otherAccountId, but we target testAccountId in URL
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/accounts/"+testing_tools.TestAccountId, testing_tools.OtherUserId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account delete must be rejected")
})
}
func stringPointer(s string) *string { func stringPointer(s string) *string {
return &s return &s
} }

View File

@@ -0,0 +1,445 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Records_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all records", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/zones/testZoneId/records", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.DNSRecord{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 1, len(got))
assert.Equal(t, "sub.example.com", got[0].Name)
assert.Equal(t, api.DNSRecordTypeA, got[0].Type)
assert.Equal(t, "1.2.3.4", got[0].Content)
assert.Equal(t, 300, got[0].Ttl)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Records_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
recordId string
expectedStatus int
expectRecord bool
}{
{
name: "Get existing record",
zoneId: "testZoneId",
recordId: "testRecordId",
expectedStatus: http.StatusOK,
expectRecord: true,
},
{
name: "Get non-existing record",
zoneId: "testZoneId",
recordId: "nonExistingRecordId",
expectedStatus: http.StatusNotFound,
expectRecord: false,
},
{
name: "Get record from non-existing zone",
zoneId: "nonExistingZoneId",
recordId: "testRecordId",
expectedStatus: http.StatusNotFound,
expectRecord: false,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1)
path = strings.Replace(path, "{recordId}", tc.recordId, 1)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, path, user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.expectRecord {
got := &api.DNSRecord{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, "testRecordId", got.Id)
assert.Equal(t, "sub.example.com", got.Name)
assert.Equal(t, api.DNSRecordTypeA, got.Type)
assert.Equal(t, "1.2.3.4", got.Content)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_Records_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
requestBody *api.PostApiDnsZonesZoneIdRecordsJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, record *api.DNSRecord)
}{
{
name: "Create A record",
zoneId: "testZoneId",
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
Name: "new.example.com",
Type: api.DNSRecordTypeA,
Content: "5.6.7.8",
Ttl: 600,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, record *api.DNSRecord) {
t.Helper()
assert.NotEmpty(t, record.Id)
assert.Equal(t, "new.example.com", record.Name)
assert.Equal(t, api.DNSRecordTypeA, record.Type)
assert.Equal(t, "5.6.7.8", record.Content)
assert.Equal(t, 600, record.Ttl)
},
},
{
name: "Create CNAME record",
zoneId: "testZoneId",
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
Name: "alias.example.com",
Type: api.DNSRecordTypeCNAME,
Content: "target.example.com",
Ttl: 300,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, record *api.DNSRecord) {
t.Helper()
assert.NotEmpty(t, record.Id)
assert.Equal(t, "alias.example.com", record.Name)
assert.Equal(t, api.DNSRecordTypeCNAME, record.Type)
assert.Equal(t, "target.example.com", record.Content)
},
},
{
name: "Create record with invalid content for A type",
zoneId: "testZoneId",
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
Name: "bad.example.com",
Type: api.DNSRecordTypeA,
Content: "not-an-ip",
Ttl: 300,
},
expectedStatus: http.StatusUnprocessableEntity,
},
{
name: "Create record in non-existing zone",
zoneId: "nonExistingZoneId",
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
Name: "new.example.com",
Type: api.DNSRecordTypeA,
Content: "5.6.7.8",
Ttl: 600,
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
path := strings.Replace("/api/dns/zones/{zoneId}/records", "{zoneId}", tc.zoneId, 1)
req := testing_tools.BuildRequest(t, body, http.MethodPost, path, user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.DNSRecord{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify the created record directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbRecord := testing_tools.VerifyRecordInDB(t, db, got.Id)
assert.Equal(t, got.Name, dbRecord.Name)
assert.Equal(t, got.Content, dbRecord.Content)
assert.Equal(t, got.Ttl, dbRecord.TTL)
}
})
}
}
}
func Test_Records_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
recordId string
requestBody *api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, record *api.DNSRecord)
}{
{
name: "Update record content and TTL",
zoneId: "testZoneId",
recordId: "testRecordId",
requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
Name: "sub.example.com",
Type: api.DNSRecordTypeA,
Content: "10.20.30.40",
Ttl: 600,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, record *api.DNSRecord) {
t.Helper()
assert.Equal(t, "sub.example.com", record.Name)
assert.Equal(t, "10.20.30.40", record.Content)
assert.Equal(t, 600, record.Ttl)
},
},
{
name: "Update non-existing record",
zoneId: "testZoneId",
recordId: "nonExistingRecordId",
requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
Name: "sub.example.com",
Type: api.DNSRecordTypeA,
Content: "10.20.30.40",
Ttl: 600,
},
expectedStatus: http.StatusNotFound,
},
{
name: "Update record in non-existing zone",
zoneId: "nonExistingZoneId",
recordId: "testRecordId",
requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
Name: "sub.example.com",
Type: api.DNSRecordTypeA,
Content: "10.20.30.40",
Ttl: 600,
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1)
path = strings.Replace(path, "{recordId}", tc.recordId, 1)
req := testing_tools.BuildRequest(t, body, http.MethodPut, path, user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.DNSRecord{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify the updated record directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbRecord := testing_tools.VerifyRecordInDB(t, db, tc.recordId)
assert.Equal(t, "10.20.30.40", dbRecord.Content)
assert.Equal(t, 600, dbRecord.TTL)
}
})
}
}
}
func Test_Records_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
recordId string
expectedStatus int
}{
{
name: "Delete existing record",
zoneId: "testZoneId",
recordId: "testRecordId",
expectedStatus: http.StatusOK,
},
{
name: "Delete non-existing record",
zoneId: "testZoneId",
recordId: "nonExistingRecordId",
expectedStatus: http.StatusNotFound,
},
{
name: "Delete record from non-existing zone",
zoneId: "nonExistingZoneId",
recordId: "testRecordId",
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1)
path = strings.Replace(path, "{recordId}", tc.recordId, 1)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, path, user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
// Verify deletion in DB for successful deletes by privileged users
if tc.expectedStatus == http.StatusOK && user.expectResponse {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyRecordNotInDB(t, db, tc.recordId)
}
})
}
}
}

View File

@@ -0,0 +1,416 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Zones_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all zones", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/zones", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.Zone{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 1, len(got))
assert.Equal(t, "Test Zone", got[0].Name)
assert.Equal(t, "example.com", got[0].Domain)
assert.Equal(t, true, got[0].Enabled)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Zones_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
expectedStatus int
expectZone bool
}{
{
name: "Get existing zone",
zoneId: "testZoneId",
expectedStatus: http.StatusOK,
expectZone: true,
},
{
name: "Get non-existing zone",
zoneId: "nonExistingZoneId",
expectedStatus: http.StatusNotFound,
expectZone: false,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.expectZone {
got := &api.Zone{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, "testZoneId", got.Id)
assert.Equal(t, "Test Zone", got.Name)
assert.Equal(t, "example.com", got.Domain)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_Zones_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
enabled := true
disabled := false
tt := []struct {
name string
requestBody *api.PostApiDnsZonesJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, zone *api.Zone)
}{
{
name: "Create zone with valid data",
requestBody: &api.PostApiDnsZonesJSONRequestBody{
Name: "New Zone",
Domain: "newzone.com",
Enabled: &enabled,
EnableSearchDomain: false,
DistributionGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, zone *api.Zone) {
t.Helper()
assert.NotEmpty(t, zone.Id)
assert.Equal(t, "New Zone", zone.Name)
assert.Equal(t, "newzone.com", zone.Domain)
assert.Equal(t, true, zone.Enabled)
assert.Equal(t, false, zone.EnableSearchDomain)
assert.Equal(t, 1, len(zone.DistributionGroups))
},
},
{
name: "Create zone with search domain enabled",
requestBody: &api.PostApiDnsZonesJSONRequestBody{
Name: "Search Zone",
Domain: "search.example.com",
Enabled: &enabled,
EnableSearchDomain: true,
DistributionGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, zone *api.Zone) {
t.Helper()
assert.NotEmpty(t, zone.Id)
assert.Equal(t, "Search Zone", zone.Name)
assert.Equal(t, true, zone.EnableSearchDomain)
},
},
{
name: "Create disabled zone",
requestBody: &api.PostApiDnsZonesJSONRequestBody{
Name: "Disabled Zone",
Domain: "disabled.example.com",
Enabled: &disabled,
EnableSearchDomain: false,
DistributionGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, zone *api.Zone) {
t.Helper()
assert.NotEmpty(t, zone.Id)
assert.Equal(t, false, zone.Enabled)
},
},
{
name: "Create zone with empty distribution groups",
requestBody: &api.PostApiDnsZonesJSONRequestBody{
Name: "No Groups Zone",
Domain: "nogroups.com",
Enabled: &enabled,
EnableSearchDomain: false,
DistributionGroups: []string{},
},
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/dns/zones", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.Zone{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify the created zone directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbZone := testing_tools.VerifyZoneInDB(t, db, got.Id)
assert.Equal(t, got.Name, dbZone.Name)
assert.Equal(t, got.Domain, dbZone.Domain)
assert.Equal(t, got.Enabled, dbZone.Enabled)
assert.Equal(t, got.EnableSearchDomain, dbZone.EnableSearchDomain)
}
})
}
}
}
func Test_Zones_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
enabled := true
tt := []struct {
name string
zoneId string
requestBody *api.PutApiDnsZonesZoneIdJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, zone *api.Zone)
}{
{
name: "Update zone name and settings",
zoneId: "testZoneId",
requestBody: &api.PutApiDnsZonesZoneIdJSONRequestBody{
Name: "Updated Zone",
Domain: "example.com",
Enabled: &enabled,
EnableSearchDomain: true,
DistributionGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, zone *api.Zone) {
t.Helper()
assert.Equal(t, "Updated Zone", zone.Name)
assert.Equal(t, "example.com", zone.Domain)
assert.Equal(t, true, zone.EnableSearchDomain)
},
},
{
name: "Update non-existing zone",
zoneId: "nonExistingZoneId",
requestBody: &api.PutApiDnsZonesZoneIdJSONRequestBody{
Name: "Whatever",
Domain: "whatever.com",
Enabled: &enabled,
EnableSearchDomain: false,
DistributionGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.Zone{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify the updated zone directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbZone := testing_tools.VerifyZoneInDB(t, db, tc.zoneId)
assert.Equal(t, "Updated Zone", dbZone.Name)
assert.Equal(t, "example.com", dbZone.Domain)
assert.Equal(t, true, dbZone.Enabled)
assert.Equal(t, true, dbZone.EnableSearchDomain)
}
})
}
}
}
func Test_Zones_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
expectedStatus int
}{
{
name: "Delete existing zone",
zoneId: "testZoneId",
expectedStatus: http.StatusOK,
},
{
name: "Delete non-existing zone",
zoneId: "nonExistingZoneId",
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
// Verify deletion in DB for successful deletes by privileged users
if tc.expectedStatus == http.StatusOK && user.expectResponse {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyZoneNotInDB(t, db, tc.zoneId)
}
})
}
}
}

View File

@@ -78,6 +78,68 @@ func Test_Events_GetAll(t *testing.T) {
} }
} }
func Test_Events_GetAll_Audit(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all audit events", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, false)
// First, perform a mutation to generate an event (create a group as admin)
groupBody, err := json.Marshal(&api.GroupRequest{Name: "auditTestGroup"})
if err != nil {
t.Fatalf("Failed to marshal group request: %v", err)
}
createReq := testing_tools.BuildRequest(t, groupBody, http.MethodPost, "/api/groups", testing_tools.TestAdminId)
createRecorder := httptest.NewRecorder()
apiHandler.ServeHTTP(createRecorder, createReq)
assert.Equal(t, http.StatusOK, createRecorder.Code, "Failed to create group to generate event")
// Now query audit events
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/events/audit", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.Event{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.GreaterOrEqual(t, len(got), 1, "Expected at least one event after creating a group")
// Verify the group creation event exists
found := false
for _, event := range got {
if event.ActivityCode == "group.add" {
found = true
assert.Equal(t, testing_tools.TestAdminId, event.InitiatorId)
assert.Equal(t, "Group created", event.Activity)
break
}
}
assert.True(t, found, "Expected to find a group.add event")
})
}
}
func Test_Events_GetAll_Empty(t *testing.T) { func Test_Events_GetAll_Empty(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, true) apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, true)

View File

@@ -0,0 +1,118 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Geolocations_GetAllCountries(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all countries", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/locations/countries", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.Country{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got))
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Geolocations_GetCitiesByCountry(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get cities by country", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/locations/countries/{country}/cities", "{country}", "US", 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.City{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got))
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Geolocations_GetCitiesByCountry_InvalidCode(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/locations/countries/{country}/cities", "{country}", "INVALID", 1), testing_tools.TestAdminId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, http.StatusUnprocessableEntity, true)
}

View File

@@ -26,7 +26,7 @@ func Test_Groups_GetAll(t *testing.T) {
{"Regular user", testing_tools.TestUserId, false}, {"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true}, {"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true}, {"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, true}, {"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true}, {"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false}, {"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false}, {"Other user", testing_tools.OtherUserId, false},
@@ -71,7 +71,7 @@ func Test_Groups_GetById(t *testing.T) {
{"Regular user", testing_tools.TestUserId, false}, {"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true}, {"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true}, {"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, true}, {"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true}, {"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false}, {"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false}, {"Other user", testing_tools.OtherUserId, false},
@@ -216,7 +216,6 @@ func Test_Groups_Create(t *testing.T) {
} }
tc.verifyResponse(t, got) tc.verifyResponse(t, got)
// Verify group exists in DB
db := testing_tools.GetDB(t, am.GetStore()) db := testing_tools.GetDB(t, am.GetStore())
dbGroup := testing_tools.VerifyGroupInDB(t, db, got.Id) dbGroup := testing_tools.VerifyGroupInDB(t, db, got.Id)
assert.Equal(t, tc.requestBody.Name, dbGroup.Name) assert.Equal(t, tc.requestBody.Name, dbGroup.Name)

View File

@@ -0,0 +1,295 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_IdentityProviders_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all identity providers", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/identity-providers", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.IdentityProvider{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
// The embedded IdP manager is not initialized in the test environment,
// so GetIdentityProviders returns an empty list.
assert.Equal(t, 0, len(got))
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_IdentityProviders_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
idpId string
expectedStatus int
}{
{
name: "Get existing identity provider",
idpId: "testIdpId",
expectedStatus: http.StatusInternalServerError,
},
{
name: "Get non-existing identity provider",
idpId: "nonExistingIdpId",
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/identity-providers/{idpId}", "{idpId}", tc.idpId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
})
}
}
}
func Test_IdentityProviders_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
requestBody *api.PostApiIdentityProvidersJSONRequestBody
expectedStatus int
}{
{
name: "Create identity provider with valid data",
requestBody: &api.PostApiIdentityProvidersJSONRequestBody{
Type: api.IdentityProviderTypeGoogle,
Name: "New IDP",
ClientId: "newClientId",
ClientSecret: "newClientSecret",
},
// Validation passes but the embedded IdP manager is not initialized,
// so the operation returns an internal server error.
expectedStatus: http.StatusInternalServerError,
},
{
name: "Create identity provider with invalid issuer",
requestBody: &api.PostApiIdentityProvidersJSONRequestBody{
Type: api.IdentityProviderTypeOidc,
Name: "Invalid IDP",
Issuer: "not-a-url",
ClientId: "clientId",
ClientSecret: "clientSecret",
},
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/identity-providers", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
})
}
}
}
func Test_IdentityProviders_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
idpId string
requestBody *api.PutApiIdentityProvidersIdpIdJSONRequestBody
expectedStatus int
}{
{
name: "Update existing identity provider",
idpId: "testIdpId",
requestBody: &api.PutApiIdentityProvidersIdpIdJSONRequestBody{
Type: api.IdentityProviderTypeGoogle,
Name: "Updated IDP",
ClientId: "updatedClientId",
ClientSecret: "updatedClientSecret",
},
// Validation passes but the embedded IdP manager is not initialized,
// so the operation returns an internal server error.
expectedStatus: http.StatusInternalServerError,
},
{
name: "Update non-existing identity provider",
idpId: "nonExistingIdpId",
requestBody: &api.PutApiIdentityProvidersIdpIdJSONRequestBody{
Type: api.IdentityProviderTypeGoogle,
Name: "Updated IDP",
ClientId: "updatedClientId",
ClientSecret: "updatedClientSecret",
},
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/identity-providers/{idpId}", "{idpId}", tc.idpId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
})
}
}
}
func Test_IdentityProviders_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
idpId string
expectedStatus int
}{
{
name: "Delete existing identity provider",
idpId: "testIdpId",
expectedStatus: http.StatusInternalServerError,
},
{
name: "Delete non-existing identity provider",
idpId: "nonExistingIdpId",
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/identity-providers/{idpId}", "{idpId}", tc.idpId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
})
}
}
}

View File

@@ -0,0 +1,183 @@
//go:build integration
package integration
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// Test_Instance_GetStatus tests the unauthenticated GET /api/instance endpoint.
// This endpoint bypasses auth middleware. With nil idpManager (no embedded IDP),
// SetupRequired should be false.
func Test_Instance_GetStatus(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
// The /api/instance endpoint is unauthenticated (bypass path).
// We still pass a token via BuildRequest but the bypass middleware skips auth.
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/instance", testing_tools.TestAdminId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
assert.Equal(t, http.StatusOK, recorder.Code, "Expected 200 OK for instance status endpoint, got %d: %s", recorder.Code, string(content))
got := &api.InstanceStatus{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
// With nil idpManager (no embedded IDP configured), setup is not required.
assert.Equal(t, false, got.SetupRequired, "Expected SetupRequired to be false when embedded IDP is not configured")
}
// Test_Instance_GetStatus_Unauthenticated verifies the endpoint works without any
// valid user token, since it is on the bypass path.
func Test_Instance_GetStatus_Unauthenticated(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
// Use an invalid token to confirm the bypass middleware skips auth
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/instance", testing_tools.InvalidToken)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
assert.Equal(t, http.StatusOK, recorder.Code, "Expected 200 OK for unauthenticated instance status endpoint, got %d: %s", recorder.Code, string(content))
got := &api.InstanceStatus{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, false, got.SetupRequired)
}
// Test_Instance_GetVersionInfo tests the authenticated GET /api/instance/version endpoint.
func Test_Instance_GetVersionInfo(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get version info", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/instance/version", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := &api.InstanceVersionInfo{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.NotEmpty(t, got.ManagementCurrentVersion, "Expected non-empty current version")
})
}
}
// Test_Instance_Setup tests the unauthenticated POST /api/setup endpoint.
// Since embedded IDP is not configured in the test environment, the setup
// endpoint should return an error (500 Internal Server Error) because the
// instance manager's CreateOwnerUser returns "embedded IDP is not enabled".
func Test_Instance_Setup(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
body, err := json.Marshal(&api.SetupRequest{
Email: "admin@test.com",
Password: "securepassword123",
Name: "Admin User",
})
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
// The /api/setup endpoint is unauthenticated (bypass path).
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/setup", testing_tools.TestAdminId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
// Without embedded IDP, CreateOwnerUser returns a plain error (not a status error),
// which the handler maps to 500 Internal Server Error.
assert.Equal(t, http.StatusInternalServerError, recorder.Code,
"Expected 500 when embedded IDP is not configured, got %d: %s", recorder.Code, string(content))
}
// Test_Instance_Setup_Unauthenticated verifies the setup endpoint works (reaches
// the handler) even with an invalid token, since it is on the bypass path.
func Test_Instance_Setup_Unauthenticated(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
body, err := json.Marshal(&api.SetupRequest{
Email: "admin@test.com",
Password: "securepassword123",
Name: "Admin User",
})
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/setup", testing_tools.InvalidToken)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
// Even with an invalid token, the bypass middleware skips auth and the handler runs.
// Without embedded IDP, it returns 500.
assert.Equal(t, http.StatusInternalServerError, recorder.Code,
"Expected 500 when embedded IDP is not configured, got %d: %s", recorder.Code, string(content))
}

View File

@@ -0,0 +1,154 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// Note: The integration test infrastructure does not configure an embedded IDP,
// so actual invite operations will return PreconditionFailed (412) for authorized users.
// These tests verify that the permissions layer correctly denies regular users
// before the handler logic is reached.
func Test_Invites_List(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - List invites", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/users/invites", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get PreconditionFailed (no embedded IDP configured)
// Unauthorized users get rejected by the permissions middleware
testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse)
})
}
}
func Test_Invites_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Create invite", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
body, err := json.Marshal(&api.UserInviteCreateRequest{
Email: "newuser@test.com",
Name: "New User",
Role: "user",
AutoGroups: []string{},
})
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/users/invites", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get PreconditionFailed (no embedded IDP configured)
// Unauthorized users get rejected by the permissions middleware
testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse)
})
}
}
func Test_Invites_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Delete invite", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/users/invites/{inviteId}", "{inviteId}", "someInviteId", 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get PreconditionFailed (no embedded IDP configured)
// Unauthorized users get rejected by the permissions middleware
testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse)
})
}
}
func Test_Invites_Regenerate(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Regenerate invite", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodPost, strings.Replace("/api/users/invites/{inviteId}/regenerate", "{inviteId}", "someInviteId", 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get PreconditionFailed (no embedded IDP configured)
// Unauthorized users get rejected by the permissions middleware
testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse)
})
}
}

View File

@@ -45,7 +45,7 @@ func Test_Peers_GetAll(t *testing.T) {
{ {
name: "Regular service user", name: "Regular service user",
userId: testing_tools.TestServiceUserId, userId: testing_tools.TestServiceUserId,
expectResponse: true, expectResponse: false,
}, },
{ {
name: "Admin service user", name: "Admin service user",
@@ -123,7 +123,7 @@ func Test_Peers_GetById(t *testing.T) {
{ {
name: "Regular service user", name: "Regular service user",
userId: testing_tools.TestServiceUserId, userId: testing_tools.TestServiceUserId,
expectResponse: true, expectResponse: false,
}, },
{ {
name: "Admin service user", name: "Admin service user",

View File

@@ -0,0 +1,372 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_PostureChecks_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all posture checks", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/posture-checks", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []*api.PostureCheck{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 1, len(got))
assert.Equal(t, "testPostureCheckId", got[0].Id)
assert.Equal(t, "NetBird Version Check", got[0].Name)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_PostureChecks_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
postureCheckId string
expectedStatus int
expectCheck bool
}{
{
name: "Get existing posture check",
postureCheckId: "testPostureCheckId",
expectedStatus: http.StatusOK,
expectCheck: true,
},
{
name: "Get non-existing posture check",
postureCheckId: "nonExistingPostureCheckId",
expectedStatus: http.StatusNotFound,
expectCheck: false,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/posture-checks/{postureCheckId}", "{postureCheckId}", tc.postureCheckId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.expectCheck {
got := &api.PostureCheck{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, "testPostureCheckId", got.Id)
assert.Equal(t, "NetBird Version Check", got.Name)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_PostureChecks_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
minVersion := "0.32.0"
tt := []struct {
name string
requestBody *api.PostureCheckUpdate
expectedStatus int
verifyResponse func(t *testing.T, check *api.PostureCheck)
}{
{
name: "Create posture check with NB version",
requestBody: &api.PostureCheckUpdate{
Name: "New Version Check",
Description: "check for new version",
Checks: &api.Checks{
NbVersionCheck: &api.NBVersionCheck{
MinVersion: minVersion,
},
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, check *api.PostureCheck) {
t.Helper()
assert.NotEmpty(t, check.Id)
assert.Equal(t, "New Version Check", check.Name)
assert.NotNil(t, check.Checks.NbVersionCheck)
assert.Equal(t, minVersion, check.Checks.NbVersionCheck.MinVersion)
},
},
{
name: "Create posture check with empty name",
requestBody: &api.PostureCheckUpdate{
Name: "",
Checks: &api.Checks{
NbVersionCheck: &api.NBVersionCheck{
MinVersion: "0.32.0",
},
},
},
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/posture-checks", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.PostureCheck{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
db := testing_tools.GetDB(t, am.GetStore())
dbCheck := testing_tools.VerifyPostureCheckInDB(t, db, got.Id)
assert.Equal(t, got.Name, dbCheck.Name)
}
})
}
}
}
func Test_PostureChecks_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
minVersion := "0.33.0"
tt := []struct {
name string
postureCheckId string
requestBody *api.PostureCheckUpdate
expectedStatus int
verifyResponse func(t *testing.T, check *api.PostureCheck)
}{
{
name: "Update posture check name and version",
postureCheckId: "testPostureCheckId",
requestBody: &api.PostureCheckUpdate{
Name: "Updated Version Check",
Description: "updated description",
Checks: &api.Checks{
NbVersionCheck: &api.NBVersionCheck{
MinVersion: minVersion,
},
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, check *api.PostureCheck) {
t.Helper()
assert.Equal(t, "testPostureCheckId", check.Id)
assert.Equal(t, "Updated Version Check", check.Name)
},
},
{
name: "Update non-existing posture check",
postureCheckId: "nonExistingPostureCheckId",
requestBody: &api.PostureCheckUpdate{
Name: "whatever",
Checks: &api.Checks{
NbVersionCheck: &api.NBVersionCheck{
MinVersion: "0.33.0",
},
},
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/posture-checks/{postureCheckId}", "{postureCheckId}", tc.postureCheckId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.PostureCheck{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
db := testing_tools.GetDB(t, am.GetStore())
dbCheck := testing_tools.VerifyPostureCheckInDB(t, db, tc.postureCheckId)
assert.Equal(t, "Updated Version Check", dbCheck.Name)
}
})
}
}
}
func Test_PostureChecks_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
postureCheckId string
expectedStatus int
}{
{
name: "Delete existing posture check",
postureCheckId: "testPostureCheckId",
expectedStatus: http.StatusOK,
},
{
name: "Delete non-existing posture check",
postureCheckId: "nonExistingPostureCheckId",
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/posture-checks/{postureCheckId}", "{postureCheckId}", tc.postureCheckId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if tc.expectedStatus == http.StatusOK && user.expectResponse {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyPostureCheckNotInDB(t, db, tc.postureCheckId)
}
})
}
}
}

View File

@@ -0,0 +1,306 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_ReverseProxy_GetClusters(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get clusters", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/reverse-proxies/clusters", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.ProxyCluster{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got), "Expected empty clusters list when no proxy is connected")
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_ReverseProxy_GetAllServices(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all services", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/reverse-proxies/services", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []*api.Service{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got), "Expected empty services list with no services in DB")
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_ReverseProxy_CreateService(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Create service", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
// Creating a service requires a valid domain entry. Without one, authorized
// users will receive a validation error. This test verifies that the
// permissions layer correctly rejects unauthorized users before handler
// logic runs, and that authorized users reach the handler (even if the
// handler returns an error due to missing domain).
body, err := json.Marshal(&api.ServiceRequest{
Name: "test-service",
Domain: "nonexistent.example.com",
Enabled: true,
})
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/reverse-proxies/services", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users reach the handler but get a validation error (no valid domain).
// Unauthorized users are rejected by the permissions middleware.
testing_tools.ReadResponse(t, recorder, http.StatusUnprocessableEntity, user.expectResponse)
})
}
}
func Test_ReverseProxy_GetServiceById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get service by ID (non-existing)", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/reverse-proxies/services/{serviceId}", "{serviceId}", "nonExistingServiceId", 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get NotFound for a non-existing service.
// Unauthorized users are rejected by the permissions middleware.
testing_tools.ReadResponse(t, recorder, http.StatusNotFound, user.expectResponse)
})
}
}
func Test_ReverseProxy_DeleteService(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Delete service (non-existing)", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/reverse-proxies/services/{serviceId}", "{serviceId}", "nonExistingServiceId", 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get NotFound for a non-existing service.
// Unauthorized users are rejected by the permissions middleware.
testing_tools.ReadResponse(t, recorder, http.StatusNotFound, user.expectResponse)
})
}
}
func Test_ReverseProxy_GetAllDomains(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all domains", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/reverse-proxies/domains", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.ReverseProxyDomain{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got), "Expected empty domains list with no domains in DB")
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_ReverseProxy_GetAccessLogs(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get proxy access logs", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/events/proxy", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := &api.ProxyAccessLogsResponse{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got.Data), "Expected empty access logs data")
assert.Equal(t, 0, got.TotalRecords, "Expected zero total records")
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}

View File

@@ -0,0 +1,124 @@
//go:build integration
package integration
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
)
func Test_Users_Approve(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
targetUserId string
expectedStatus int
}{
{
name: "Approve pending user",
targetUserId: "pendingUserId",
expectedStatus: http.StatusOK,
},
{
name: "Approve non-existing user",
targetUserId: "nonExistingUserId",
expectedStatus: http.StatusNotFound,
},
{
name: "Approve already active user",
targetUserId: testing_tools.TestUserId,
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_approve_reject.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodPost, strings.Replace("/api/users/{userId}/approve", "{userId}", tc.targetUserId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
})
}
}
}
func Test_Users_Reject(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
targetUserId string
expectedStatus int
}{
{
name: "Reject pending user",
targetUserId: "pendingUserId",
expectedStatus: http.StatusOK,
},
{
name: "Reject non-existing user",
targetUserId: "nonExistingUserId",
expectedStatus: http.StatusNotFound,
},
{
name: "Reject non-pending user",
targetUserId: testing_tools.TestUserId,
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_approve_reject.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/users/{userId}/reject", "{userId}", tc.targetUserId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
_, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if expectResponse && tc.expectedStatus == http.StatusOK {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyUserNotInDB(t, db, tc.targetUserId)
}
})
}
}
}

View File

@@ -0,0 +1,64 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Users_GetCurrent(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, true},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, false},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get current user", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/users/current", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := &api.User{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, user.userId, got.Id)
assert.NotNil(t, got.IsCurrent)
assert.True(t, *got.IsCurrent)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}

View File

@@ -26,7 +26,7 @@ func Test_Users_GetAll(t *testing.T) {
{"Regular user", testing_tools.TestUserId, true}, {"Regular user", testing_tools.TestUserId, true},
{"Admin user", testing_tools.TestAdminId, true}, {"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true}, {"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, true}, {"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true}, {"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false}, {"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false}, {"Other user", testing_tools.OtherUserId, false},
@@ -637,6 +637,38 @@ func Test_PATs_Create(t *testing.T) {
} }
} }
func Test_Users_Update_CrossAccountAttack(t *testing.T) {
t.Run("Admin attempts to update user from other account", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
body, _ := json.Marshal(&api.UserRequest{
Role: "user",
AutoGroups: []string{},
IsBlocked: true,
})
// TestAdminId belongs to testAccountId, but targets otherUserId which belongs to otherAccountId
req := testing_tools.BuildRequest(t, body, http.MethodPut, "/api/users/otherUserId", testing_tools.TestAdminId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account user update must be rejected")
})
}
func Test_Users_Delete_CrossAccountAttack(t *testing.T) {
t.Run("Admin attempts to delete service user from other account", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
// TestAdminId belongs to testAccountId, but targets otherServiceUserId which belongs to otherAccountId
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/users/otherServiceUserId", testing_tools.TestAdminId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account user delete must be rejected")
})
}
func Test_PATs_Delete(t *testing.T) { func Test_PATs_Delete(t *testing.T) {
users := []struct { users := []struct {
name string name string

View File

@@ -15,4 +15,4 @@ INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NU
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0);

View File

@@ -15,7 +15,7 @@ INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,N
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO name_server_groups VALUES('testNSGroupId','testAccountId','testNSGroup','test nameserver group','[{"IP":"1.1.1.1","NSType":1,"Port":53}]','["testGroupId"]',0,'["example.com"]',1,0); INSERT INTO name_server_groups VALUES('testNSGroupId','testAccountId','testNSGroup','test nameserver group','[{"IP":"1.1.1.1","NSType":1,"Port":53}]','["testGroupId"]',0,'["example.com"]',1,0);

View File

@@ -0,0 +1,26 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `zones` (`id` text,`account_id` text,`name` text,`domain` text,`enabled` numeric,`enable_search_domain` numeric,`distribution_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_zones` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `records` (`id` text,`account_id` text,`zone_id` text,`name` text,`type` text,`content` text,`ttl` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_zones_records` FOREIGN KEY (`zone_id`) REFERENCES `zones`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO name_server_groups VALUES('testNSGroupId','testAccountId','testNSGroup','test nameserver group','[{"IP":"1.1.1.1","NSType":1,"Port":53}]','["testGroupId"]',0,'["example.com"]',1,0);
INSERT INTO zones VALUES('testZoneId','testAccountId','Test Zone','example.com',1,0,'["testGroupId"]');
INSERT INTO records VALUES('testRecordId','testAccountId','testZoneId','sub.example.com','A','1.2.3.4',300);

View File

@@ -14,5 +14,5 @@ INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,N
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);

View File

@@ -15,5 +15,5 @@ INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NU
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO "groups" VALUES('allGroupId','testAccountId','All','api','[]',0,''); INSERT INTO "groups" VALUES('allGroupId','testAccountId','All','api','[]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);

View File

@@ -0,0 +1,20 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `identity_providers` (`id` text,`account_id` text,`type` text,`name` text,`issuer` text,`client_id` text,`client_secret` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_identity_providers` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO identity_providers VALUES('testIdpId','testAccountId','oidc','Test IDP','https://issuer.example.com','client123','secret456');

Some files were not shown because too many files have changed in this diff Show More