[management] validate permissions on groups read with name (#5749)

This commit is contained in:
Pascal Fischer
2026-04-07 14:13:09 +02:00
committed by GitHub
parent 6da34e483c
commit 14b3b77bda
11 changed files with 68 additions and 65 deletions

View File

@@ -85,8 +85,8 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
accountMgr := &mock_server.MockAccountManager{ accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) { GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID) return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
}, },
} }

View File

@@ -1119,7 +1119,7 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
} }
groupIDs := make([]string, 0, len(groupNames)) groupIDs := make([]string, 0, len(groupNames))
for _, groupName := range groupNames { for _, groupName := range groupNames {
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID) g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID, activity.SystemInitiator)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err) return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
} }

View File

@@ -698,8 +698,8 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
accountMgr := &mock_server.MockAccountManager{ accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) { GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID) return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
}, },
} }

View File

@@ -75,7 +75,7 @@ type Manager interface {
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error

View File

@@ -736,18 +736,18 @@ func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID inte
} }
// GetGroupByName mocks base method. // GetGroupByName mocks base method.
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID) ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID)
ret0, _ := ret[0].(*types.Group) ret0, _ := ret[0].(*types.Group)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetGroupByName indicates an expected call of GetGroupByName. // GetGroupByName indicates an expected call of GetGroupByName.
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID interface{}) *gomock.Call { func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID)
} }
// GetIdentityProvider mocks base method. // GetIdentityProvider mocks base method.

View File

@@ -61,7 +61,10 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
} }

View File

@@ -52,7 +52,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
groupName := r.URL.Query().Get("name") groupName := r.URL.Query().Get("name")
if groupName != "" { if groupName != "" {
// Get single group by name // Get single group by name
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID) group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -118,7 +118,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
return return
} }
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID) allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@@ -71,7 +71,7 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
return groups, nil return groups, nil
}, },
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) { GetGroupByNameFunc: func(ctx context.Context, groupName, _, _ string) (*types.Group, error) {
if groupName == "All" { if groupName == "All" {
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
} }

View File

@@ -46,7 +46,7 @@ type MockAccountManager struct {
AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) GetGroupByNameFunc func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
@@ -406,9 +406,9 @@ func (am *MockAccountManager) AddPeer(
} }
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) { func (am *MockAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
if am.GetGroupByNameFunc != nil { if am.GetGroupByNameFunc != nil {
return am.GetGroupByNameFunc(ctx, accountID, groupName) return am.GetGroupByNameFunc(ctx, groupName, accountID, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented")
} }

View File

@@ -121,7 +121,7 @@ type Store interface {
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error)
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error)
CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error
UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error

View File

@@ -165,34 +165,6 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
} }
// GetClusterSupportsCustomPorts mocks base method.
func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterSupportsCustomPorts", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterSupportsCustomPorts indicates an expected call of GetClusterSupportsCustomPorts.
func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr)
}
// GetClusterRequireSubdomain mocks base method.
func (m *MockStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterRequireSubdomain", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterRequireSubdomain indicates an expected call of GetClusterRequireSubdomain.
func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr)
}
// Close mocks base method. // Close mocks base method.
func (m *MockStore) Close(ctx context.Context) error { func (m *MockStore) Close(ctx context.Context) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -1389,6 +1361,34 @@ func (mr *MockStoreMockRecorder) GetAnyAccountID(ctx interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnyAccountID", reflect.TypeOf((*MockStore)(nil).GetAnyAccountID), ctx) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnyAccountID", reflect.TypeOf((*MockStore)(nil).GetAnyAccountID), ctx)
} }
// GetClusterRequireSubdomain mocks base method.
func (m *MockStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterRequireSubdomain", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterRequireSubdomain indicates an expected call of GetClusterRequireSubdomain.
func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr)
}
// GetClusterSupportsCustomPorts mocks base method.
func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterSupportsCustomPorts", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterSupportsCustomPorts indicates an expected call of GetClusterSupportsCustomPorts.
func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr)
}
// GetCustomDomain mocks base method. // GetCustomDomain mocks base method.
func (m *MockStore) GetCustomDomain(ctx context.Context, accountID, domainID string) (*domain.Domain, error) { func (m *MockStore) GetCustomDomain(ctx context.Context, accountID, domainID string) (*domain.Domain, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -1466,18 +1466,18 @@ func (mr *MockStoreMockRecorder) GetGroupByID(ctx, lockStrength, accountID, grou
} }
// GetGroupByName mocks base method. // GetGroupByName mocks base method.
func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types2.Group, error) { func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types2.Group, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, groupName, accountID) ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, accountID, groupName)
ret0, _ := ret[0].(*types2.Group) ret0, _ := ret[0].(*types2.Group)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetGroupByName indicates an expected call of GetGroupByName. // GetGroupByName indicates an expected call of GetGroupByName.
func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, groupName, accountID interface{}) *gomock.Call { func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, accountID, groupName interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, groupName, accountID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, accountID, groupName)
} }
// GetGroupsByIDs mocks base method. // GetGroupsByIDs mocks base method.
@@ -1974,6 +1974,21 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRouteByID", reflect.TypeOf((*MockStore)(nil).GetRouteByID), ctx, lockStrength, accountID, routeID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRouteByID", reflect.TypeOf((*MockStore)(nil).GetRouteByID), ctx, lockStrength, accountID, routeID)
} }
// GetRoutingPeerNetworks mocks base method.
func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks.
func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID)
}
// GetServiceByDomain mocks base method. // GetServiceByDomain mocks base method.
func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -2361,21 +2376,6 @@ func (mr *MockStoreMockRecorder) IncrementSetupKeyUsage(ctx, setupKeyID interfac
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID)
} }
// GetRoutingPeerNetworks mocks base method.
func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks.
func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID)
}
// IsPrimaryAccount mocks base method. // IsPrimaryAccount mocks base method.
func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) { func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()