diff --git a/client/internal/engine.go b/client/internal/engine.go index 74a07927c..c377c12e1 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1231,36 +1231,19 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer PreSharedKey: e.config.PreSharedKey, } - if e.config.RosenpassEnabled && !e.config.RosenpassPermissive { - lk := []byte(e.config.WgPrivateKey.PublicKey().String()) - rk := []byte(wgConfig.RemoteKey) - var keyInput []byte - if string(lk) > string(rk) { - //nolint:gocritic - keyInput = append(lk[:16], rk[:16]...) - } else { - //nolint:gocritic - keyInput = append(rk[:16], lk[:16]...) - } - - key, err := wgtypes.NewKey(keyInput) - if err != nil { - return nil, err - } - - wgConfig.PreSharedKey = &key - } - // randomize connection timeout timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond config := peer.ConnConfig{ - Key: pubKey, - LocalKey: e.config.WgPrivateKey.PublicKey().String(), - Timeout: timeout, - WgConfig: wgConfig, - LocalWgPort: e.config.WgPort, - RosenpassPubKey: e.getRosenpassPubKey(), - RosenpassAddr: e.getRosenpassAddr(), + Key: pubKey, + LocalKey: e.config.WgPrivateKey.PublicKey().String(), + Timeout: timeout, + WgConfig: wgConfig, + LocalWgPort: e.config.WgPort, + RosenpassConfig: peer.RosenpassConfig{ + PubKey: e.getRosenpassPubKey(), + Addr: e.getRosenpassAddr(), + PermissiveMode: e.config.RosenpassPermissive, + }, ICEConfig: icemaker.Config{ StunTurn: &e.stunTurn, InterfaceBlackList: e.config.IFaceBlackList, diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 85f94b53f..44e8997bc 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -60,6 +60,15 @@ type WgConfig struct { PreSharedKey *wgtypes.Key } +type RosenpassConfig struct { + // RosenpassPubKey is this peer's Rosenpass public key + PubKey []byte + // RosenpassPubKey is this peer's RosenpassAddr server address (IP:port) + Addr string + + PermissiveMode bool +} + // ConnConfig is a peer Connection configuration type ConnConfig struct { // Key is a public key of a remote peer @@ -73,10 +82,7 @@ type ConnConfig struct { LocalWgPort int - // RosenpassPubKey is this peer's Rosenpass public key - RosenpassPubKey []byte - // RosenpassPubKey is this peer's RosenpassAddr server address (IP:port) - RosenpassAddr string + RosenpassConfig RosenpassConfig // ICEConfig ICE protocol configuration ICEConfig icemaker.Config @@ -109,6 +115,8 @@ type Conn struct { connIDICE nbnet.ConnectionID beforeAddPeerHooks []nbnet.AddHookFunc afterRemovePeerHooks []nbnet.RemoveHookFunc + // used to store the remote Rosenpass key for Relayed connection in case of connection update from ice + rosenpassRemoteKey []byte wgProxyICE wgproxy.Proxy wgProxyRelay wgproxy.Proxy @@ -375,7 +383,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC wgProxy.Work() } - if err = conn.configureWGEndpoint(ep); err != nil { + if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil { conn.handleConfigurationFailure(err, wgProxy) return } @@ -408,7 +416,7 @@ func (conn *Conn) onICEStateDisconnected() { conn.dumpState.SwitchToRelay() conn.wgProxyRelay.Work() - if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil { + if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil { conn.log.Errorf("failed to switch to relay conn: %v", err) } @@ -478,7 +486,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { } wgProxy.Work() - if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil { + if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.log.Warnf("Failed to close relay connection: %v", err) } @@ -493,6 +501,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { }() wgConfigWorkaround() + conn.rosenpassRemoteKey = rci.rosenpassPubKey conn.currentConnPriority = connPriorityRelay conn.statusRelay.Set(StatusConnected) conn.setRelayedProxy(wgProxy) @@ -556,13 +565,14 @@ func (conn *Conn) listenGuardEvent(ctx context.Context) { } } -func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error { +func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error { + presharedKey := conn.presharedKey(remoteRPKey) return conn.config.WgConfig.WgInterface.UpdatePeer( conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, addr, - conn.config.WgConfig.PreSharedKey, + presharedKey, ) } @@ -783,6 +793,44 @@ func (conn *Conn) AllowedIP() netip.Addr { return conn.config.WgConfig.AllowedIps[0].Addr() } +func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key { + if conn.config.RosenpassConfig.PubKey == nil { + return conn.config.WgConfig.PreSharedKey + } + + if remoteRosenpassKey == nil && conn.config.RosenpassConfig.PermissiveMode { + return conn.config.WgConfig.PreSharedKey + } + + determKey, err := conn.rosenpassDetermKey() + if err != nil { + conn.log.Errorf("failed to generate Rosenpass initial key: %v", err) + return conn.config.WgConfig.PreSharedKey + } + + return determKey +} + +// todo: move this logic into Rosenpass package +func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) { + lk := []byte(conn.config.LocalKey) + rk := []byte(conn.config.Key) // remote key + var keyInput []byte + if string(lk) > string(rk) { + //nolint:gocritic + keyInput = append(lk[:16], rk[:16]...) + } else { + //nolint:gocritic + keyInput = append(rk[:16], lk[:16]...) + } + + key, err := wgtypes.NewKey(keyInput) + if err != nil { + return nil, err + } + return &key, nil +} + func isController(config ConnConfig) bool { return config.LocalKey > config.Key } diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index 505bedb7f..6d55cfff4 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -2,6 +2,7 @@ package peer import ( "context" + "fmt" "os" "sync" "testing" @@ -161,3 +162,145 @@ func TestConn_Status(t *testing.T) { }) } } + +func TestConn_presharedKey(t *testing.T) { + conn1 := Conn{ + config: ConnConfig{ + Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + RosenpassConfig: RosenpassConfig{}, + }, + } + conn2 := Conn{ + config: ConnConfig{ + Key: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + LocalKey: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + RosenpassConfig: RosenpassConfig{}, + }, + } + + tests := []struct { + conn1Permissive bool + conn1RosenpassEnabled bool + conn2Permissive bool + conn2RosenpassEnabled bool + conn1ExpectedInitialKey bool + conn2ExpectedInitialKey bool + }{ + { + conn1Permissive: false, + conn1RosenpassEnabled: false, + conn2Permissive: false, + conn2RosenpassEnabled: false, + conn1ExpectedInitialKey: false, + conn2ExpectedInitialKey: false, + }, + { + conn1Permissive: false, + conn1RosenpassEnabled: true, + conn2Permissive: false, + conn2RosenpassEnabled: true, + conn1ExpectedInitialKey: true, + conn2ExpectedInitialKey: true, + }, + { + conn1Permissive: false, + conn1RosenpassEnabled: true, + conn2Permissive: false, + conn2RosenpassEnabled: false, + conn1ExpectedInitialKey: true, + conn2ExpectedInitialKey: false, + }, + { + conn1Permissive: false, + conn1RosenpassEnabled: false, + conn2Permissive: false, + conn2RosenpassEnabled: true, + conn1ExpectedInitialKey: false, + conn2ExpectedInitialKey: true, + }, + { + conn1Permissive: true, + conn1RosenpassEnabled: true, + conn2Permissive: false, + conn2RosenpassEnabled: false, + conn1ExpectedInitialKey: false, + conn2ExpectedInitialKey: false, + }, + { + conn1Permissive: false, + conn1RosenpassEnabled: false, + conn2Permissive: true, + conn2RosenpassEnabled: true, + conn1ExpectedInitialKey: false, + conn2ExpectedInitialKey: false, + }, + { + conn1Permissive: true, + conn1RosenpassEnabled: true, + conn2Permissive: true, + conn2RosenpassEnabled: true, + conn1ExpectedInitialKey: true, + conn2ExpectedInitialKey: true, + }, + { + conn1Permissive: false, + conn1RosenpassEnabled: false, + conn2Permissive: false, + conn2RosenpassEnabled: true, + conn1ExpectedInitialKey: false, + conn2ExpectedInitialKey: true, + }, + { + conn1Permissive: false, + conn1RosenpassEnabled: true, + conn2Permissive: true, + conn2RosenpassEnabled: true, + conn1ExpectedInitialKey: true, + conn2ExpectedInitialKey: true, + }, + } + + conn1.config.RosenpassConfig.PermissiveMode = true + for i, test := range tests { + tcase := i + 1 + t.Run(fmt.Sprintf("Rosenpass test case %d", tcase), func(t *testing.T) { + conn1.config.RosenpassConfig = RosenpassConfig{} + conn2.config.RosenpassConfig = RosenpassConfig{} + + if test.conn1RosenpassEnabled { + conn1.config.RosenpassConfig.PubKey = []byte("dummykey") + } + conn1.config.RosenpassConfig.PermissiveMode = test.conn1Permissive + + if test.conn2RosenpassEnabled { + conn2.config.RosenpassConfig.PubKey = []byte("dummykey") + } + conn2.config.RosenpassConfig.PermissiveMode = test.conn2Permissive + + conn1PresharedKey := conn1.presharedKey(conn2.config.RosenpassConfig.PubKey) + conn2PresharedKey := conn2.presharedKey(conn1.config.RosenpassConfig.PubKey) + + if test.conn1ExpectedInitialKey { + if conn1PresharedKey == nil { + t.Errorf("Case %d: Expected conn1 to have a non-nil key, but got nil", tcase) + } + } else { + if conn1PresharedKey != nil { + t.Errorf("Case %d: Expected conn1 to have a nil key, but got %v", tcase, conn1PresharedKey) + } + } + + // Assert conn2's key expectation + if test.conn2ExpectedInitialKey { + if conn2PresharedKey == nil { + t.Errorf("Case %d: Expected conn2 to have a non-nil key, but got nil", tcase) + } + } else { + if conn2PresharedKey != nil { + t.Errorf("Case %d: Expected conn2 to have a nil key, but got %v", tcase, conn2PresharedKey) + } + } + }) + } +} diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index d23727e96..224ea0262 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -154,8 +154,8 @@ func (h *Handshaker) sendOffer() error { IceCredentials: IceCredentials{iceUFrag, icePwd}, WgListenPort: h.config.LocalWgPort, Version: version.NetbirdVersion(), - RosenpassPubKey: h.config.RosenpassPubKey, - RosenpassAddr: h.config.RosenpassAddr, + RosenpassPubKey: h.config.RosenpassConfig.PubKey, + RosenpassAddr: h.config.RosenpassConfig.Addr, } addr, err := h.relay.RelayInstanceAddress() @@ -174,8 +174,8 @@ func (h *Handshaker) sendAnswer() error { IceCredentials: IceCredentials{uFrag, pwd}, WgListenPort: h.config.LocalWgPort, Version: version.NetbirdVersion(), - RosenpassPubKey: h.config.RosenpassPubKey, - RosenpassAddr: h.config.RosenpassAddr, + RosenpassPubKey: h.config.RosenpassConfig.PubKey, + RosenpassAddr: h.config.RosenpassConfig.Addr, } addr, err := h.relay.RelayInstanceAddress() if err == nil { diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 2874604fd..72c4758f4 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -10,20 +10,27 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/errors" - route "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/route" ) type RouteSelector struct { mu sync.RWMutex selectedRoutes map[route.NetID]struct{} selectAll bool + + // Indicates if new routes should be automatically selected + includeNewRoutes bool + + // All known routes at the time of deselection + knownRoutes []route.NetID } func NewRouteSelector() *RouteSelector { return &RouteSelector{ - selectedRoutes: map[route.NetID]struct{}{}, - // default selects all routes - selectAll: true, + selectedRoutes: map[route.NetID]struct{}{}, + selectAll: true, + includeNewRoutes: false, + knownRoutes: []route.NetID{}, } } @@ -46,6 +53,7 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al rs.selectedRoutes[route] = struct{}{} } rs.selectAll = false + rs.includeNewRoutes = false return errors.FormatErrorOrNil(err) } @@ -57,16 +65,22 @@ func (rs *RouteSelector) SelectAllRoutes() { rs.selectAll = true rs.selectedRoutes = map[route.NetID]struct{}{} + rs.includeNewRoutes = false } // DeselectRoutes removes specific routes from the selection. -// If the selector is in "select all" mode, it will transition to "select specific" mode. +// If the selector is in "select all" mode, it will transition to "select specific" mode +// but will keep new routes selected. func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error { rs.mu.Lock() defer rs.mu.Unlock() if rs.selectAll { rs.selectAll = false + rs.includeNewRoutes = true + rs.knownRoutes = make([]route.NetID, len(allRoutes)) + copy(rs.knownRoutes, allRoutes) + rs.selectedRoutes = map[route.NetID]struct{}{} for _, route := range allRoutes { rs.selectedRoutes[route] = struct{}{} @@ -92,6 +106,7 @@ func (rs *RouteSelector) DeselectAllRoutes() { defer rs.mu.Unlock() rs.selectAll = false + rs.includeNewRoutes = false rs.selectedRoutes = map[route.NetID]struct{}{} } @@ -103,8 +118,20 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { if rs.selectAll { return true } + + // Check if the route exists in selectedRoutes _, selected := rs.selectedRoutes[routeID] - return selected + if selected { + return true + } + + // If includeNewRoutes is true and this is a new route (not in knownRoutes), + // then it should be selected + if rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, routeID) { + return true + } + + return false } // FilterSelected removes unselected routes from the provided map. @@ -118,7 +145,11 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { filtered := route.HAMap{} for id, rt := range routes { - if rs.IsSelected(id.NetID()) { + netID := id.NetID() + _, selected := rs.selectedRoutes[netID] + + // Include if directly selected or if it's a new route and includeNewRoutes is true + if selected || (rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, netID)) { filtered[id] = rt } } @@ -131,11 +162,15 @@ func (rs *RouteSelector) MarshalJSON() ([]byte, error) { defer rs.mu.RUnlock() return json.Marshal(struct { - SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` - SelectAll bool `json:"select_all"` + SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` + SelectAll bool `json:"select_all"` + IncludeNewRoutes bool `json:"include_new_routes"` + KnownRoutes []route.NetID `json:"known_routes"` }{ - SelectAll: rs.selectAll, - SelectedRoutes: rs.selectedRoutes, + SelectAll: rs.selectAll, + SelectedRoutes: rs.selectedRoutes, + IncludeNewRoutes: rs.includeNewRoutes, + KnownRoutes: rs.knownRoutes, }) } @@ -149,12 +184,16 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error { if len(data) == 0 || string(data) == "null" { rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectAll = true + rs.includeNewRoutes = false + rs.knownRoutes = []route.NetID{} return nil } var temp struct { - SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` - SelectAll bool `json:"select_all"` + SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` + SelectAll bool `json:"select_all"` + IncludeNewRoutes bool `json:"include_new_routes"` + KnownRoutes []route.NetID `json:"known_routes"` } if err := json.Unmarshal(data, &temp); err != nil { @@ -163,10 +202,15 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error { rs.selectedRoutes = temp.SelectedRoutes rs.selectAll = temp.SelectAll + rs.includeNewRoutes = temp.IncludeNewRoutes + rs.knownRoutes = temp.KnownRoutes if rs.selectedRoutes == nil { rs.selectedRoutes = map[route.NetID]struct{}{} } + if rs.knownRoutes == nil { + rs.knownRoutes = []route.NetID{} + } return nil } diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go index b1671f254..a1461dff6 100644 --- a/client/internal/routeselector/routeselector_test.go +++ b/client/internal/routeselector/routeselector_test.go @@ -316,7 +316,7 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) { return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes) }, // After deselecting specific routes, new routes should remain unselected - wantNewSelected: []route.NetID{"route2", "route3"}, + wantNewSelected: []route.NetID{"route2", "route3", "route4", "route5"}, }, { name: "New routes after selecting with append", @@ -358,3 +358,73 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) { }) } } + +func TestRouteSelector_MixedSelectionDeselection(t *testing.T) { + allRoutes := []route.NetID{"route1", "route2", "route3"} + + tests := []struct { + name string + routesToSelect []route.NetID + selectAppend bool + routesToDeselect []route.NetID + selectFirst bool + wantSelectedFinal []route.NetID + }{ + { + name: "1. Select A, then Deselect B", + routesToSelect: []route.NetID{"route1"}, + selectAppend: false, + routesToDeselect: []route.NetID{"route2"}, + selectFirst: true, + wantSelectedFinal: []route.NetID{"route1"}, + }, + { + name: "2. Select A, then Deselect A", + routesToSelect: []route.NetID{"route1"}, + selectAppend: false, + routesToDeselect: []route.NetID{"route1"}, + selectFirst: true, + wantSelectedFinal: []route.NetID{}, + }, + { + name: "3. Deselect A (from all), then Select B", + routesToSelect: []route.NetID{"route2"}, + selectAppend: false, + routesToDeselect: []route.NetID{"route1"}, + selectFirst: false, + wantSelectedFinal: []route.NetID{"route2"}, + }, + { + name: "4. Deselect A (from all), then Select A", + routesToSelect: []route.NetID{"route1"}, + selectAppend: false, + routesToDeselect: []route.NetID{"route1"}, + selectFirst: false, + wantSelectedFinal: []route.NetID{"route1"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := routeselector.NewRouteSelector() + + var err1, err2 error + + if tt.selectFirst { + err1 = rs.SelectRoutes(tt.routesToSelect, tt.selectAppend, allRoutes) + require.NoError(t, err1) + err2 = rs.DeselectRoutes(tt.routesToDeselect, allRoutes) + require.NoError(t, err2) + } else { + err1 = rs.DeselectRoutes(tt.routesToDeselect, allRoutes) + require.NoError(t, err1) + err2 = rs.SelectRoutes(tt.routesToSelect, tt.selectAppend, allRoutes) + require.NoError(t, err2) + } + + for _, r := range allRoutes { + assert.Equal(t, slices.Contains(tt.wantSelectedFinal, r), rs.IsSelected(r), "Route %s final state mismatch", r) + } + }) + } +} diff --git a/management/server/account.go b/management/server/account.go index 1627959d2..d7f108dfe 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -283,7 +283,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, err } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update) if err != nil { return nil, fmt.Errorf("failed to validate user permissions: %w", err) } @@ -533,7 +533,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return err } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Delete) if err != nil { return fmt.Errorf("failed to validate user permissions: %w", err) } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 7a4afb682..2ddc9a3d0 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -60,15 +60,15 @@ type Manager interface { GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) - SaveGroup(ctx context.Context, accountID, userID string, group *types.Group) error - SaveGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error + SaveGroup(ctx context.Context, accountID, userID string, group *types.Group, create bool) error + SaveGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group, create bool) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) - SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) + SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) @@ -94,7 +94,7 @@ type Manager interface { HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) + SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager diff --git a/management/server/account_test.go b/management/server/account_test.go index cf4523e70..7f34cf845 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1115,7 +1115,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { Name: "GroupA", Peers: []string{}, } - if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil { t.Errorf("save group: %v", err) return } @@ -1131,7 +1131,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { Action: types.PolicyTrafficActionAccept, }, }, - }) + }, true) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -1150,7 +1150,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { }() group.Peers = []string{peer1.ID, peer2.ID, peer3.ID} - if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil { t.Errorf("save group: %v", err) return } @@ -1192,7 +1192,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { Name: "GroupA", Peers: []string{peer1.ID, peer2.ID}, } - if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil { t.Errorf("save group: %v", err) return } @@ -1223,7 +1223,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { Action: types.PolicyTrafficActionAccept, }, }, - }) + }, true) if err != nil { t.Errorf("delete default rule: %v", err) return @@ -1240,7 +1240,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { Name: "GroupA", Peers: []string{peer1.ID, peer3.ID}, } - if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil { t.Errorf("save group: %v", err) return } @@ -1256,7 +1256,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { Action: types.PolicyTrafficActionAccept, }, }, - }) + }, true) if err != nil { t.Errorf("save policy: %v", err) return @@ -1295,7 +1295,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }) + }, true) require.NoError(t, err, "failed to save group") @@ -1315,7 +1315,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { Action: types.PolicyTrafficActionAccept, }, }, - }) + }, true) if err != nil { t.Errorf("save policy: %v", err) return diff --git a/management/server/dns.go b/management/server/dns.go index d457db773..a3f32c2a9 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -81,7 +81,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update) if err != nil { return status.NewPermissionValidationError(err) } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 8a0e0cd02..36476b14c 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -504,7 +504,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { Name: "GroupB", Peers: []string{}, }, - }) + }, true) assert.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -564,7 +564,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }) + }, true) assert.NoError(t, err) done := make(chan struct{}) diff --git a/management/server/group.go b/management/server/group.go index c102cedb8..0bd840798 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -66,17 +66,21 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, } // SaveGroup object of the peers -func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { +func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *types.Group, create bool) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - return am.SaveGroups(ctx, accountID, userID, []*types.Group{newGroup}) + return am.SaveGroups(ctx, accountID, userID, []*types.Group{newGroup}, create) } // SaveGroups adds new groups to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Write) +func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error { + operation := operations.Create + if !create { + operation = operations.Update + } + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operation) if err != nil { return status.NewPermissionValidationError(err) } @@ -203,7 +207,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use // If an error occurs while deleting a group, the function skips it and continues deleting other groups. // Errors are collected and returned at the end. func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) } diff --git a/management/server/group_test.go b/management/server/group_test.go index dffaa80e3..4966f2b33 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -40,7 +40,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { } for _, group := range account.Groups { group.Issued = types.GroupIssuedIntegration - err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) + err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true) if err != nil { t.Errorf("should allow to create %s groups", types.GroupIssuedIntegration) } @@ -48,7 +48,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { for _, group := range account.Groups { group.Issued = types.GroupIssuedJWT - err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) + err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true) if err != nil { t.Errorf("should allow to create %s groups", types.GroupIssuedJWT) } @@ -56,7 +56,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { for _, group := range account.Groups { group.Issued = types.GroupIssuedAPI group.ID = "" - err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) + err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true) if err == nil { t.Errorf("should not create api group with the same name, %s", group.Name) } @@ -162,7 +162,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { } } - err = manager.SaveGroups(context.Background(), account.Id, groupAdminUserID, groups) + err = manager.SaveGroups(context.Background(), account.Id, groupAdminUserID, groups, true) assert.NoError(t, err, "Failed to save test groups") testCases := []struct { @@ -382,13 +382,13 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t return nil, nil, err } - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute, true) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2, true) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups, true) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies, true) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys, true) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers, true) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration, true) acc, err := am.Store.GetAccount(context.Background(), account.Id) if err != nil { @@ -426,7 +426,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { Name: "GroupE", Peers: []string{peer2.ID}, }, - }) + }, true) assert.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -446,7 +446,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { ID: "groupB", Name: "GroupB", Peers: []string{peer1.ID, peer2.ID}, - }) + }, true) assert.NoError(t, err) select { @@ -524,7 +524,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { Action: types.PolicyTrafficActionAccept, }, }, - }) + }, true) assert.NoError(t, err) // Saving a group linked to policy should update account peers and send peer update @@ -539,7 +539,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID}, - }) + }, true) assert.NoError(t, err) select { @@ -608,7 +608,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { ID: "groupC", Name: "GroupC", Peers: []string{peer1.ID, peer3.ID}, - }) + }, true) assert.NoError(t, err) select { @@ -649,7 +649,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }) + }, true) assert.NoError(t, err) select { @@ -676,7 +676,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { ID: "groupD", Name: "GroupD", Peers: []string{peer1.ID}, - }) + }, true) assert.NoError(t, err) select { @@ -723,7 +723,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { ID: "groupE", Name: "GroupE", Peers: []string{peer2.ID, peer3.ID}, - }) + }, true) assert.NoError(t, err) select { diff --git a/management/server/groups/manager.go b/management/server/groups/manager.go index 48e28d4f8..df4b6c3d6 100644 --- a/management/server/groups/manager.go +++ b/management/server/groups/manager.go @@ -72,7 +72,7 @@ func (m *managerImpl) GetAllGroupsMap(ctx context.Context, accountID, userID str } func (m *managerImpl) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resource *types.Resource) error { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Write) + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update) if err != nil { return err } diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index 667095018..3ae833dc0 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -143,7 +143,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { IntegrationReference: existingGroup.IntegrationReference, } - if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group); err != nil { + if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group, false); err != nil { log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err) util.WriteError(r.Context(), err, w) return @@ -203,7 +203,7 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { Issued: types.GroupIssuedAPI, } - err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group) + err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group, true) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index f4ac34e53..2caa2f5bf 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -35,7 +35,7 @@ var TestPeers = map[string]*nbpeer.Peer{ func initGroupTestData(initGroups ...*types.Group) *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ - SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group) error { + SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group, create bool) error { if !strings.HasPrefix(group.ID, "id-") { group.ID = "id-was-set" } diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index 02db2a13a..9ff7ea0ea 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -95,7 +95,7 @@ func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) { return } - h.savePolicy(w, r, accountID, userID, policyID) + h.savePolicy(w, r, accountID, userID, policyID, false) } // createPolicy handles policy creation request @@ -108,11 +108,11 @@ func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) { accountID, userID := userAuth.AccountId, userAuth.UserId - h.savePolicy(w, r, accountID, userID, "") + h.savePolicy(w, r, accountID, userID, "", true) } // savePolicy handles policy creation and update -func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) { +func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string, create bool) { var req api.PutApiPoliciesPolicyIdJSONRequestBody if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -279,7 +279,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s policy.SourcePostureChecks = *req.SourcePostureChecks } - policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy) + policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy, create) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/policies/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go index 6450295eb..6f3dbc792 100644 --- a/management/server/http/handlers/policies/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -34,7 +34,7 @@ func initPoliciesTestData(policies ...*types.Policy) *handler { } return policy, nil }, - SavePolicyFunc: func(_ context.Context, _, _ string, policy *types.Policy) (*types.Policy, error) { + SavePolicyFunc: func(_ context.Context, _, _ string, policy *types.Policy, create bool) (*types.Policy, error) { if !strings.HasPrefix(policy.ID, "id-") { policy.ID = "id-was-set" policy.Rules[0].ID = "id-was-set" diff --git a/management/server/http/handlers/policies/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go index 310a1a2f9..2925f96ef 100644 --- a/management/server/http/handlers/policies/posture_checks_handler.go +++ b/management/server/http/handlers/policies/posture_checks_handler.go @@ -84,7 +84,7 @@ func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http return } - p.savePostureChecks(w, r, accountID, userID, postureChecksID) + p.savePostureChecks(w, r, accountID, userID, postureChecksID, false) } // createPostureCheck handles posture check creation request @@ -97,7 +97,7 @@ func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http accountID, userID := userAuth.AccountId, userAuth.UserId - p.savePostureChecks(w, r, accountID, userID, "") + p.savePostureChecks(w, r, accountID, userID, "", true) } // getPostureCheck handles a posture check Get request identified by ID @@ -150,7 +150,7 @@ func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http } // savePostureChecks handles posture checks create and update -func (p *postureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) { +func (p *postureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string, create bool) { var ( err error req api.PostureCheckUpdate @@ -175,7 +175,7 @@ func (p *postureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http. return } - postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks) + postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks, create) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go index e3844caa2..e875b3738 100644 --- a/management/server/http/handlers/policies/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -40,7 +40,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH } return p, nil }, - SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { + SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) { postureChecks.ID = "postureCheck" testPostureChecks[postureChecks.ID] = postureChecks diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 182b07714..8088585d5 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -45,8 +45,8 @@ type MockAccountManager struct { GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) - SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group) error - SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group) error + SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error + SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error @@ -54,7 +54,7 @@ type MockAccountManager struct { GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) @@ -98,7 +98,7 @@ type MockAccountManager struct { HasConnectedChannelFunc func(peerID string) bool GetExternalCacheManagerFunc func() account.ExternalCacheManager GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) + SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManagerFunc func() idp.Manager @@ -324,17 +324,17 @@ func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, gro } // SaveGroup mock implementation of SaveGroup from server.AccountManager interface -func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *types.Group) error { +func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *types.Group, create bool) error { if am.SaveGroupFunc != nil { - return am.SaveGroupFunc(ctx, accountID, userID, group) + return am.SaveGroupFunc(ctx, accountID, userID, group, create) } return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented") } // SaveGroups mock implementation of SaveGroups from server.AccountManager interface -func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { +func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error { if am.SaveGroupsFunc != nil { - return am.SaveGroupsFunc(ctx, accountID, userID, groups) + return am.SaveGroupsFunc(ctx, accountID, userID, groups, create) } return status.Errorf(codes.Unimplemented, "method SaveGroups is not implemented") } @@ -388,9 +388,9 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID } // SavePolicy mock implementation of SavePolicy from server.AccountManager interface -func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) { +func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) { if am.SavePolicyFunc != nil { - return am.SavePolicyFunc(ctx, accountID, userID, policy) + return am.SavePolicyFunc(ctx, accountID, userID, policy, create) } return nil, status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") } @@ -724,9 +724,9 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p } // SavePostureChecks mocks SavePostureChecks of the AccountManager interface -func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { +func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) { if am.SavePostureChecksFunc != nil { - return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks) + return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks, create) } return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") } diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 773377f7a..797d7c11c 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -38,7 +38,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) } @@ -99,7 +99,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Update) if err != nil { return status.NewPermissionValidationError(err) } @@ -149,7 +149,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index dd1149a03..1ba790797 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -965,7 +965,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { Name: "GroupB", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, }, - }) + }, true) assert.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index eba3a1fe1..1c46e9281 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -60,7 +60,7 @@ func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID stri } func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, modules.Networks, operations.Write) + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, modules.Networks, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) } @@ -96,7 +96,7 @@ func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, network } func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, modules.Networks, operations.Write) + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, modules.Networks, operations.Update) if err != nil { return nil, status.NewPermissionValidationError(err) } @@ -118,7 +118,7 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network } func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Write) + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) } diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 9efd1fae6..21d1e54de 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -95,7 +95,7 @@ func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, } func (m *managerImpl) CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, modules.Networks, operations.Write) + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, modules.Networks, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) } @@ -187,7 +187,7 @@ func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networ } func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, modules.Networks, operations.Write) + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, modules.Networks, operations.Update) if err != nil { return nil, status.NewPermissionValidationError(err) } @@ -307,7 +307,7 @@ func (m *managerImpl) updateResourceGroups(ctx context.Context, transaction stor } func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Write) + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) } diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index 2c8f7f677..7b488b361 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -80,7 +80,7 @@ func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, use } func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, modules.Networks, operations.Write) + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, modules.Networks, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) } @@ -149,7 +149,7 @@ func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkI } func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, modules.Networks, operations.Write) + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, modules.Networks, operations.Update) if err != nil { return nil, status.NewPermissionValidationError(err) } @@ -195,7 +195,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t } func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Write) + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) } diff --git a/management/server/peer.go b/management/server/peer.go index 05e3b176b..27825a148 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -190,7 +190,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update) if err != nil { return nil, status.NewPermissionValidationError(err) } @@ -321,7 +321,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 0afaed829..406c3e49e 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -303,12 +303,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { group1.Peers = append(group1.Peers, peer1.ID) group2.Peers = append(group2.Peers, peer2.ID) - err = manager.SaveGroup(context.Background(), account.Id, userID, &group1) + err = manager.SaveGroup(context.Background(), account.Id, userID, &group1, true) if err != nil { t.Errorf("expecting group1 to be added, got failure %v", err) return } - err = manager.SaveGroup(context.Background(), account.Id, userID, &group2) + err = manager.SaveGroup(context.Background(), account.Id, userID, &group2, true) if err != nil { t.Errorf("expecting group2 to be added, got failure %v", err) return @@ -327,7 +327,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { }, }, } - policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) + policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return @@ -375,7 +375,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { } policy.Enabled = false - _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return @@ -1478,7 +1478,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { Name: "GroupC", Peers: []string{}, }, - }) + }, true) require.NoError(t, err) // create a user with auto groups @@ -1654,7 +1654,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { Action: types.PolicyTrafficActionAccept, }, }, - }) + }, true) require.NoError(t, err) done := make(chan struct{}) diff --git a/management/server/permissions/operations/operation.go b/management/server/permissions/operations/operation.go index af709de3a..11481234f 100644 --- a/management/server/permissions/operations/operation.go +++ b/management/server/permissions/operations/operation.go @@ -3,6 +3,8 @@ package operations type Operation string const ( - Read Operation = "read" - Write Operation = "write" + Create Operation = "create" + Read Operation = "read" + Update Operation = "update" + Delete Operation = "delete" ) diff --git a/management/server/permissions/roles/admin.go b/management/server/permissions/roles/admin.go index a826d186a..af3a81297 100644 --- a/management/server/permissions/roles/admin.go +++ b/management/server/permissions/roles/admin.go @@ -9,13 +9,17 @@ import ( var Admin = RolePermissions{ Role: types.UserRoleAdmin, AutoAllowNew: map[operations.Operation]bool{ - operations.Read: true, - operations.Write: true, + operations.Read: true, + operations.Create: true, + operations.Update: true, + operations.Delete: true, }, Permissions: Permissions{ modules.Accounts: { - operations.Read: true, - operations.Write: false, + operations.Read: true, + operations.Create: false, + operations.Update: false, + operations.Delete: false, }, }, } diff --git a/management/server/permissions/roles/owner.go b/management/server/permissions/roles/owner.go index f739d18ea..668470e47 100644 --- a/management/server/permissions/roles/owner.go +++ b/management/server/permissions/roles/owner.go @@ -8,7 +8,9 @@ import ( var Owner = RolePermissions{ Role: types.UserRoleOwner, AutoAllowNew: map[operations.Operation]bool{ - operations.Read: true, - operations.Write: true, + operations.Read: true, + operations.Create: true, + operations.Update: true, + operations.Delete: true, }, } diff --git a/management/server/permissions/roles/user.go b/management/server/permissions/roles/user.go index 6e8a9307b..bb3df0aea 100644 --- a/management/server/permissions/roles/user.go +++ b/management/server/permissions/roles/user.go @@ -8,7 +8,9 @@ import ( var User = RolePermissions{ Role: types.UserRoleUser, AutoAllowNew: map[operations.Operation]bool{ - operations.Read: false, - operations.Write: false, + operations.Read: false, + operations.Create: false, + operations.Update: false, + operations.Delete: false, }, } diff --git a/management/server/policy.go b/management/server/policy.go index 8f56bd493..1e9331d43 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -31,11 +31,15 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic } // SavePolicy in the store -func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) { +func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Write) + operation := operations.Create + if !create { + operation = operations.Update + } + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operation) if err != nil { return nil, status.NewPermissionValidationError(err) } @@ -87,7 +91,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) } diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 10b7fc2d1..0c1160cda 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -883,7 +883,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Name: "GroupD", Peers: []string{peer1.ID, peer2.ID}, }, - }) + }, true) assert.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -915,7 +915,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: types.PolicyTrafficActionAccept, }, }, - }) + }, true) assert.NoError(t, err) select { @@ -947,7 +947,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: types.PolicyTrafficActionAccept, }, }, - }) + }, true) assert.NoError(t, err) select { @@ -979,7 +979,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: types.PolicyTrafficActionAccept, }, }, - }) + }, true) assert.NoError(t, err) select { @@ -1010,7 +1010,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: types.PolicyTrafficActionAccept, }, }, - }) + }, true) assert.NoError(t, err) select { @@ -1030,7 +1030,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { }() policyWithSourceAndDestinationPeers.Enabled = false - policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers, true) assert.NoError(t, err) select { @@ -1051,7 +1051,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { policyWithSourceAndDestinationPeers.Description = "updated description" policyWithSourceAndDestinationPeers.Rules[0].Destinations = []string{"groupA"} - policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers, true) assert.NoError(t, err) select { @@ -1071,7 +1071,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { }() policyWithSourceAndDestinationPeers.Enabled = true - policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers, true) assert.NoError(t, err) select { diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 9b8067b8c..f91e89b45 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -31,11 +31,15 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID } // SavePostureChecks saves a posture check. -func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { +func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Write) + operation := operations.Create + if !create { + operation = operations.Update + } + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operation) if err != nil { return nil, status.NewPermissionValidationError(err) } diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index bad162f05..232955f7d 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -33,7 +33,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { t.Run("Generic posture check flow", func(t *testing.T) { // regular users can not create checks - _, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) + _, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}, true) assert.Error(t, err) // regular users cannot list check @@ -48,7 +48,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { MinVersion: "0.26.0", }, }, - }) + }, true) assert.NoError(t, err) // admin users can list check @@ -68,7 +68,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { }, }, }, - }) + }, true) assert.Error(t, err) // admins can update posture checks @@ -77,7 +77,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { MinVersion: "0.27.0", }, } - _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck) + _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck, true) assert.NoError(t, err) // users should not be able to delete posture checks @@ -137,7 +137,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Name: "GroupC", Peers: []string{}, }, - }) + }, true) assert.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -156,7 +156,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, } - postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA) + postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA, true) require.NoError(t, err) postureCheckB := &posture.Checks{ @@ -177,7 +177,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) + postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB, true) assert.NoError(t, err) select { @@ -200,7 +200,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { MinVersion: "0.29.0", }, } - _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB, true) assert.NoError(t, err) select { @@ -232,7 +232,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) + policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) assert.NoError(t, err) select { @@ -261,7 +261,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB, true) assert.NoError(t, err) select { @@ -280,7 +280,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }() policy.SourcePostureChecks = []string{} - _, err := manager.SavePolicy(context.Background(), account.Id, userID, policy) + _, err := manager.SavePolicy(context.Background(), account.Id, userID, policy, true) assert.NoError(t, err) select { @@ -308,7 +308,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB, true) assert.NoError(t, err) // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update @@ -325,7 +325,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, SourcePostureChecks: []string{postureCheckB.ID}, - }) + }, true) assert.NoError(t, err) done := make(chan struct{}) @@ -339,7 +339,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { MinVersion: "0.29.0", }, } - _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB, true) assert.NoError(t, err) select { @@ -369,7 +369,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, SourcePostureChecks: []string{postureCheckB.ID}, - }) + }, true) assert.NoError(t, err) done := make(chan struct{}) @@ -383,7 +383,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { MinVersion: "0.29.0", }, } - _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB, true) assert.NoError(t, err) select { @@ -408,7 +408,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, SourcePostureChecks: []string{postureCheckB.ID}, - }) + }, true) assert.NoError(t, err) done := make(chan struct{}) @@ -426,7 +426,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, } - _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB, true) assert.NoError(t, err) select { @@ -465,7 +465,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, }, } - postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA) + postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA, true) require.NoError(t, err, "failed to save postureCheckA") postureCheckB := &posture.Checks{ @@ -475,7 +475,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, }, } - postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB) + postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB, true) require.NoError(t, err, "failed to save postureCheckB") policy := &types.Policy{ @@ -490,7 +490,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { SourcePostureChecks: []string{postureCheckA.ID}, } - policy, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) + policy, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true) require.NoError(t, err, "failed to save policy") t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { @@ -514,7 +514,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { policy.Rules[0].Sources = []string{"groupB"} policy.Rules[0].Destinations = []string{"groupA"} - _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true) require.NoError(t, err, "failed to update policy") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) @@ -525,7 +525,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { policy.Rules[0].Sources = []string{"groupA"} policy.Rules[0].Destinations = []string{"groupB"} - _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true) require.NoError(t, err, "failed to update policy") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) @@ -546,7 +546,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { policy.Rules[0].Sources = []string{"nonExistentGroup"} policy.Rules[0].Destinations = []string{"nonExistentGroup"} - _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true) require.NoError(t, err, "failed to update policy") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) diff --git a/management/server/route.go b/management/server/route.go index 453da92b3..8b91e127a 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -120,7 +120,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) } @@ -238,7 +238,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update) if err != nil { return status.NewPermissionValidationError(err) } @@ -313,7 +313,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) } diff --git a/management/server/route_test.go b/management/server/route_test.go index 351dad8f7..dcda3e6d1 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1215,7 +1215,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { Name: "peer1 group", Peers: []string{peer1ID}, } - err = am.SaveGroup(context.Background(), account.Id, userID, newGroup) + err = am.SaveGroup(context.Background(), account.Id, userID, newGroup, true) require.NoError(t, err) rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser") @@ -1227,7 +1227,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID} - _, err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy) + _, err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, true) require.NoError(t, err) err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) @@ -1505,7 +1505,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou } for _, group := range newGroup { - err = am.SaveGroup(context.Background(), accountID, userID, group) + err = am.SaveGroup(context.Background(), accountID, userID, group, true) if err != nil { return nil, err } @@ -1959,7 +1959,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { Name: "GroupC", Peers: []string{}, }, - }) + }, true) assert.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID) @@ -2143,7 +2143,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { ID: "groupB", Name: "GroupB", Peers: []string{peer1ID}, - }) + }, true) assert.NoError(t, err) select { @@ -2183,7 +2183,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { ID: "groupC", Name: "GroupC", Peers: []string{peer1ID}, - }) + }, true) assert.NoError(t, err) select { diff --git a/management/server/setupkey.go b/management/server/setupkey.go index f205a170f..b0903c8d0 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -58,7 +58,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) } @@ -110,7 +110,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Update) if err != nil { return nil, status.NewPermissionValidationError(err) } @@ -203,7 +203,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use // DeleteSetupKey removes the setup key from the account func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) } diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 6e1e1cf7d..a561de40d 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -41,7 +41,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { Name: "group_name_2", Peers: []string{}, }, - }) + }, true) if err != nil { t.Fatal(err) } @@ -109,7 +109,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { ID: "group_1", Name: "group_name_1", Peers: []string{}, - }) + }, true) if err != nil { t.Fatal(err) } @@ -118,7 +118,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { ID: "group_2", Name: "group_name_2", Peers: []string{}, - }) + }, true) if err != nil { t.Fatal(err) } @@ -403,7 +403,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }) + }, true) assert.NoError(t, err) policy := &types.Policy{ @@ -418,7 +418,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { }, }, } - _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) diff --git a/management/server/user.go b/management/server/user.go index c952100b6..bffdfdb80 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -28,7 +28,7 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) } @@ -86,7 +86,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return nil, err } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Users, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Users, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) } @@ -234,7 +234,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return err } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) } @@ -291,7 +291,7 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin return status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites") } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create) if err != nil { return status.NewPermissionValidationError(err) } @@ -338,7 +338,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) } @@ -380,7 +380,7 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) } @@ -502,7 +502,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, return nil, nil //nolint:nilnil } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create) // TODO: split by Create and Update if err != nil { return nil, status.NewPermissionValidationError(err) } @@ -973,7 +973,7 @@ func (am *DefaultAccountManager) deleteUserFromIDP(ctx context.Context, targetUs // If an error occurs while deleting the user, the function skips it and continues deleting other users. // Errors are collected and returned at the end. func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Write) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) } diff --git a/management/server/user_test.go b/management/server/user_test.go index e7020fe00..4684d192a 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1338,7 +1338,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }) + }, true) require.NoError(t, err) policy := &types.Policy{ @@ -1353,7 +1353,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { }, }, } - _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) diff --git a/release_files/install.sh b/release_files/install.sh index 459645c58..e5a61dcfe 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -109,6 +109,9 @@ add_apt_repo() { curl -sSL https://pkgs.netbird.io/debian/public.key \ | ${SUDO} gpg --dearmor -o /usr/share/keyrings/netbird-archive-keyring.gpg + # Explicitly set the file permission + ${SUDO} chmod 0644 /usr/share/keyrings/netbird-archive-keyring.gpg + echo 'deb [signed-by=/usr/share/keyrings/netbird-archive-keyring.gpg] https://pkgs.netbird.io/debian stable main' \ | ${SUDO} tee /etc/apt/sources.list.d/netbird.list