diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 26da85850..3b303e6a9 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -28,7 +28,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -79,16 +78,6 @@ func TestGetDNSSettings(t *testing.T) { if len(dnsSettings.DisabledManagementGroups) != 1 { t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups) } - - _, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID) - if err == nil { - t.Errorf("An error should be returned when getting the DNS settings with a regular user") - } - - s, ok := status.FromError(err) - if !ok && s.Type() != status.PermissionDenied { - t.Errorf("returned error should be Permission Denied, got err: %s", err) - } } func TestSaveDNSSettings(t *testing.T) { diff --git a/management/server/http/handlers/instance/instance_handler_test.go b/management/server/http/handlers/instance/instance_handler_test.go index ce0ce0ce7..02f787d26 100644 --- a/management/server/http/handlers/instance/instance_handler_test.go +++ b/management/server/http/handlers/instance/instance_handler_test.go @@ -10,12 +10,17 @@ import ( "net/mail" "testing" + "github.com/golang/mock/gomock" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/idp" nbinstance "github.com/netbirdio/netbird/management/server/instance" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) @@ -295,8 +300,15 @@ func TestSetup_ManagerError(t *testing.T) { func TestGetVersionInfo_Success(t *testing.T) { manager := &mockInstanceManager{} + ctrl := gomock.NewController(t) + permissionsManager := permissions.NewMockManager(ctrl) + permissionsManager.EXPECT().WithPermission(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(module modules.Module, operation operations.Operation, handler func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth), authErrHandler ...permissions.AuthErrorHandler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + handler(w, r, &auth.UserAuth{}) + } + }).AnyTimes() router := mux.NewRouter() - AddVersionEndpoint(manager, router, nil) + AddVersionEndpoint(manager, router, permissionsManager) req := httptest.NewRequest(http.MethodGet, "/instance/version", nil) rec := httptest.NewRecorder() diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index a8af9aa18..2b7b697d6 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -342,7 +342,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request, use // Check if user is an admin/service user through their role isAdmin := user.Role == types.UserRoleAdmin || user.Role == types.UserRoleOwner - if !isAdmin && !userAuth.IsChild { + if !isAdmin && !user.IsServiceUser && !userAuth.IsChild { if account.Settings.RegularUsersViewBlocked { util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{}) return diff --git a/management/server/http/testing/integration/invites_handler_integration_test.go b/management/server/http/testing/integration/invites_handler_integration_test.go new file mode 100644 index 000000000..358fef667 --- /dev/null +++ b/management/server/http/testing/integration/invites_handler_integration_test.go @@ -0,0 +1,154 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// Note: The integration test infrastructure does not configure an embedded IDP, +// so actual invite operations will return PreconditionFailed (412) for authorized users. +// These tests verify that the permissions layer correctly denies regular users +// before the handler logic is reached. + +func Test_Invites_List(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - List invites", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/users/invites", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + // Authorized users get PreconditionFailed (no embedded IDP configured) + // Unauthorized users get rejected by the permissions middleware + testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse) + }) + } +} + +func Test_Invites_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Create invite", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false) + + body, err := json.Marshal(&api.UserInviteCreateRequest{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + }) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/users/invites", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + // Authorized users get PreconditionFailed (no embedded IDP configured) + // Unauthorized users get rejected by the permissions middleware + testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse) + }) + } +} + +func Test_Invites_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Delete invite", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/users/invites/{inviteId}", "{inviteId}", "someInviteId", 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + // Authorized users get PreconditionFailed (no embedded IDP configured) + // Unauthorized users get rejected by the permissions middleware + testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse) + }) + } +} + +func Test_Invites_Regenerate(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Regenerate invite", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodPost, strings.Replace("/api/users/invites/{inviteId}/regenerate", "{inviteId}", "someInviteId", 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + // Authorized users get PreconditionFailed (no embedded IDP configured) + // Unauthorized users get rejected by the permissions middleware + testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse) + }) + } +} diff --git a/management/server/http/testing/integration/posture_checks_handler_integration_test.go b/management/server/http/testing/integration/posture_checks_handler_integration_test.go new file mode 100644 index 000000000..40ac677bf --- /dev/null +++ b/management/server/http/testing/integration/posture_checks_handler_integration_test.go @@ -0,0 +1,372 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_PostureChecks_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all posture checks", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/posture-checks", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []*api.PostureCheck{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + assert.Equal(t, "testPostureCheckId", got[0].Id) + assert.Equal(t, "NetBird Version Check", got[0].Name) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_PostureChecks_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + postureCheckId string + expectedStatus int + expectCheck bool + }{ + { + name: "Get existing posture check", + postureCheckId: "testPostureCheckId", + expectedStatus: http.StatusOK, + expectCheck: true, + }, + { + name: "Get non-existing posture check", + postureCheckId: "nonExistingPostureCheckId", + expectedStatus: http.StatusNotFound, + expectCheck: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/posture-checks/{postureCheckId}", "{postureCheckId}", tc.postureCheckId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectCheck { + got := &api.PostureCheck{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, "testPostureCheckId", got.Id) + assert.Equal(t, "NetBird Version Check", got.Name) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_PostureChecks_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + minVersion := "0.32.0" + tt := []struct { + name string + requestBody *api.PostureCheckUpdate + expectedStatus int + verifyResponse func(t *testing.T, check *api.PostureCheck) + }{ + { + name: "Create posture check with NB version", + requestBody: &api.PostureCheckUpdate{ + Name: "New Version Check", + Description: "check for new version", + Checks: &api.Checks{ + NbVersionCheck: &api.NBVersionCheck{ + MinVersion: minVersion, + }, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, check *api.PostureCheck) { + t.Helper() + assert.NotEmpty(t, check.Id) + assert.Equal(t, "New Version Check", check.Name) + assert.NotNil(t, check.Checks.NbVersionCheck) + assert.Equal(t, minVersion, check.Checks.NbVersionCheck.MinVersion) + }, + }, + { + name: "Create posture check with empty name", + requestBody: &api.PostureCheckUpdate{ + Name: "", + Checks: &api.Checks{ + NbVersionCheck: &api.NBVersionCheck{ + MinVersion: "0.32.0", + }, + }, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/posture-checks", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.PostureCheck{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + db := testing_tools.GetDB(t, am.GetStore()) + dbCheck := testing_tools.VerifyPostureCheckInDB(t, db, got.Id) + assert.Equal(t, got.Name, dbCheck.Name) + } + }) + } + } +} + +func Test_PostureChecks_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + minVersion := "0.33.0" + tt := []struct { + name string + postureCheckId string + requestBody *api.PostureCheckUpdate + expectedStatus int + verifyResponse func(t *testing.T, check *api.PostureCheck) + }{ + { + name: "Update posture check name and version", + postureCheckId: "testPostureCheckId", + requestBody: &api.PostureCheckUpdate{ + Name: "Updated Version Check", + Description: "updated description", + Checks: &api.Checks{ + NbVersionCheck: &api.NBVersionCheck{ + MinVersion: minVersion, + }, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, check *api.PostureCheck) { + t.Helper() + assert.Equal(t, "testPostureCheckId", check.Id) + assert.Equal(t, "Updated Version Check", check.Name) + }, + }, + { + name: "Update non-existing posture check", + postureCheckId: "nonExistingPostureCheckId", + requestBody: &api.PostureCheckUpdate{ + Name: "whatever", + Checks: &api.Checks{ + NbVersionCheck: &api.NBVersionCheck{ + MinVersion: "0.33.0", + }, + }, + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/posture-checks/{postureCheckId}", "{postureCheckId}", tc.postureCheckId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.PostureCheck{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + db := testing_tools.GetDB(t, am.GetStore()) + dbCheck := testing_tools.VerifyPostureCheckInDB(t, db, tc.postureCheckId) + assert.Equal(t, "Updated Version Check", dbCheck.Name) + } + }) + } + } +} + +func Test_PostureChecks_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + postureCheckId string + expectedStatus int + }{ + { + name: "Delete existing posture check", + postureCheckId: "testPostureCheckId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing posture check", + postureCheckId: "nonExistingPostureCheckId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/posture-checks/{postureCheckId}", "{postureCheckId}", tc.postureCheckId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + + if tc.expectedStatus == http.StatusOK && user.expectResponse { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyPostureCheckNotInDB(t, db, tc.postureCheckId) + } + }) + } + } +} diff --git a/management/server/http/testing/integration/users_approve_reject_handler_integration_test.go b/management/server/http/testing/integration/users_approve_reject_handler_integration_test.go new file mode 100644 index 000000000..52bf8983d --- /dev/null +++ b/management/server/http/testing/integration/users_approve_reject_handler_integration_test.go @@ -0,0 +1,124 @@ +//go:build integration + +package integration + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" +) + +func Test_Users_Approve(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + targetUserId string + expectedStatus int + }{ + { + name: "Approve pending user", + targetUserId: "pendingUserId", + expectedStatus: http.StatusOK, + }, + { + name: "Approve non-existing user", + targetUserId: "nonExistingUserId", + expectedStatus: http.StatusNotFound, + }, + { + name: "Approve already active user", + targetUserId: testing_tools.TestUserId, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_approve_reject.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodPost, strings.Replace("/api/users/{userId}/approve", "{userId}", tc.targetUserId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + }) + } + } +} + +func Test_Users_Reject(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + targetUserId string + expectedStatus int + }{ + { + name: "Reject pending user", + targetUserId: "pendingUserId", + expectedStatus: http.StatusOK, + }, + { + name: "Reject non-existing user", + targetUserId: "nonExistingUserId", + expectedStatus: http.StatusNotFound, + }, + { + name: "Reject non-pending user", + targetUserId: testing_tools.TestUserId, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_approve_reject.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/users/{userId}/reject", "{userId}", tc.targetUserId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + _, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + + if expectResponse && tc.expectedStatus == http.StatusOK { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyUserNotInDB(t, db, tc.targetUserId) + } + }) + } + } +} diff --git a/management/server/http/testing/testing_tools/db_verify.go b/management/server/http/testing/testing_tools/db_verify.go index f8af6a41f..a21e72a17 100644 --- a/management/server/http/testing/testing_tools/db_verify.go +++ b/management/server/http/testing/testing_tools/db_verify.go @@ -12,6 +12,7 @@ import ( routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" @@ -220,3 +221,20 @@ func VerifyNetworkRouterNotInDB(t *testing.T, db *gorm.DB, routerID string) { db.Model(&routerTypes.NetworkRouter{}).Where("id = ? AND account_id = ?", routerID, TestAccountId).Count(&count) assert.Equal(t, int64(0), count, "Expected network router %s to NOT exist in DB", routerID) } + +// VerifyPostureCheckInDB reads a posture check directly from the DB and returns it. +func VerifyPostureCheckInDB(t *testing.T, db *gorm.DB, postureCheckID string) *posture.Checks { + t.Helper() + var check posture.Checks + err := db.Where("id = ? AND account_id = ?", postureCheckID, TestAccountId).First(&check).Error + require.NoError(t, err, "Expected posture check %s to exist in DB", postureCheckID) + return &check +} + +// VerifyPostureCheckNotInDB verifies that a posture check does not exist in the DB. +func VerifyPostureCheckNotInDB(t *testing.T, db *gorm.DB, postureCheckID string) { + t.Helper() + var count int64 + db.Model(&posture.Checks{}).Where("id = ? AND account_id = ?", postureCheckID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected posture check %s to NOT exist in DB", postureCheckID) +} diff --git a/management/server/peer.go b/management/server/peer.go index 69732bd1a..dbc0cac96 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -44,7 +44,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return nil, err } - if all { + if all || user.IsAdminOrServiceUser() { return accountPeers, nil } @@ -561,6 +561,9 @@ func (am *DefaultAccountManager) handleUserAddedPeer(ctx context.Context, accoun if user.PendingApproval { return status.Errorf(status.PermissionDenied, "user pending approval cannot add peers") } + if temporary && !user.IsAdminOrServiceUser() { + return status.Errorf(status.PermissionDenied, "only admin or service users can add peers") + } if !temporary { config.AccountID = user.AccountID diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 7f0a48dc7..4c5b9bda6 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -32,14 +32,6 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { } t.Run("Generic posture check flow", func(t *testing.T) { - // regular users can not create checks - _, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}, true) - assert.Error(t, err) - - // regular users cannot list check - _, err = am.ListPostureChecks(context.Background(), account.Id, regularUserID) - assert.Error(t, err) - // should be possible to create posture check with uniq name postureCheck, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ Name: postureCheckName, @@ -80,10 +72,6 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck, true) assert.NoError(t, err) - // users should not be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, regularUserID) - assert.Error(t, err) - // admin should be able to delete posture checks err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, adminUserID) assert.NoError(t, err) diff --git a/management/server/user.go b/management/server/user.go index 7bf7a51e1..c6f09ac28 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -919,8 +919,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun accountUsers := []*types.User{} // Determine if user has full access based on their role - // Service users and regular users have limited visibility - hasFullAccess := user.HasAdminPower() + hasFullAccess := user.HasAdminPower() || user.IsServiceUser switch { case hasFullAccess: diff --git a/management/server/user_invite_test.go b/management/server/user_invite_test.go index 039bf00c0..e3a665221 100644 --- a/management/server/user_invite_test.go +++ b/management/server/user_invite_test.go @@ -163,22 +163,6 @@ func TestCreateUserInvite_ExistingUserEmail(t *testing.T) { assert.Equal(t, status.UserAlreadyExists, sErr.Type()) } -func TestCreateUserInvite_PermissionDenied(t *testing.T) { - am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) - defer cleanup() - - invite := &types.UserInfo{ - Email: "newuser@test.com", - Name: "New User", - Role: "user", - AutoGroups: []string{}, - } - - // Regular user should not be able to create invites - _, err := am.CreateUserInvite(context.Background(), testAccountID, testRegularUserID, invite, 0) - require.Error(t, err) -} - func TestCreateUserInvite_InvalidEmail(t *testing.T) { am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) defer cleanup() @@ -412,14 +396,6 @@ func TestListUserInvites_Empty(t *testing.T) { assert.Len(t, invites, 0) } -func TestListUserInvites_PermissionDenied(t *testing.T) { - am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) - defer cleanup() - - _, err := am.ListUserInvites(context.Background(), testAccountID, testRegularUserID) - require.Error(t, err) -} - func TestRegenerateUserInvite_Success(t *testing.T) { am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) defer cleanup() @@ -470,26 +446,6 @@ func TestRegenerateUserInvite_NotFound(t *testing.T) { assert.Equal(t, status.NotFound, sErr.Type()) } -func TestRegenerateUserInvite_PermissionDenied(t *testing.T) { - am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) - defer cleanup() - - // Create an invite first - invite := &types.UserInfo{ - Email: "newuser@test.com", - Name: "New User", - Role: "user", - AutoGroups: []string{}, - } - - result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) - require.NoError(t, err) - - // Regular user should not be able to regenerate - _, err = am.RegenerateUserInvite(context.Background(), testAccountID, testRegularUserID, result.UserInfo.ID, 0) - require.Error(t, err) -} - func TestDeleteUserInvite_Success(t *testing.T) { am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) defer cleanup() @@ -531,26 +487,6 @@ func TestDeleteUserInvite_NotFound(t *testing.T) { assert.Equal(t, status.NotFound, sErr.Type()) } -func TestDeleteUserInvite_PermissionDenied(t *testing.T) { - am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) - defer cleanup() - - // Create an invite first - invite := &types.UserInfo{ - Email: "newuser@test.com", - Name: "New User", - Role: "user", - AutoGroups: []string{}, - } - - result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) - require.NoError(t, err) - - // Regular user should not be able to delete - err = am.DeleteUserInvite(context.Background(), testAccountID, testRegularUserID, result.UserInfo.ID) - require.Error(t, err) -} - func TestDeleteUserInvite_WrongAccount(t *testing.T) { am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) defer cleanup() diff --git a/management/server/user_test.go b/management/server/user_test.go index 61c0f4883..468ca336d 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1426,13 +1426,14 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // create an account and an admin user - account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: ownerUserID, Domain: "netbird.io"}) + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: tc.name, Domain: "netbird.io"}) if err != nil { t.Fatal(err) } // create other users account.Users[regularUserID] = types.NewRegularUser(regularUserID, "", "") + account.Users[ownerUserID] = types.NewOwnerUser(ownerUserID, "", "") account.Users[adminUserID] = types.NewAdminUser(adminUserID) account.Users[serviceUserID] = &types.User{IsServiceUser: true, Id: serviceUserID, Role: types.UserRoleAdmin, ServiceUserName: "service"} err = manager.Store.SaveAccount(context.Background(), account) @@ -1705,11 +1706,6 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "not-found"}, expectedErr: status.NewUserNotFoundError("not-found"), }, - { - name: "not part of account", - userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account2Owner"}, - expectedErr: status.NewUserNotPartOfAccountError(), - }, { name: "blocked", userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "blocked-user"}, @@ -1846,7 +1842,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { result, err := am.GetCurrentUserInfo(context.Background(), tc.userAuth) if tc.expectedErr != nil { - assert.Equal(t, err, tc.expectedErr) + assert.Equal(t, tc.expectedErr, err) return } @@ -1911,22 +1907,6 @@ func TestApproveUser(t *testing.T) { _, err = manager.ApproveUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id) require.Error(t, err) assert.Contains(t, err.Error(), "not pending approval") - - // Test approval by non-admin should fail - regularUser := types.NewRegularUser("regular-user", "", "") - regularUser.AccountID = account.Id - err = manager.Store.SaveUser(context.Background(), regularUser) - require.NoError(t, err) - - pendingUser2 := types.NewRegularUser("pending-user-2", "", "") - pendingUser2.AccountID = account.Id - pendingUser2.Blocked = true - pendingUser2.PendingApproval = true - err = manager.Store.SaveUser(context.Background(), pendingUser2) - require.NoError(t, err) - - _, err = manager.ApproveUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id) - require.Error(t, err) } func TestRejectUser(t *testing.T) { @@ -1971,17 +1951,6 @@ func TestRejectUser(t *testing.T) { err = manager.RejectUser(context.Background(), account.Id, adminUser.Id, regularUser.Id) require.Error(t, err) assert.Contains(t, err.Error(), "not pending approval") - - // Test rejection by non-admin should fail - pendingUser2 := types.NewRegularUser("pending-user-2", "", "") - pendingUser2.AccountID = account.Id - pendingUser2.Blocked = true - pendingUser2.PendingApproval = true - err = manager.Store.SaveUser(context.Background(), pendingUser2) - require.NoError(t, err) - - err = manager.RejectUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id) - require.Error(t, err) } func TestUser_Operations_WithEmbeddedIDP(t *testing.T) {