From 14b3b77bda7251240689791eda71fa4e9a68dd7c Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 7 Apr 2026 14:13:09 +0200 Subject: [PATCH] [management] validate permissions on groups read with name (#5749) --- .../service/manager/l4_port_test.go | 4 +- .../reverseproxy/service/manager/manager.go | 2 +- .../service/manager/manager_test.go | 4 +- management/server/account/manager.go | 2 +- management/server/account/manager_mock.go | 8 +- management/server/group.go | 5 +- .../http/handlers/groups/groups_handler.go | 4 +- .../handlers/groups/groups_handler_test.go | 2 +- management/server/mock_server/account_mock.go | 6 +- management/server/store/store.go | 2 +- management/server/store/store_mock.go | 94 +++++++++---------- 11 files changed, 68 insertions(+), 65 deletions(-) diff --git a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go index 4a7647d90..47dce3a64 100644 --- a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go +++ b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go @@ -85,8 +85,8 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor accountMgr := &mock_server.MockAccountManager{ StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, - GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) { - return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID) + GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { + return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) }, } diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index 989187826..ed9d4201b 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -1119,7 +1119,7 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr } groupIDs := make([]string, 0, len(groupNames)) for _, groupName := range groupNames { - g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID) + g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID, activity.SystemInitiator) if err != nil { return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err) } diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index f6e532118..69d48f10a 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -698,8 +698,8 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { accountMgr := &mock_server.MockAccountManager{ StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, - GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) { - return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID) + GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { + return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) }, } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 45af63ae8..b4516d512 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -75,7 +75,7 @@ type Manager interface { GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) - GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) + GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error diff --git a/management/server/account/manager_mock.go b/management/server/account/manager_mock.go index 90700c795..36e5fe39f 100644 --- a/management/server/account/manager_mock.go +++ b/management/server/account/manager_mock.go @@ -736,18 +736,18 @@ func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID inte } // GetGroupByName mocks base method. -func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { +func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID) + ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID) ret0, _ := ret[0].(*types.Group) ret1, _ := ret[1].(error) return ret0, ret1 } // GetGroupByName indicates an expected call of GetGroupByName. -func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID) } // GetIdentityProvider mocks base method. diff --git a/management/server/group.go b/management/server/group.go index 326b167cf..7b5b9b86c 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -61,7 +61,10 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us } // GetGroupByName filters all groups in an account by name and returns the one with the most peers -func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { +func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { + if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { + return nil, err + } return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) } diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index 56ccc9d0b..f8d161a87 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -52,7 +52,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { groupName := r.URL.Query().Get("name") if groupName != "" { // Get single group by name - group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID) + group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -118,7 +118,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { return } - allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID) + allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index 458a15c11..c7b4cbcdd 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -71,7 +71,7 @@ func initGroupTestData(initGroups ...*types.Group) *handler { return groups, nil }, - GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) { + GetGroupByNameFunc: func(ctx context.Context, groupName, _, _ string) (*types.Group, error) { if groupName == "All" { return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index afd2021ac..ff369355e 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -46,7 +46,7 @@ type MockAccountManager struct { AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) - GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) + GetGroupByNameFunc func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error @@ -406,9 +406,9 @@ func (am *MockAccountManager) AddPeer( } // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface -func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) { +func (am *MockAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { if am.GetGroupByNameFunc != nil { - return am.GetGroupByNameFunc(ctx, accountID, groupName) + return am.GetGroupByNameFunc(ctx, groupName, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented") } diff --git a/management/server/store/store.go b/management/server/store/store.go index e24a1efef..f0c34ffa9 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -121,7 +121,7 @@ type Store interface { GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) - GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) + GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index a8648aed7..5e609c4ec 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -165,34 +165,6 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration) } -// GetClusterSupportsCustomPorts mocks base method. -func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetClusterSupportsCustomPorts", ctx, clusterAddr) - ret0, _ := ret[0].(*bool) - return ret0 -} - -// GetClusterSupportsCustomPorts indicates an expected call of GetClusterSupportsCustomPorts. -func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr) -} - -// GetClusterRequireSubdomain mocks base method. -func (m *MockStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetClusterRequireSubdomain", ctx, clusterAddr) - ret0, _ := ret[0].(*bool) - return ret0 -} - -// GetClusterRequireSubdomain indicates an expected call of GetClusterRequireSubdomain. -func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr) -} - // Close mocks base method. func (m *MockStore) Close(ctx context.Context) error { m.ctrl.T.Helper() @@ -1389,6 +1361,34 @@ func (mr *MockStoreMockRecorder) GetAnyAccountID(ctx interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnyAccountID", reflect.TypeOf((*MockStore)(nil).GetAnyAccountID), ctx) } +// GetClusterRequireSubdomain mocks base method. +func (m *MockStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClusterRequireSubdomain", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// GetClusterRequireSubdomain indicates an expected call of GetClusterRequireSubdomain. +func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr) +} + +// GetClusterSupportsCustomPorts mocks base method. +func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClusterSupportsCustomPorts", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// GetClusterSupportsCustomPorts indicates an expected call of GetClusterSupportsCustomPorts. +func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr) +} + // GetCustomDomain mocks base method. func (m *MockStore) GetCustomDomain(ctx context.Context, accountID, domainID string) (*domain.Domain, error) { m.ctrl.T.Helper() @@ -1466,18 +1466,18 @@ func (mr *MockStoreMockRecorder) GetGroupByID(ctx, lockStrength, accountID, grou } // GetGroupByName mocks base method. -func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types2.Group, error) { +func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types2.Group, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, groupName, accountID) + ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, accountID, groupName) ret0, _ := ret[0].(*types2.Group) ret1, _ := ret[1].(error) return ret0, ret1 } // GetGroupByName indicates an expected call of GetGroupByName. -func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, groupName, accountID interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, accountID, groupName interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, groupName, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, accountID, groupName) } // GetGroupsByIDs mocks base method. @@ -1974,6 +1974,21 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRouteByID", reflect.TypeOf((*MockStore)(nil).GetRouteByID), ctx, lockStrength, accountID, routeID) } +// GetRoutingPeerNetworks mocks base method. +func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks. +func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID) +} + // GetServiceByDomain mocks base method. func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { m.ctrl.T.Helper() @@ -2361,21 +2376,6 @@ func (mr *MockStoreMockRecorder) IncrementSetupKeyUsage(ctx, setupKeyID interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID) } -// GetRoutingPeerNetworks mocks base method. -func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID) - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks. -func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID) -} - // IsPrimaryAccount mocks base method. func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) { m.ctrl.T.Helper()