diff --git a/.github/workflows/android-build-validation.yml b/.github/workflows/android-build-validation.yml new file mode 100644 index 000000000..57cbbacb4 --- /dev/null +++ b/.github/workflows/android-build-validation.yml @@ -0,0 +1,41 @@ +name: Android build validation + +on: + push: + branches: + - main + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} + cancel-in-progress: true + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v3 + - name: Install Go + uses: actions/setup-go@v4 + with: + go-version: "1.20.x" + - name: Setup Android SDK + uses: android-actions/setup-android@v2 + - name: NDK Cache + id: ndk-cache + uses: actions/cache@v3 + with: + path: /usr/local/lib/android/sdk/ndk + key: ndk-cache-23.1.7779620 + - name: Setup NDK + run: /usr/local/lib/android/sdk/tools/bin/sdkmanager --install "ndk;23.1.7779620" + - name: install gomobile + run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda + - name: gomobile init + run: gomobile init + - name: build android nebtird lib + run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android + env: + CGO_ENABLED: 0 + ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620 \ No newline at end of file diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml index f7d6766fc..2987c04b4 100644 --- a/.github/workflows/test-infrastructure-files.yml +++ b/.github/workflows/test-infrastructure-files.yml @@ -80,6 +80,7 @@ jobs: CI_NETBIRD_MGMT_IDP: "none" CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret + CI_NETBIRD_SIGNAL_PORT: 12345 run: | grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID @@ -91,6 +92,7 @@ jobs: grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073" grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$' + grep $CI_NETBIRD_SIGNAL_PORT docker-compose.yml | grep ':80' grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$' grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM diff --git a/client/android/login.go b/client/android/login.go index ad334541c..8d2636c9a 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -84,10 +84,14 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) { func (a *Auth) saveConfigIfSSOSupported() (bool, error) { supportsSSO := true err := a.withBackOff(a.ctx, func() (err error) { - _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) - if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound { - _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) - if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound { + _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) + if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { + _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) + s, ok := gstatus.FromError(err) + if !ok { + return err + } + if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented { supportsSSO = false err = nil } diff --git a/client/cmd/login.go b/client/cmd/login.go index a5cc3215c..5433db522 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -3,8 +3,6 @@ package cmd import ( "context" "fmt" - "os" - "runtime" "strings" "time" @@ -195,51 +193,12 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) { codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) } - browserAuthMsg := "Please do the SSO login in your browser. \n" + + cmd.Println("Please do the SSO login in your browser. \n" + "If your browser didn't open automatically, use this URL to log in:\n\n" + - verificationURIComplete + " " + codeMsg - - setupKeyAuthMsg := "\nAlternatively, you may want to use a setup key, see:\n\n" + - "https://docs.netbird.io/how-to/register-machines-using-setup-keys" - - authenticateUsingBrowser := func() { - cmd.Println(browserAuthMsg) - cmd.Println("") - if err := open.Run(verificationURIComplete); err != nil { - cmd.Println(setupKeyAuthMsg) - } - } - - switch runtime.GOOS { - case "windows", "darwin": - authenticateUsingBrowser() - case "linux": - if isLinuxRunningDesktop() { - authenticateUsingBrowser() - } else { - // If current flow is PKCE, it implies the server is anticipating the redirect to localhost. - // Devices lacking browser support are incompatible with this flow.Therefore, - // these devices will need to resort to setup keys instead. - if isPKCEFlow(verificationURIComplete) { - cmd.Println("Please proceed with setting up this device using setup keys, see:\n\n" + - "https://docs.netbird.io/how-to/register-machines-using-setup-keys") - } else { - cmd.Println(browserAuthMsg) - } - } + verificationURIComplete + " " + codeMsg) + cmd.Println("") + if err := open.Run(verificationURIComplete); err != nil { + cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } } - -// isLinuxRunningDesktop checks if a Linux OS is running desktop environment. -func isLinuxRunningDesktop() bool { - return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" -} - -// isPKCEFlow determines if the PKCE flow is active or not, -// by checking the existence of redirect_uri inside the verification URL. -func isPKCEFlow(verificationURL string) bool { - if verificationURL == "" { - return false - } - return strings.Contains(verificationURL, "redirect_uri") -} diff --git a/client/cmd/testutil.go b/client/cmd/testutil.go index 678072f0b..6d47021dd 100644 --- a/client/cmd/testutil.go +++ b/client/cmd/testutil.go @@ -76,7 +76,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste return nil, nil } accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "", - eventStore) + eventStore, false) if err != nil { t.Fatal(err) } diff --git a/client/internal/acl/manager_create.go b/client/internal/acl/manager_create.go index c573d2c64..2fdca02ae 100644 --- a/client/internal/acl/manager_create.go +++ b/client/internal/acl/manager_create.go @@ -1,4 +1,4 @@ -//go:build !linux +//go:build !linux || android package acl diff --git a/client/internal/acl/manager_create_linux.go b/client/internal/acl/manager_create_linux.go index 4342463d3..05b042351 100644 --- a/client/internal/acl/manager_create_linux.go +++ b/client/internal/acl/manager_create_linux.go @@ -1,3 +1,5 @@ +//go:build !android + package acl import ( diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index 794fe0958..8731e4f0b 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -4,8 +4,8 @@ import ( "context" "fmt" "net/http" + "runtime" - log "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" @@ -57,25 +57,43 @@ func (t TokenInfo) GetTokenToUse() string { return t.AccessToken } -// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration. +// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration +// +// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow, +// and if that also fails, the authentication process is deemed unsuccessful +// +// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { - log.Debug("loading pkce authorization flow info") - - pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) - if err == nil { - return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) + if runtime.GOOS == "linux" && !isLinuxRunningDesktop() { + return authenticateWithDeviceCodeFlow(ctx, config) } - log.Debugf("loading pkce authorization flow info failed with error: %v", err) - log.Debugf("falling back to device authorization flow info") + pkceFlow, err := authenticateWithPKCEFlow(ctx, config) + if err != nil { + // fallback to device code flow + return authenticateWithDeviceCodeFlow(ctx, config) + } + return pkceFlow, nil +} +// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow +func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { + pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) + if err != nil { + return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) + } + return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) +} + +// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow +func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) if err != nil { s, ok := gstatus.FromError(err) if ok && s.Code() == codes.NotFound { return nil, fmt.Errorf("no SSO provider returned from management. " + - "If you are using hosting Netbird see documentation at " + - "https://github.com/netbirdio/netbird/tree/main/management for details") + "Please proceed with setting up this device using setup keys " + + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } else if ok && s.Code() == codes.Unimplemented { return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+ "please update your server or use Setup Keys to login", config.ManagementURL) diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index d15d49373..a3d0c1309 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -12,7 +12,6 @@ import ( "net/http" "net/url" "strings" - "sync" "time" log "github.com/sirupsen/logrus" @@ -80,7 +79,7 @@ func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string { } // RequestAuthInfo requests a authorization code login flow information. -func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) { +func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) { state, err := randomBytesInHex(24) if err != nil { return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err) @@ -114,64 +113,37 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) ( tokenChan := make(chan *oauth2.Token, 1) errChan := make(chan error, 1) - go p.startServer(tokenChan, errChan) + parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL) + if err != nil { + return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err) + } + + server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())} + defer func() { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + if err := server.Shutdown(shutdownCtx); err != nil { + log.Errorf("failed to close the server: %v", err) + } + }() + + go p.startServer(server, tokenChan, errChan) select { case <-ctx.Done(): return TokenInfo{}, ctx.Err() case token := <-tokenChan: - return p.handleOAuthToken(token) + return p.parseOAuthToken(token) case err := <-errChan: return TokenInfo{}, err } } -func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) { - var wg sync.WaitGroup - - parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL) - if err != nil { - errChan <- fmt.Errorf("failed to parse redirect URL: %v", err) - return - } - - server := http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())} - go func() { - if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - errChan <- err - } - }() - - wg.Add(1) - http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { - defer wg.Done() - - tokenValidatorFunc := func() (*oauth2.Token, error) { - query := req.URL.Query() - - if authError := query.Get(queryError); authError != "" { - authErrorDesc := query.Get(queryErrorDesc) - return nil, fmt.Errorf("%s.%s", authError, authErrorDesc) - } - - // Prevent timing attacks on state - if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 { - return nil, fmt.Errorf("invalid state") - } - - code := query.Get(queryCode) - if code == "" { - return nil, fmt.Errorf("missing code") - } - - return p.oAuthConfig.Exchange( - req.Context(), - code, - oauth2.SetAuthURLParam("code_verifier", p.codeVerifier), - ) - } - - token, err := tokenValidatorFunc() +func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + token, err := p.handleRequest(req) if err != nil { renderPKCEFlowTmpl(w, err) errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err) @@ -182,13 +154,38 @@ func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errC tokenChan <- token }) - wg.Wait() - if err := server.Shutdown(context.Background()); err != nil { - log.Errorf("error while shutting down pkce flow server: %v", err) + server.Handler = mux + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + errChan <- err } } -func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) { +func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) { + query := req.URL.Query() + + if authError := query.Get(queryError); authError != "" { + authErrorDesc := query.Get(queryErrorDesc) + return nil, fmt.Errorf("%s.%s", authError, authErrorDesc) + } + + // Prevent timing attacks on the state + if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 { + return nil, fmt.Errorf("invalid state") + } + + code := query.Get(queryCode) + if code == "" { + return nil, fmt.Errorf("missing code") + } + + return p.oAuthConfig.Exchange( + req.Context(), + code, + oauth2.SetAuthURLParam("code_verifier", p.codeVerifier), + ) +} + +func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) { tokenInfo := TokenInfo{ AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, diff --git a/client/internal/auth/util.go b/client/internal/auth/util.go index 33a0e6e35..e61e0f175 100644 --- a/client/internal/auth/util.go +++ b/client/internal/auth/util.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "io" + "os" "reflect" "strings" ) @@ -60,3 +61,8 @@ func isValidAccessToken(token string, audience string) error { return fmt.Errorf("invalid JWT token audience field") } + +// isLinuxRunningDesktop checks if a Linux OS is running desktop environment +func isLinuxRunningDesktop() bool { + return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" +} diff --git a/client/internal/checkfw/check.go b/client/internal/checkfw/check.go index 59626cbc3..edfd8a5b3 100644 --- a/client/internal/checkfw/check.go +++ b/client/internal/checkfw/check.go @@ -1,3 +1,3 @@ -//go:build !linux +//go:build !linux || android package checkfw diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 9f17ff36b..ea4a23a8d 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1049,7 +1049,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) { return nil, "", err } accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", - eventStore) + eventStore, false) if err != nil { return nil, "", err } diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index 3e7eb1df6..b70e4cb6e 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -36,7 +36,7 @@ services: volumes: - $SIGNAL_VOLUMENAME:/var/lib/netbird ports: - - 10000:80 + - $NETBIRD_SIGNAL_PORT:80 # # port and command for Let's Encrypt validation # - 443:443 # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] diff --git a/infrastructure_files/tests/setup.env b/infrastructure_files/tests/setup.env index 6cf1acdf4..b0999eb51 100644 --- a/infrastructure_files/tests/setup.env +++ b/infrastructure_files/tests/setup.env @@ -21,4 +21,5 @@ NETBIRD_AUTH_USER_ID_CLAIM="email" NETBIRD_AUTH_DEVICE_AUTH_SCOPE="openid email" NETBIRD_MGMT_IDP=$CI_NETBIRD_MGMT_IDP NETBIRD_IDP_MGMT_CLIENT_ID=$CI_NETBIRD_IDP_MGMT_CLIENT_ID -NETBIRD_IDP_MGMT_CLIENT_SECRET=$CI_NETBIRD_IDP_MGMT_CLIENT_SECRET \ No newline at end of file +NETBIRD_IDP_MGMT_CLIENT_SECRET=$CI_NETBIRD_IDP_MGMT_CLIENT_SECRET +NETBIRD_SIGNAL_PORT=12345 \ No newline at end of file diff --git a/management/client/client_test.go b/management/client/client_test.go index deef57329..86c598adb 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -61,7 +61,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { peersUpdateManager := mgmt.NewPeersUpdateManager() eventStore := &activity.InMemoryEventStore{} accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "", - eventStore) + eventStore, false) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index 5c3816715..ca333b931 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -31,6 +31,7 @@ import ( "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity/sqlite" httpapi "github.com/netbirdio/netbird/management/server/http" "github.com/netbirdio/netbird/management/server/idp" @@ -142,12 +143,22 @@ var ( if disableSingleAccMode { mgmtSingleAccModeDomain = "" } - eventStore, err := sqlite.NewSQLiteStore(config.Datadir) + eventStore, key, err := initEventStore(config.Datadir, config.DataStoreEncryptionKey) if err != nil { - return err + return fmt.Errorf("failed to initialize database: %s", err) } + + if key != "" { + log.Debugf("update config with activity store key") + config.DataStoreEncryptionKey = key + err := updateMgmtConfig(mgmtConfig, config) + if err != nil { + return fmt.Errorf("failed to write out store encryption key: %s", err) + } + } + accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, - dnsDomain, eventStore) + dnsDomain, eventStore, userDeleteFromIDPEnabled) if err != nil { return fmt.Errorf("failed to build default manager: %v", err) } @@ -287,6 +298,20 @@ var ( } ) +func initEventStore(dataDir string, key string) (activity.Store, string, error) { + var err error + if key == "" { + log.Debugf("generate new activity store encryption key") + key, err = sqlite.GenerateKey() + if err != nil { + return nil, "", err + } + } + store, err := sqlite.NewSQLiteStore(dataDir, key) + return store, key, err + +} + func notifyStop(msg string) { select { case stopCh <- 1: @@ -440,6 +465,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { return loadedConfig, err } +func updateMgmtConfig(path string, config *server.Config) error { + return util.WriteJson(path, config) +} + // OIDCConfigResponse used for parsing OIDC config response type OIDCConfigResponse struct { Issuer string `json:"issuer"` diff --git a/management/cmd/root.go b/management/cmd/root.go index a149841c5..2080a6b29 100644 --- a/management/cmd/root.go +++ b/management/cmd/root.go @@ -24,6 +24,7 @@ var ( disableMetrics bool disableSingleAccMode bool idpSignKeyRefreshEnabled bool + userDeleteFromIDPEnabled bool rootCmd = &cobra.Command{ Use: "netbird-mgmt", @@ -56,6 +57,7 @@ func init() { mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird") mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain)) mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.") + mgmtCmd.Flags().BoolVar(&userDeleteFromIDPEnabled, "user-delete-from-idp", false, "Allows to delete user from IDP when user is deleted from account") rootCmd.MarkFlagRequired("config") //nolint rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "") diff --git a/management/server/account.go b/management/server/account.go index 74b2bdfb3..72ba73d7d 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -81,7 +81,6 @@ type AccountManager interface { GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) GetGroup(accountId, groupID string) (*Group, error) SaveGroup(accountID, userID string, group *Group) error - UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error) DeleteGroup(accountId, userId, groupID string) error ListGroups(accountId string) ([]*Group, error) GroupAddPeer(accountId, groupID, peerID string) error @@ -94,13 +93,11 @@ type AccountManager interface { GetRoute(accountID, routeID, userID string) (*route.Route, error) CreateRoute(accountID string, prefix, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) SaveRoute(accountID, userID string, route *route.Route) error - UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error) DeleteRoute(accountID, routeID, userID string) error ListRoutes(accountID, userID string) ([]*route.Route, error) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error - UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) DeleteNameServerGroup(accountID, nsGroupID, userID string) error ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) GetDNSDomain() string @@ -134,6 +131,9 @@ type DefaultAccountManager struct { // dnsDomain is used for peer resolution. This is appended to the peer's name dnsDomain string peerLoginExpiry Scheduler + + // userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account + userDeleteFromIDPEnabled bool } // Settings represents Account settings structure that can be modified via API and Dashboard @@ -739,18 +739,19 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { // BuildManager creates a new DefaultAccountManager with a provided Store func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, - singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, + singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, userDeleteFromIDPEnabled bool, ) (*DefaultAccountManager, error) { am := &DefaultAccountManager{ - Store: store, - peersUpdateManager: peersUpdateManager, - idpManager: idpManager, - ctx: context.Background(), - cacheMux: sync.Mutex{}, - cacheLoading: map[string]chan struct{}{}, - dnsDomain: dnsDomain, - eventStore: eventStore, - peerLoginExpiry: NewDefaultScheduler(), + Store: store, + peersUpdateManager: peersUpdateManager, + idpManager: idpManager, + ctx: context.Background(), + cacheMux: sync.Mutex{}, + cacheLoading: map[string]chan struct{}{}, + dnsDomain: dnsDomain, + eventStore: eventStore, + peerLoginExpiry: NewDefaultScheduler(), + userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, } allAccounts := store.GetAllAccounts() // enable single account mode only if configured by user and number of existing accounts is not grater than 1 @@ -875,33 +876,19 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() return account.GetNextPeerExpiration() } + expiredPeers := account.GetExpiredPeers() var peerIDs []string - for _, peer := range account.GetExpiredPeers() { - if peer.Status.LoginExpired { - continue - } + for _, peer := range expiredPeers { peerIDs = append(peerIDs, peer.ID) - peer.MarkLoginExpired(true) - account.UpdatePeer(peer) - err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status) - if err != nil { - log.Errorf("failed saving peer status while expiring peer %s", peer.ID) - return account.GetNextPeerExpiration() - } - am.storeEvent(peer.UserID, peer.ID, account.Id, activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain())) } log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) - if len(peerIDs) != 0 { - // this will trigger peer disconnect from the management service - am.peersUpdateManager.CloseChannels(peerIDs) - err = am.updateAccountPeers(account) - if err != nil { - log.Errorf("failed updating account peers while expiring peers for account %s", accountID) - return account.GetNextPeerExpiration() - } + if err := am.expireAndUpdatePeers(account, expiredPeers); err != nil { + log.Errorf("failed updating account peers while expiring peers for account %s", account.Id) + return account.GetNextPeerExpiration() } + return account.GetNextPeerExpiration() } } @@ -1672,19 +1659,3 @@ func newAccountWithId(accountID, userID, domain string) *Account { } return acc } - -func removeFromList(inputList []string, toRemove []string) []string { - toRemoveMap := make(map[string]struct{}) - for _, item := range toRemove { - toRemoveMap[item] = struct{}{} - } - - var resultList []string - for _, item := range inputList { - _, ok := toRemoveMap[item] - if !ok { - resultList = append(resultList, item) - } - } - return resultList -} diff --git a/management/server/account_test.go b/management/server/account_test.go index 64fd90524..204e98947 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2063,7 +2063,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore) + return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore, false) } func createStore(t *testing.T) (Store, error) { diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 4de667ded..ce36f520f 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -104,6 +104,8 @@ const ( UserBlocked // UserUnblocked indicates that a user unblocked another user UserUnblocked + // UserDeleted indicates that a user deleted another user + UserDeleted // GroupDeleted indicates that a user deleted group GroupDeleted // UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login @@ -162,6 +164,7 @@ var activityMap = map[Activity]Code{ ServiceUserDeleted: {"Service user deleted", "service.user.delete"}, UserBlocked: {"User blocked", "user.block"}, UserUnblocked: {"User unblocked", "user.unblock"}, + UserDeleted: {"User deleted", "user.delete"}, GroupDeleted: {"Group deleted", "group.delete"}, UserLoggedInPeer: {"User logged in peer", "user.peer.login"}, PeerLoginExpired: {"Peer login expired", "peer.login.expire"}, diff --git a/management/server/activity/event.go b/management/server/activity/event.go index 17ec4a0b0..1bf86ef2c 100644 --- a/management/server/activity/event.go +++ b/management/server/activity/event.go @@ -18,10 +18,13 @@ type Event struct { ID uint64 // InitiatorID is the ID of an object that initiated the event (e.g., a user) InitiatorID string + // InitiatorEmail is the email address of an object that initiated the event. This will be set on deleted users only + InitiatorEmail string // TargetID is the ID of an object that was effected by the event (e.g., a peer) TargetID string // AccountID is the ID of an account where the event happened AccountID string + // Meta of the event, e.g. deleted peer information like name, IP, etc Meta map[string]any } @@ -35,12 +38,13 @@ func (e *Event) Copy() *Event { } return &Event{ - Timestamp: e.Timestamp, - Activity: e.Activity, - ID: e.ID, - InitiatorID: e.InitiatorID, - TargetID: e.TargetID, - AccountID: e.AccountID, - Meta: meta, + Timestamp: e.Timestamp, + Activity: e.Activity, + ID: e.ID, + InitiatorID: e.InitiatorID, + InitiatorEmail: e.InitiatorEmail, + TargetID: e.TargetID, + AccountID: e.AccountID, + Meta: meta, } } diff --git a/management/server/activity/sqlite/crypt.go b/management/server/activity/sqlite/crypt.go new file mode 100644 index 000000000..8f2755604 --- /dev/null +++ b/management/server/activity/sqlite/crypt.go @@ -0,0 +1,81 @@ +package sqlite + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "fmt" +) + +var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05} + +type EmailEncrypt struct { + block cipher.Block +} + +func GenerateKey() (string, error) { + key := make([]byte, 32) + _, err := rand.Read(key) + if err != nil { + return "", err + } + readableKey := base64.StdEncoding.EncodeToString(key) + return readableKey, nil +} + +func NewEmailEncrypt(key string) (*EmailEncrypt, error) { + binKey, err := base64.StdEncoding.DecodeString(key) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(binKey) + if err != nil { + return nil, err + } + ec := &EmailEncrypt{ + block: block, + } + + return ec, nil +} + +func (ec *EmailEncrypt) Encrypt(payload string) string { + plainText := pkcs5Padding([]byte(payload)) + cipherText := make([]byte, len(plainText)) + cbc := cipher.NewCBCEncrypter(ec.block, iv) + cbc.CryptBlocks(cipherText, plainText) + return base64.StdEncoding.EncodeToString(cipherText) +} + +func (ec *EmailEncrypt) Decrypt(data string) (string, error) { + cipherText, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return "", err + } + cbc := cipher.NewCBCDecrypter(ec.block, iv) + cbc.CryptBlocks(cipherText, cipherText) + payload, err := pkcs5UnPadding(cipherText) + if err != nil { + return "", err + } + + return string(payload), nil +} + +func pkcs5Padding(ciphertext []byte) []byte { + padding := aes.BlockSize - len(ciphertext)%aes.BlockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(ciphertext, padText...) +} + +func pkcs5UnPadding(src []byte) ([]byte, error) { + srcLen := len(src) + paddingLen := int(src[srcLen-1]) + if paddingLen >= srcLen || paddingLen > aes.BlockSize { + return nil, fmt.Errorf("padding size error") + } + return src[:srcLen-paddingLen], nil +} diff --git a/management/server/activity/sqlite/crypt_test.go b/management/server/activity/sqlite/crypt_test.go new file mode 100644 index 000000000..5fb59a692 --- /dev/null +++ b/management/server/activity/sqlite/crypt_test.go @@ -0,0 +1,63 @@ +package sqlite + +import ( + "testing" +) + +func TestGenerateKey(t *testing.T) { + testData := "exampl@netbird.io" + key, err := GenerateKey() + if err != nil { + t.Fatalf("failed to generate key: %s", err) + } + ee, err := NewEmailEncrypt(key) + if err != nil { + t.Fatalf("failed to init email encryption: %s", err) + } + + encrypted := ee.Encrypt(testData) + if encrypted == "" { + t.Fatalf("invalid encrypted text") + } + + decrypted, err := ee.Decrypt(encrypted) + if err != nil { + t.Fatalf("failed to decrypt data: %s", err) + } + + if decrypted != testData { + t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted) + } +} + +func TestCorruptKey(t *testing.T) { + testData := "exampl@netbird.io" + key, err := GenerateKey() + if err != nil { + t.Fatalf("failed to generate key: %s", err) + } + ee, err := NewEmailEncrypt(key) + if err != nil { + t.Fatalf("failed to init email encryption: %s", err) + } + + encrypted := ee.Encrypt(testData) + if encrypted == "" { + t.Fatalf("invalid encrypted text") + } + + newKey, err := GenerateKey() + if err != nil { + t.Fatalf("failed to generate key: %s", err) + } + + ee, err = NewEmailEncrypt(newKey) + if err != nil { + t.Fatalf("failed to init email encryption: %s", err) + } + + res, err := ee.Decrypt(encrypted) + if err == nil || res == testData { + t.Fatalf("incorrect decryption, the result is: %s", res) + } +} diff --git a/management/server/activity/sqlite/sqlite.go b/management/server/activity/sqlite/sqlite.go index a4c85cf60..7ff59674d 100644 --- a/management/server/activity/sqlite/sqlite.go +++ b/management/server/activity/sqlite/sqlite.go @@ -3,14 +3,14 @@ package sqlite import ( "database/sql" "encoding/json" - - "github.com/netbirdio/netbird/management/server/activity" - - // sqlite driver + "fmt" "path/filepath" "time" - _ "github.com/mattn/go-sqlite3" + _ "github.com/mattn/go-sqlite3" // sqlite driver + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/activity" ) const ( @@ -25,35 +25,62 @@ const ( "meta TEXT," + " target_id TEXT);" - selectDescQuery = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" + - " FROM events WHERE account_id = ? ORDER BY timestamp DESC LIMIT ? OFFSET ?;" - selectAscQuery = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" + - " FROM events WHERE account_id = ? ORDER BY timestamp ASC LIMIT ? OFFSET ?;" + creatTableAccountEmailQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL);` + + selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.email as "initiator_email", target_id, t.email as "target_email", account_id, meta + FROM events + LEFT JOIN deleted_users i ON events.initiator_id = i.id + LEFT JOIN deleted_users t ON events.target_id = t.id + WHERE account_id = ? + ORDER BY timestamp DESC LIMIT ? OFFSET ?;` + + selectAscQuery = `SELECT events.id, activity, timestamp, initiator_id, i.email as "initiator_email", target_id, t.email as "target_email", account_id, meta + FROM events + LEFT JOIN deleted_users i ON events.initiator_id = i.id + LEFT JOIN deleted_users t ON events.target_id = t.id + WHERE account_id = ? + ORDER BY timestamp ASC LIMIT ? OFFSET ?;` + insertQuery = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " + "VALUES(?, ?, ?, ?, ?, ?)" + + insertDeleteUserQuery = `INSERT INTO deleted_users(id, email) VALUES(?, ?)` ) // Store is the implementation of the activity.Store interface backed by SQLite type Store struct { - db *sql.DB + db *sql.DB + emailEncrypt *EmailEncrypt + insertStatement *sql.Stmt selectAscStatement *sql.Stmt selectDescStatement *sql.Stmt + deleteUserStmt *sql.Stmt } // NewSQLiteStore creates a new Store with an event table if not exists. -func NewSQLiteStore(dataDir string) (*Store, error) { +func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) { dbFile := filepath.Join(dataDir, eventSinkDB) db, err := sql.Open("sqlite3", dbFile) if err != nil { return nil, err } + crypt, err := NewEmailEncrypt(encryptionKey) + if err != nil { + return nil, err + } + _, err = db.Exec(createTableQuery) if err != nil { return nil, err } + _, err = db.Exec(creatTableAccountEmailQuery) + if err != nil { + return nil, err + } + insertStmt, err := db.Prepare(insertQuery) if err != nil { return nil, err @@ -69,25 +96,35 @@ func NewSQLiteStore(dataDir string) (*Store, error) { return nil, err } - return &Store{ + deleteUserStmt, err := db.Prepare(insertDeleteUserQuery) + if err != nil { + return nil, err + } + + s := &Store{ db: db, + emailEncrypt: crypt, insertStatement: insertStmt, selectDescStatement: selectDescStmt, selectAscStatement: selectAscStmt, - }, nil + deleteUserStmt: deleteUserStmt, + } + return s, nil } -func processResult(result *sql.Rows) ([]*activity.Event, error) { +func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) { events := make([]*activity.Event, 0) for result.Next() { var id int64 var operation activity.Activity var timestamp time.Time var initiator string + var initiatorEmail *string var target string + var targetEmail *string var account string var jsonMeta string - err := result.Scan(&id, &operation, ×tamp, &initiator, &target, &account, &jsonMeta) + err := result.Scan(&id, &operation, ×tamp, &initiator, &initiatorEmail, &target, &targetEmail, &account, &jsonMeta) if err != nil { return nil, err } @@ -100,7 +137,17 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) { } } - events = append(events, &activity.Event{ + if targetEmail != nil { + email, err := store.emailEncrypt.Decrypt(*targetEmail) + if err != nil { + log.Errorf("failed to decrypt email address for target id: %s", target) + meta["email"] = "" + } else { + meta["email"] = email + } + } + + event := &activity.Event{ Timestamp: timestamp, Activity: operation, ID: uint64(id), @@ -108,7 +155,18 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) { TargetID: target, AccountID: account, Meta: meta, - }) + } + + if initiatorEmail != nil { + email, err := store.emailEncrypt.Decrypt(*initiatorEmail) + if err != nil { + log.Errorf("failed to decrypt email address of initiator: %s", initiator) + } else { + event.InitiatorEmail = email + } + } + + events = append(events, event) } return events, nil @@ -127,13 +185,18 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([ } defer result.Close() //nolint - return processResult(result) + return store.processResult(result) } -// Save an event in the SQLite events table +// Save an event in the SQLite events table end encrypt the "email" element in meta map func (store *Store) Save(event *activity.Event) (*activity.Event, error) { var jsonMeta string - if event.Meta != nil { + meta, err := store.saveDeletedUserEmailInEncrypted(event) + if err != nil { + return nil, err + } + + if meta != nil { metaBytes, err := json.Marshal(event.Meta) if err != nil { return nil, err @@ -156,6 +219,29 @@ func (store *Store) Save(event *activity.Event) (*activity.Event, error) { return eventCopy, nil } +// saveDeletedUserEmailInEncrypted if the meta contains email then store it in encrypted way and delete this item from +// meta map +func (store *Store) saveDeletedUserEmailInEncrypted(event *activity.Event) (map[string]any, error) { + email, ok := event.Meta["email"] + if !ok { + return event.Meta, nil + } + + delete(event.Meta, "email") + + encrypted := store.emailEncrypt.Encrypt(fmt.Sprintf("%s", email)) + _, err := store.deleteUserStmt.Exec(event.TargetID, encrypted) + if err != nil { + return nil, err + } + + if len(event.Meta) == 1 { + return nil, nil // nolint + } + delete(event.Meta, "email") + return event.Meta, nil +} + // Close the Store func (store *Store) Close() error { if store.db != nil { diff --git a/management/server/activity/sqlite/sqlite_test.go b/management/server/activity/sqlite/sqlite_test.go index 2ca9a1e64..f6a6f9467 100644 --- a/management/server/activity/sqlite/sqlite_test.go +++ b/management/server/activity/sqlite/sqlite_test.go @@ -12,7 +12,8 @@ import ( func TestNewSQLiteStore(t *testing.T) { dataDir := t.TempDir() - store, err := NewSQLiteStore(dataDir) + key, _ := GenerateKey() + store, err := NewSQLiteStore(dataDir, key) if err != nil { t.Fatal(err) return diff --git a/management/server/config.go b/management/server/config.go index ea0143988..31c1cf45c 100644 --- a/management/server/config.go +++ b/management/server/config.go @@ -35,7 +35,8 @@ type Config struct { TURNConfig *TURNConfig Signal *Host - Datadir string + Datadir string + DataStoreEncryptionKey string HttpConfig *HttpServerConfig diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 092c52afa..b089949b2 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -191,7 +191,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.test", eventStore) + return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.test", eventStore, false) } func createDNSStore(t *testing.T) (Store, error) { diff --git a/management/server/group.go b/management/server/group.go index 5b1d2ac9f..697fe5d70 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -33,26 +33,6 @@ type Group struct { Peers []string } -const ( - // UpdateGroupName indicates a name update operation - UpdateGroupName GroupUpdateOperationType = iota - // InsertPeersToGroup indicates insert peers to group operation - InsertPeersToGroup - // RemovePeersFromGroup indicates a remove peers from group operation - RemovePeersFromGroup - // UpdateGroupPeers indicates a replacement of group peers list - UpdateGroupPeers -) - -// GroupUpdateOperationType operation type -type GroupUpdateOperationType int - -// GroupUpdateOperation operation object with type and values to be applied -type GroupUpdateOperation struct { - Type GroupUpdateOperationType - Values []string -} - // EventMeta returns activity event meta related to the group func (g *Group) EventMeta() map[string]any { return map[string]any{"name": g.Name} @@ -165,57 +145,6 @@ func difference(a, b []string) []string { return diff } -// UpdateGroup updates a group using a list of operations -func (am *DefaultAccountManager) UpdateGroup(accountID string, - groupID string, operations []GroupUpdateOperation, -) (*Group, error) { - unlock := am.Store.AcquireAccountLock(accountID) - defer unlock() - - account, err := am.Store.GetAccount(accountID) - if err != nil { - return nil, err - } - - groupToUpdate, ok := account.Groups[groupID] - if !ok { - return nil, status.Errorf(status.NotFound, "group with ID %s no longer exists", groupID) - } - - group := groupToUpdate.Copy() - - for _, operation := range operations { - switch operation.Type { - case UpdateGroupName: - group.Name = operation.Values[0] - case UpdateGroupPeers: - group.Peers = operation.Values - case InsertPeersToGroup: - sourceList := group.Peers - resultList := removeFromList(sourceList, operation.Values) - group.Peers = append(resultList, operation.Values...) - case RemovePeersFromGroup: - sourceList := group.Peers - resultList := removeFromList(sourceList, operation.Values) - group.Peers = resultList - } - } - - account.Groups[groupID] = group - - account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { - return nil, err - } - - err = am.updateAccountPeers(account) - if err != nil { - return nil, err - } - - return group, nil -} - // DeleteGroup object of the peers func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error { unlock := am.Store.AcquireAccountLock(accountId) diff --git a/management/server/http/api/generate.sh b/management/server/http/api/generate.sh old mode 100644 new mode 100755 diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 06da0ede3..f2d1e26bf 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -922,6 +922,10 @@ components: description: The ID of the initiator of the event. E.g., an ID of a user that triggered the event. type: string example: google-oauth2|123456789012345678901 + initiator_email: + description: The e-mail address of the initiator of the event. E.g., an e-mail of a user that triggered the event. + type: string + example: demo@netbird.io target_id: description: The ID of the target of the event. E.g., an ID of the peer that a user removed. type: string @@ -938,6 +942,7 @@ components: - activity - activity_code - initiator_id + - initiator_email - target_id - meta responses: diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 402aae635..33c935a68 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -164,6 +164,9 @@ type Event struct { // Id Event unique identifier Id string `json:"id"` + // InitiatorEmail The e-mail address of the initiator of the event. E.g., an e-mail of a user that triggered the event. + InitiatorEmail string `json:"initiator_email"` + // InitiatorId The ID of the initiator of the event. E.g., an ID of a user that triggered the event. InitiatorId string `json:"initiator_id"` diff --git a/management/server/http/events_handler.go b/management/server/http/events_handler.go index 1d1c176e5..cbca44364 100644 --- a/management/server/http/events_handler.go +++ b/management/server/http/events_handler.go @@ -45,14 +45,46 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { util.WriteError(err, w) return } - events := make([]*api.Event, 0) - for _, e := range accountEvents { - events = append(events, toEventResponse(e)) + events := make([]*api.Event, len(accountEvents)) + for i, e := range accountEvents { + events[i] = toEventResponse(e) + } + + err = h.fillEventsWithInitiatorEmail(events, account.Id, user.Id) + if err != nil { + util.WriteError(err, w) + return } util.WriteJSONObject(w, events) } +func (h *EventsHandler) fillEventsWithInitiatorEmail(events []*api.Event, accountId, userId string) error { + // build email map based on users + userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId) + if err != nil { + log.Errorf("failed to get users from account: %s", err) + return err + } + + emails := make(map[string]string) + for _, ui := range userInfos { + emails[ui.ID] = ui.Email + } + + // fill event with email of initiator + var ok bool + for _, event := range events { + if event.InitiatorEmail == "" { + event.InitiatorEmail, ok = emails[event.InitiatorId] + if !ok { + log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId) + } + } + } + return nil +} + func toEventResponse(event *activity.Event) *api.Event { meta := make(map[string]string) if event.Meta != nil { @@ -60,13 +92,15 @@ func toEventResponse(event *activity.Event) *api.Event { meta[s] = fmt.Sprintf("%v", a) } } - return &api.Event{ - Id: fmt.Sprint(event.ID), - InitiatorId: event.InitiatorID, - Activity: event.Activity.Message(), - ActivityCode: api.EventActivityCode(event.Activity.StringCode()), - TargetId: event.TargetID, - Timestamp: event.Timestamp, - Meta: meta, + e := &api.Event{ + Id: fmt.Sprint(event.ID), + InitiatorId: event.InitiatorID, + InitiatorEmail: event.InitiatorEmail, + Activity: event.Activity.Message(), + ActivityCode: api.EventActivityCode(event.Activity.StringCode()), + TargetId: event.TargetID, + Timestamp: event.Timestamp, + Meta: meta, } + return e } diff --git a/management/server/http/events_handler_test.go b/management/server/http/events_handler_test.go index a77e44f45..4cfad922b 100644 --- a/management/server/http/events_handler_test.go +++ b/management/server/http/events_handler_test.go @@ -37,6 +37,9 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E }, }, user, nil }, + GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) { + return make([]*server.UserInfo, 0), nil + }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index 44603059a..ddb1233bf 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -53,22 +53,6 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle Issued: server.GroupIssuedAPI, }, nil }, - UpdateGroupFunc: func(_ string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) { - var group server.Group - group.ID = groupID - for _, operation := range operations { - switch operation.Type { - case server.UpdateGroupName: - group.Name = operation.Values[0] - case server.UpdateGroupPeers, server.InsertPeersToGroup: - group.Peers = operation.Values - case server.RemovePeersFromGroup: - default: - return nil, fmt.Errorf("no operation") - } - } - return &group, nil - }, GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) { for _, peer := range TestPeers { if peer.IP.String() == peerIP { diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/nameservers_handler_test.go index 75fcb4c1c..100f4b87a 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/nameservers_handler_test.go @@ -88,31 +88,6 @@ func initNameserversTestData() *NameserversHandler { } return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) }, - UpdateNameServerGroupFunc: func(accountID, nsGroupID, _ string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) { - nsGroupToUpdate := baseExistingNSGroup.Copy() - if nsGroupID != nsGroupToUpdate.ID { - return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID) - } - for _, operation := range operations { - switch operation.Type { - case server.UpdateNameServerGroupName: - nsGroupToUpdate.Name = operation.Values[0] - case server.UpdateNameServerGroupDescription: - nsGroupToUpdate.Description = operation.Values[0] - case server.UpdateNameServerGroupNameServers: - var parsedNSList []nbdns.NameServer - for _, nsURL := range operation.Values { - parsed, err := nbdns.ParseNameServerURL(nsURL) - if err != nil { - return nil, err - } - parsedNSList = append(parsedNSList, parsed) - } - nsGroupToUpdate.NameServers = parsedNSList - } - } - return nsGroupToUpdate, nil - }, GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return testingNSAccount, testingAccount.Users["test_user"], nil }, diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go index c4270284c..3f2b7b910 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/routes_handler_test.go @@ -8,7 +8,6 @@ import ( "net/http" "net/http/httptest" "net/netip" - "strconv" "testing" "github.com/netbirdio/netbird/management/server/http/api" @@ -108,38 +107,6 @@ func initRoutesTestData() *RoutesHandler { IP: netip.MustParseAddr(existingPeerID).AsSlice(), }, nil }, - UpdateRouteFunc: func(_ string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) { - routeToUpdate := baseExistingRoute - if routeID != routeToUpdate.ID { - return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID) - } - for _, operation := range operations { - switch operation.Type { - case server.UpdateRouteNetwork: - routeToUpdate.NetworkType, routeToUpdate.Network, _ = route.ParseNetwork(operation.Values[0]) - case server.UpdateRouteDescription: - routeToUpdate.Description = operation.Values[0] - case server.UpdateRouteNetworkIdentifier: - routeToUpdate.NetID = operation.Values[0] - case server.UpdateRoutePeer: - routeToUpdate.Peer = operation.Values[0] - if routeToUpdate.Peer == notFoundPeerID { - return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", routeToUpdate.Peer) - } - case server.UpdateRouteMetric: - routeToUpdate.Metric, _ = strconv.Atoi(operation.Values[0]) - case server.UpdateRouteMasquerade: - routeToUpdate.Masquerade, _ = strconv.ParseBool(operation.Values[0]) - case server.UpdateRouteEnabled: - routeToUpdate.Enabled, _ = strconv.ParseBool(operation.Values[0]) - case server.UpdateRouteGroups: - routeToUpdate.Groups = operation.Values - default: - return nil, fmt.Errorf("no operation") - } - } - return routeToUpdate, nil - }, GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return testingAccount, testingAccount.Users["test_user"], nil }, diff --git a/management/server/idp/auth0.go b/management/server/idp/auth0.go index 64ec88e9f..d3802d8ad 100644 --- a/management/server/idp/auth0.go +++ b/management/server/idp/auth0.go @@ -513,7 +513,9 @@ func buildUserExportRequest() (string, error) { return string(str), nil } -func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) { +func (am *Auth0Manager) createRequest( + method string, endpoint string, body io.Reader, +) (*http.Request, error) { jwtToken, err := am.credentials.Authenticate() if err != nil { return nil, err @@ -521,17 +523,23 @@ func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (* reqURL := am.authIssuer + endpoint - payload := strings.NewReader(payloadStr) - - req, err := http.NewRequest("POST", reqURL, payload) + req, err := http.NewRequest(method, reqURL, body) if err != nil { return nil, err } req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + + return req, nil +} + +func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) { + req, err := am.createRequest("POST", endpoint, strings.NewReader(payloadStr)) + if err != nil { + return nil, err + } req.Header.Add("content-type", "application/json") return req, nil - } // GetAllAccounts gets all registered accounts with corresponding user data. @@ -737,6 +745,38 @@ func (am *Auth0Manager) InviteUserByID(userID string) error { return nil } +// DeleteUser from Auth0 +func (am *Auth0Manager) DeleteUser(userID string) error { + req, err := am.createRequest(http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil) + if err != nil { + return err + } + + resp, err := am.httpClient.Do(req) + if err != nil { + log.Debugf("execute delete request: %v", err) + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountRequestError() + } + return err + } + + defer func() { + err = resp.Body.Close() + if err != nil { + log.Errorf("close delete request body: %v", err) + } + }() + if resp.StatusCode != 204 { + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountRequestStatusError() + } + return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode) + } + + return nil +} + // checkExportJobStatus checks the status of the job created at CreateExportUsersJob. // If the status is "completed", then return the downloadLink func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) { diff --git a/management/server/idp/authentik.go b/management/server/idp/authentik.go index c57efef1a..cf6706854 100644 --- a/management/server/idp/authentik.go +++ b/management/server/idp/authentik.go @@ -458,6 +458,38 @@ func (am *AuthentikManager) InviteUserByID(_ string) error { return fmt.Errorf("method InviteUserByID not implemented") } +// DeleteUser from Authentik +func (am *AuthentikManager) DeleteUser(userID string) error { + ctx, err := am.authenticationContext() + if err != nil { + return err + } + + userPk, err := strconv.ParseInt(userID, 10, 32) + if err != nil { + return err + } + + resp, err := am.apiClient.CoreApi.CoreUsersDestroy(ctx, int32(userPk)).Execute() + if err != nil { + return err + } + defer resp.Body.Close() // nolint + + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountDeleteUser() + } + + if resp.StatusCode != http.StatusNoContent { + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountRequestStatusError() + } + return fmt.Errorf("unable to delete user %s, statusCode %d", userID, resp.StatusCode) + } + + return nil +} + func (am *AuthentikManager) authenticationContext() (context.Context, error) { jwtToken, err := am.credentials.Authenticate() if err != nil { diff --git a/management/server/idp/azure.go b/management/server/idp/azure.go index 7cff7d8fc..22e6825ae 100644 --- a/management/server/idp/azure.go +++ b/management/server/idp/azure.go @@ -454,6 +454,43 @@ func (am *AzureManager) InviteUserByID(_ string) error { return fmt.Errorf("method InviteUserByID not implemented") } +// DeleteUser from Azure +func (am *AzureManager) DeleteUser(userID string) error { + jwtToken, err := am.credentials.Authenticate() + if err != nil { + return err + } + + reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, url.QueryEscape(userID)) + req, err := http.NewRequest(http.MethodDelete, reqURL, nil) + if err != nil { + return err + } + req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + req.Header.Add("content-type", "application/json") + + log.Debugf("delete idp user %s", userID) + + resp, err := am.httpClient.Do(req) + if err != nil { + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountRequestError() + } + return err + } + defer resp.Body.Close() + + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountDeleteUser() + } + + if resp.StatusCode != http.StatusNoContent { + return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode) + } + + return nil +} + func (am *AzureManager) getUserExtensions() ([]azureExtension, error) { q := url.Values{} q.Add("$select", extensionFields) diff --git a/management/server/idp/google_workspace.go b/management/server/idp/google_workspace.go index 2e65497dc..40854e598 100644 --- a/management/server/idp/google_workspace.go +++ b/management/server/idp/google_workspace.go @@ -254,6 +254,19 @@ func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error { return fmt.Errorf("method InviteUserByID not implemented") } +// DeleteUser from GoogleWorkspace. +func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error { + if err := gm.usersService.Delete(userID).Do(); err != nil { + return err + } + + if gm.appMetrics != nil { + gm.appMetrics.IDPMetrics().CountDeleteUser() + } + + return nil +} + // getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey. // It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it. // If that fails, it falls back to using the default Google credentials path. diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index f3758743d..7e1064da1 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -26,6 +26,7 @@ type Manager interface { CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) GetUserByEmail(email string) ([]*UserData, error) InviteUserByID(userID string) error + DeleteUser(userID string) error } // ClientConfig defines common client configuration for all IdP manager diff --git a/management/server/idp/keycloak.go b/management/server/idp/keycloak.go index 12ed87389..d65a78ae3 100644 --- a/management/server/idp/keycloak.go +++ b/management/server/idp/keycloak.go @@ -467,6 +467,47 @@ func (km *KeycloakManager) InviteUserByID(_ string) error { return fmt.Errorf("method InviteUserByID not implemented") } +// DeleteUser from Keycloack +func (km *KeycloakManager) DeleteUser(userID string) error { + jwtToken, err := km.credentials.Authenticate() + if err != nil { + return err + } + + reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID)) + + req, err := http.NewRequest(http.MethodDelete, reqURL, nil) + if err != nil { + return err + } + req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + req.Header.Add("content-type", "application/json") + + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountDeleteUser() + } + + resp, err := km.httpClient.Do(req) + if err != nil { + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountRequestError() + } + return err + } + defer resp.Body.Close() // nolint + + // In the docs, they specified 200, but in the endpoints, they return 204 + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountRequestStatusError() + } + + return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode) + } + + return nil +} + func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) { attrs := keycloakUserAttributes{} attrs.Set(wtAccountID, appMetadata.WTAccountID) diff --git a/management/server/idp/okta.go b/management/server/idp/okta.go index c6b5055d4..0e93c494c 100644 --- a/management/server/idp/okta.go +++ b/management/server/idp/okta.go @@ -319,6 +319,28 @@ func (om *OktaManager) InviteUserByID(_ string) error { return fmt.Errorf("method InviteUserByID not implemented") } +// DeleteUser from Okta +func (om *OktaManager) DeleteUser(userID string) error { + resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil) + if err != nil { + fmt.Println(err.Error()) + return err + } + + if om.appMetrics != nil { + om.appMetrics.IDPMetrics().CountDeleteUser() + } + + if resp.StatusCode != http.StatusOK { + if om.appMetrics != nil { + om.appMetrics.IDPMetrics().CountRequestStatusError() + } + return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode) + } + + return nil +} + // updateUserProfileSchema updates the Okta user schema to include custom fields, // wt_account_id and wt_pending_invite. func updateUserProfileSchema(client *okta.Client) error { diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go index fce2c7b37..73958a69e 100644 --- a/management/server/idp/zitadel.go +++ b/management/server/idp/zitadel.go @@ -428,7 +428,7 @@ func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMe return err } - resource := fmt.Sprintf("users/%s/metadata/_bulk", userID) + resource := fmt.Sprintf("users/%s", userID) _, err = zm.post(resource, string(payload)) if err != nil { return err @@ -447,6 +447,21 @@ func (zm *ZitadelManager) InviteUserByID(_ string) error { return fmt.Errorf("method InviteUserByID not implemented") } +// DeleteUser from Zitadel +func (zm *ZitadelManager) DeleteUser(userID string) error { + resource := fmt.Sprintf("users/%s", userID) + if err := zm.delete(resource); err != nil { + return err + } + + if zm.appMetrics != nil { + zm.appMetrics.IDPMetrics().CountDeleteUser() + } + + return nil + +} + // getUserMetadata requests user metadata from zitadel via ID. func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) { resource := fmt.Sprintf("users/%s/metadata/_search", userID) @@ -500,6 +515,42 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) { return io.ReadAll(resp.Body) } +// delete perform Delete requests. +func (zm *ZitadelManager) delete(resource string) error { + jwtToken, err := zm.credentials.Authenticate() + if err != nil { + return err + } + + reqURL := fmt.Sprintf("%s/%s", zm.managementEndpoint, resource) + req, err := http.NewRequest(http.MethodDelete, reqURL, nil) + if err != nil { + return err + } + req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + req.Header.Add("content-type", "application/json") + + resp, err := zm.httpClient.Do(req) + if err != nil { + if zm.appMetrics != nil { + zm.appMetrics.IDPMetrics().CountRequestError() + } + + return err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + if zm.appMetrics != nil { + zm.appMetrics.IDPMetrics().CountRequestStatusError() + } + + return fmt.Errorf("unable to delete %s, statusCode %d", reqURL, resp.StatusCode) + } + + return nil +} + // get perform Get requests. func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) { jwtToken, err := zm.credentials.Authenticate() diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 66661dbf8..b4a527e46 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -412,7 +412,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) peersUpdateManager := NewPeersUpdateManager() eventStore := &activity.InMemoryEventStore{} accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "", - eventStore) + eventStore, false) if err != nil { return nil, "", err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 6c93765f4..fa35cfdef 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -503,7 +503,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) { peersUpdateManager := server.NewPeersUpdateManager() eventStore := &activity.InMemoryEventStore{} accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", - eventStore) + eventStore, false) if err != nil { log.Fatalf("failed creating a manager: %v", err) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 4bfa922c7..24bf9f3c9 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -31,7 +31,6 @@ type MockAccountManager struct { AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, *server.NetworkMap, error) GetGroupFunc func(accountID, groupID string) (*server.Group, error) SaveGroupFunc func(accountID, userID string, group *server.Group) error - UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) DeleteGroupFunc func(accountID, userId, groupID string) error ListGroupsFunc func(accountID string) ([]*server.Group, error) GroupAddPeerFunc func(accountID, groupID, peerKey string) error @@ -54,7 +53,6 @@ type MockAccountManager struct { CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error) SaveRouteFunc func(accountID, userID string, route *route.Route) error - UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) DeleteRouteFunc func(accountID, routeID, userID string) error ListRoutesFunc func(accountID, userID string) ([]*route.Route, error) SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) @@ -68,7 +66,6 @@ type MockAccountManager struct { GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error - UpdateNameServerGroupFunc func(accountID, nsGroupID, userID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error) CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) @@ -267,14 +264,6 @@ func (am *MockAccountManager) SaveGroup(accountID, userID string, group *server. return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented") } -// UpdateGroup mock implementation of UpdateGroup from server.AccountManager interface -func (am *MockAccountManager) UpdateGroup(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) { - if am.UpdateGroupFunc != nil { - return am.UpdateGroupFunc(accountID, groupID, operations) - } - return nil, status.Errorf(codes.Unimplemented, "method UpdateGroup not implemented") -} - // DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) error { if am.DeleteGroupFunc != nil { @@ -435,14 +424,6 @@ func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.R return status.Errorf(codes.Unimplemented, "method SaveRoute is not implemented") } -// UpdateRoute mock implementation of UpdateRoute from server.AccountManager interface -func (am *MockAccountManager) UpdateRoute(accountID, ruleID string, operations []server.RouteUpdateOperation) (*route.Route, error) { - if am.UpdateRouteFunc != nil { - return am.UpdateRouteFunc(accountID, ruleID, operations) - } - return nil, status.Errorf(codes.Unimplemented, "method UpdateRoute not implemented") -} - // DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface func (am *MockAccountManager) DeleteRoute(accountID, routeID, userID string) error { if am.DeleteRouteFunc != nil { @@ -533,14 +514,6 @@ func (am *MockAccountManager) SaveNameServerGroup(accountID, userID string, nsGr return nil } -// UpdateNameServerGroup mocks UpdateNameServerGroup of the AccountManager interface -func (am *MockAccountManager) UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) { - if am.UpdateNameServerGroupFunc != nil { - return am.UpdateNameServerGroupFunc(accountID, nsGroupID, userID, operations) - } - return nil, nil -} - // DeleteNameServerGroup mocks DeleteNameServerGroup of the AccountManager interface func (am *MockAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error { if am.DeleteNameServerGroupFunc != nil { diff --git a/management/server/nameserver.go b/management/server/nameserver.go index eb2127945..7025388ba 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -3,7 +3,6 @@ package server import ( "errors" "regexp" - "strconv" "unicode/utf8" "github.com/miekg/dns" @@ -15,54 +14,7 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) -const ( - // UpdateNameServerGroupName indicates a nameserver group name update operation - UpdateNameServerGroupName NameServerGroupUpdateOperationType = iota - // UpdateNameServerGroupDescription indicates a nameserver group description update operation - UpdateNameServerGroupDescription - // UpdateNameServerGroupNameServers indicates a nameserver group nameservers list update operation - UpdateNameServerGroupNameServers - // UpdateNameServerGroupGroups indicates a nameserver group' groups update operation - UpdateNameServerGroupGroups - // UpdateNameServerGroupEnabled indicates a nameserver group status update operation - UpdateNameServerGroupEnabled - // UpdateNameServerGroupPrimary indicates a nameserver group primary status update operation - UpdateNameServerGroupPrimary - // UpdateNameServerGroupDomains indicates a nameserver group' domains update operation - UpdateNameServerGroupDomains - - domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` -) - -// NameServerGroupUpdateOperationType operation type -type NameServerGroupUpdateOperationType int - -func (t NameServerGroupUpdateOperationType) String() string { - switch t { - case UpdateNameServerGroupDescription: - return "UpdateNameServerGroupDescription" - case UpdateNameServerGroupName: - return "UpdateNameServerGroupName" - case UpdateNameServerGroupNameServers: - return "UpdateNameServerGroupNameServers" - case UpdateNameServerGroupGroups: - return "UpdateNameServerGroupGroups" - case UpdateNameServerGroupEnabled: - return "UpdateNameServerGroupEnabled" - case UpdateNameServerGroupPrimary: - return "UpdateNameServerGroupPrimary" - case UpdateNameServerGroupDomains: - return "UpdateNameServerGroupDomains" - default: - return "InvalidOperation" - } -} - -// NameServerGroupUpdateOperation operation object with type and values to be applied -type NameServerGroupUpdateOperation struct { - Type NameServerGroupUpdateOperationType - Values []string -} +const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { @@ -172,109 +124,6 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n return nil } -// UpdateNameServerGroup updates existing nameserver group with set of operations -func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) { - - unlock := am.Store.AcquireAccountLock(accountID) - defer unlock() - - account, err := am.Store.GetAccount(accountID) - if err != nil { - return nil, err - } - - if len(operations) == 0 { - return nil, status.Errorf(status.InvalidArgument, "operations shouldn't be empty") - } - - nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID] - if !ok { - return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID) - } - - newNSGroup := nsGroupToUpdate.Copy() - - for _, operation := range operations { - valuesCount := len(operation.Values) - if valuesCount < 1 { - return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String()) - } - - for _, value := range operation.Values { - if value == "" { - return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String()) - } - } - switch operation.Type { - case UpdateNameServerGroupDescription: - newNSGroup.Description = operation.Values[0] - case UpdateNameServerGroupName: - if valuesCount > 1 { - return nil, status.Errorf(status.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount) - } - err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups) - if err != nil { - return nil, err - } - newNSGroup.Name = operation.Values[0] - case UpdateNameServerGroupNameServers: - var nsList []nbdns.NameServer - for _, url := range operation.Values { - ns, err := nbdns.ParseNameServerURL(url) - if err != nil { - return nil, err - } - nsList = append(nsList, ns) - } - err = validateNSList(nsList) - if err != nil { - return nil, err - } - newNSGroup.NameServers = nsList - case UpdateNameServerGroupGroups: - err = validateGroups(operation.Values, account.Groups) - if err != nil { - return nil, err - } - newNSGroup.Groups = operation.Values - case UpdateNameServerGroupEnabled: - enabled, err := strconv.ParseBool(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0]) - } - newNSGroup.Enabled = enabled - case UpdateNameServerGroupPrimary: - primary, err := strconv.ParseBool(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0]) - } - newNSGroup.Primary = primary - case UpdateNameServerGroupDomains: - err = validateDomainInput(false, operation.Values) - if err != nil { - return nil, err - } - newNSGroup.Domains = operation.Values - } - } - - account.NameServerGroups[nsGroupID] = newNSGroup - - account.Network.IncSerial() - err = am.Store.SaveAccount(account) - if err != nil { - return nil, err - } - - err = am.updateAccountPeers(account) - if err != nil { - log.Error(err) - return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after update nameserver %s", newNSGroup.Name) - } - - return newNSGroup.Copy(), nil -} - // DeleteNameServerGroup deletes nameserver group with nsGroupID func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error { diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 9d4425056..26977116b 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -655,323 +655,6 @@ func TestSaveNameServerGroup(t *testing.T) { } } -func TestUpdateNameServerGroup(t *testing.T) { - nsGroupID := "testingNSGroup" - - existingNSGroup := &nbdns.NameServerGroup{ - ID: nsGroupID, - Name: "super", - Description: "super", - Primary: true, - NameServers: []nbdns.NameServer{ - { - IP: netip.MustParseAddr("1.1.1.1"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - { - IP: netip.MustParseAddr("1.1.2.2"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - }, - Groups: []string{group1ID}, - Enabled: true, - } - - testCases := []struct { - name string - existingNSGroup *nbdns.NameServerGroup - nsGroupID string - operations []NameServerGroupUpdateOperation - shouldCreate bool - errFunc require.ErrorAssertionFunc - expectedNSGroup *nbdns.NameServerGroup - }{ - { - name: "Should Config Single Property", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{"superNew"}, - }, - }, - errFunc: require.NoError, - shouldCreate: true, - expectedNSGroup: &nbdns.NameServerGroup{ - ID: nsGroupID, - Name: "superNew", - Description: "super", - Primary: true, - NameServers: []nbdns.NameServer{ - { - IP: netip.MustParseAddr("1.1.1.1"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - { - IP: netip.MustParseAddr("1.1.2.2"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - }, - Groups: []string{group1ID}, - Enabled: true, - }, - }, - { - name: "Should Config Multiple Properties", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{"superNew"}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupDescription, - Values: []string{"superDescription"}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupNameServers, - Values: []string{"udp://127.0.0.1:53", "udp://8.8.8.8:53"}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupGroups, - Values: []string{group1ID, group2ID}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupEnabled, - Values: []string{"false"}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupPrimary, - Values: []string{"false"}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupDomains, - Values: []string{validDomain}, - }, - }, - errFunc: require.NoError, - shouldCreate: true, - expectedNSGroup: &nbdns.NameServerGroup{ - ID: nsGroupID, - Name: "superNew", - Description: "superDescription", - Primary: false, - Domains: []string{validDomain}, - NameServers: []nbdns.NameServer{ - { - IP: netip.MustParseAddr("127.0.0.1"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - { - IP: netip.MustParseAddr("8.8.8.8"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - }, - Groups: []string{group1ID, group2ID}, - Enabled: false, - }, - }, - { - name: "Should Not Config On Invalid ID", - existingNSGroup: existingNSGroup, - nsGroupID: "nonExistingNSGroup", - errFunc: require.Error, - }, - { - name: "Should Not Config On Empty Operations", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{}, - errFunc: require.Error, - }, - { - name: "Should Not Config On Empty Values", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Empty String", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{""}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Name Large String", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{"12345678901234567890qwertyuiopqwertyuiop1"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid On Existing Name", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{existingNSGroupName}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid On Multiple Name Values", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{"nameOne", "nameTwo"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Boolean", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupEnabled, - Values: []string{"yes"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Nameservers Wrong Schema", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupNameServers, - Values: []string{"https://127.0.0.1:53"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Nameservers Wrong IP", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupNameServers, - Values: []string{"udp://8.8.8.300:53"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Large Number Of Nameservers", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupNameServers, - Values: []string{"udp://127.0.0.1:53", "udp://8.8.8.8:53", "udp://8.8.4.4:53"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid GroupID", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupGroups, - Values: []string{"nonExistingGroupID"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Domains", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupDomains, - Values: []string{invalidDomain}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Primary Status", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupPrimary, - Values: []string{"yes"}, - }, - }, - errFunc: require.Error, - }, - } - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - am, err := createNSManager(t) - if err != nil { - t.Error("failed to create account manager") - } - - account, err := initTestNSAccount(t, am) - if err != nil { - t.Error("failed to init testing account") - } - - account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup - - err = am.Store.SaveAccount(account) - if err != nil { - t.Error("account should be saved") - } - - updatedRoute, err := am.UpdateNameServerGroup(account.Id, testCase.nsGroupID, userID, testCase.operations) - testCase.errFunc(t, err) - - if !testCase.shouldCreate { - return - } - - testCase.expectedNSGroup.ID = updatedRoute.ID - - if !testCase.expectedNSGroup.IsEqual(updatedRoute) { - t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", updatedRoute, testCase.expectedNSGroup) - } - - }) - } -} - func TestDeleteNameServerGroup(t *testing.T) { nsGroupID := "testingNSGroup" @@ -1061,7 +744,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore) + return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore, false) } func createNSStore(t *testing.T) (Store, error) { diff --git a/management/server/route.go b/management/server/route.go index f51b7c2db..b232c2bb6 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -2,7 +2,6 @@ package server import ( "net/netip" - "strconv" "unicode/utf8" "github.com/netbirdio/netbird/management/proto" @@ -13,57 +12,6 @@ import ( log "github.com/sirupsen/logrus" ) -const ( - // UpdateRouteDescription indicates a route description update operation - UpdateRouteDescription RouteUpdateOperationType = iota - // UpdateRouteNetwork indicates a route IP update operation - UpdateRouteNetwork - // UpdateRoutePeer indicates a route peer update operation - UpdateRoutePeer - // UpdateRouteMetric indicates a route metric update operation - UpdateRouteMetric - // UpdateRouteMasquerade indicates a route masquerade update operation - UpdateRouteMasquerade - // UpdateRouteEnabled indicates a route enabled update operation - UpdateRouteEnabled - // UpdateRouteNetworkIdentifier indicates a route net ID update operation - UpdateRouteNetworkIdentifier - // UpdateRouteGroups indicates a group list update operation - UpdateRouteGroups -) - -// RouteUpdateOperationType operation type -type RouteUpdateOperationType int - -func (t RouteUpdateOperationType) String() string { - switch t { - case UpdateRouteDescription: - return "UpdateRouteDescription" - case UpdateRouteNetwork: - return "UpdateRouteNetwork" - case UpdateRoutePeer: - return "UpdateRoutePeer" - case UpdateRouteMetric: - return "UpdateRouteMetric" - case UpdateRouteMasquerade: - return "UpdateRouteMasquerade" - case UpdateRouteEnabled: - return "UpdateRouteEnabled" - case UpdateRouteNetworkIdentifier: - return "UpdateRouteNetworkIdentifier" - case UpdateRouteGroups: - return "UpdateRouteGroups" - default: - return "InvalidOperation" - } -} - -// RouteUpdateOperation operation object with type and values to be applied -type RouteUpdateOperation struct { - Type RouteUpdateOperationType - Values []string -} - // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) { unlock := am.Store.AcquireAccountLock(accountID) @@ -241,109 +189,6 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave return nil } -// UpdateRoute updates existing route with set of operations -func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error) { - unlock := am.Store.AcquireAccountLock(accountID) - defer unlock() - - account, err := am.Store.GetAccount(accountID) - if err != nil { - return nil, err - } - - routeToUpdate, ok := account.Routes[routeID] - if !ok { - return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID) - } - - newRoute := routeToUpdate.Copy() - - for _, operation := range operations { - - if len(operation.Values) != 1 { - return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be 1", operation.Type.String()) - } - - switch operation.Type { - case UpdateRouteDescription: - newRoute.Description = operation.Values[0] - case UpdateRouteNetworkIdentifier: - if utf8.RuneCountInString(operation.Values[0]) > route.MaxNetIDChar || operation.Values[0] == "" { - return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) - } - newRoute.NetID = operation.Values[0] - case UpdateRouteNetwork: - prefixType, prefix, err := route.ParseNetwork(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", operation.Values[0]) - } - err = am.checkPrefixPeerExists(accountID, routeToUpdate.Peer, prefix) - if err != nil { - return nil, err - } - newRoute.Network = prefix - newRoute.NetworkType = prefixType - case UpdateRoutePeer: - if operation.Values[0] != "" { - peer := account.GetPeer(operation.Values[0]) - if peer == nil { - return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", operation.Values[0]) - } - } - - err = am.checkPrefixPeerExists(accountID, operation.Values[0], routeToUpdate.Network) - if err != nil { - return nil, err - } - newRoute.Peer = operation.Values[0] - case UpdateRouteMetric: - metric, err := strconv.Atoi(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, not int", operation.Values[0]) - } - if metric < route.MinMetric || metric > route.MaxMetric { - return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, value should be %d > N < %d", - operation.Values[0], - route.MinMetric, - route.MaxMetric, - ) - } - newRoute.Metric = metric - case UpdateRouteMasquerade: - masquerade, err := strconv.ParseBool(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse masquerade %s, not boolean", operation.Values[0]) - } - newRoute.Masquerade = masquerade - case UpdateRouteEnabled: - enabled, err := strconv.ParseBool(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0]) - } - newRoute.Enabled = enabled - case UpdateRouteGroups: - err = validateGroups(operation.Values, account.Groups) - if err != nil { - return nil, err - } - newRoute.Groups = operation.Values - } - } - - account.Routes[routeID] = newRoute - - account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { - return nil, err - } - - err = am.updateAccountPeers(account) - if err != nil { - return nil, status.Errorf(status.Internal, "failed to update account peers") - } - return newRoute, nil -} - // DeleteRoute deletes route with routeID func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error { unlock := am.Store.AcquireAccountLock(accountID) diff --git a/management/server/route_test.go b/management/server/route_test.go index c943aee0b..81ce21a3f 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -524,265 +524,6 @@ func TestSaveRoute(t *testing.T) { } } -func TestUpdateRoute(t *testing.T) { - routeID := "testingRouteID" - - existingRoute := &route.Route{ - ID: routeID, - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superRoute", - NetworkType: route.IPv4Network, - Peer: peer1ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, - } - - testCases := []struct { - name string - existingRoute *route.Route - operations []RouteUpdateOperation - shouldCreate bool - errFunc require.ErrorAssertionFunc - expectedRoute *route.Route - }{ - { - name: "Happy Path Single OPS", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRoutePeer, - Values: []string{peer2ID}, - }, - }, - errFunc: require.NoError, - shouldCreate: true, - expectedRoute: &route.Route{ - ID: routeID, - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superRoute", - NetworkType: route.IPv4Network, - Peer: peer2ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, - }, - }, - { - name: "Happy Path Multiple OPS", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteDescription, - Values: []string{"great"}, - }, - { - Type: UpdateRouteNetwork, - Values: []string{"192.168.0.0/24"}, - }, - { - Type: UpdateRoutePeer, - Values: []string{peer2ID}, - }, - { - Type: UpdateRouteMetric, - Values: []string{"3030"}, - }, - { - Type: UpdateRouteMasquerade, - Values: []string{"true"}, - }, - { - Type: UpdateRouteEnabled, - Values: []string{"false"}, - }, - { - Type: UpdateRouteNetworkIdentifier, - Values: []string{"megaRoute"}, - }, - { - Type: UpdateRouteGroups, - Values: []string{routeGroup2}, - }, - }, - errFunc: require.NoError, - shouldCreate: true, - expectedRoute: &route.Route{ - ID: routeID, - Network: netip.MustParsePrefix("192.168.0.0/24"), - NetID: "megaRoute", - NetworkType: route.IPv4Network, - Peer: peer2ID, - Description: "great", - Masquerade: true, - Metric: 3030, - Enabled: false, - Groups: []string{routeGroup2}, - }, - }, - { - name: "Empty Values Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRoutePeer, - }, - }, - errFunc: require.Error, - }, - { - name: "Multiple Values Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRoutePeer, - Values: []string{peer2ID, peer1ID}, - }, - }, - errFunc: require.Error, - }, - { - name: "Bad Prefix Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteNetwork, - Values: []string{"192.168.0.0/34"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Bad Peer Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRoutePeer, - Values: []string{"non existing Peer"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Empty Peer", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRoutePeer, - Values: []string{""}, - }, - }, - errFunc: require.NoError, - shouldCreate: true, - expectedRoute: &route.Route{ - ID: routeID, - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superRoute", - NetworkType: route.IPv4Network, - Peer: "", - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, - }, - }, - { - name: "Large Network ID Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteNetworkIdentifier, - Values: []string{"12345678901234567890qwertyuiopqwertyuiop1"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Empty Network ID Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteNetworkIdentifier, - Values: []string{""}, - }, - }, - errFunc: require.Error, - }, - { - name: "Invalid Metric Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteMetric, - Values: []string{"999999"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Invalid Boolean Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteMasquerade, - Values: []string{"yes"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Invalid Group Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteGroups, - Values: []string{routeInvalidGroup1}, - }, - }, - errFunc: require.Error, - }, - } - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - am, err := createRouterManager(t) - if err != nil { - t.Error("failed to create account manager") - } - - account, err := initTestRouteAccount(t, am) - if err != nil { - t.Error("failed to init testing account") - } - - account.Routes[testCase.existingRoute.ID] = testCase.existingRoute - - err = am.Store.SaveAccount(account) - if err != nil { - t.Error("account should be saved") - } - - updatedRoute, err := am.UpdateRoute(account.Id, testCase.existingRoute.ID, testCase.operations) - - testCase.errFunc(t, err) - - if !testCase.shouldCreate { - return - } - - testCase.expectedRoute.ID = updatedRoute.ID - - if !testCase.expectedRoute.IsEqual(updatedRoute) { - t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", updatedRoute, testCase.expectedRoute) - } - }) - } -} - func TestDeleteRoute(t *testing.T) { testingRoute := &route.Route{ ID: "testingRoute", @@ -940,7 +681,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore) + return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore, false) } func createRouterStore(t *testing.T) (Store, error) { diff --git a/management/server/telemetry/idp_metrics.go b/management/server/telemetry/idp_metrics.go index 67a1d9e85..e9eee17bd 100644 --- a/management/server/telemetry/idp_metrics.go +++ b/management/server/telemetry/idp_metrics.go @@ -2,6 +2,7 @@ package telemetry import ( "context" + "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric/instrument" "go.opentelemetry.io/otel/metric/instrument/syncint64" @@ -13,6 +14,7 @@ type IDPMetrics struct { getUserByEmailCounter syncint64.Counter getAllAccountsCounter syncint64.Counter createUserCounter syncint64.Counter + deleteUserCounter syncint64.Counter getAccountCounter syncint64.Counter getUserByIDCounter syncint64.Counter authenticateRequestCounter syncint64.Counter @@ -39,6 +41,10 @@ func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error) if err != nil { return nil, err } + deleteUserCounter, err := meter.SyncInt64().Counter("management.idp.delete.user.counter", instrument.WithUnit("1")) + if err != nil { + return nil, err + } getAccountCounter, err := meter.SyncInt64().Counter("management.idp.get.account.counter", instrument.WithUnit("1")) if err != nil { return nil, err @@ -65,6 +71,7 @@ func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error) getUserByEmailCounter: getUserByEmailCounter, getAllAccountsCounter: getAllAccountsCounter, createUserCounter: createUserCounter, + deleteUserCounter: deleteUserCounter, getAccountCounter: getAccountCounter, getUserByIDCounter: getUserByIDCounter, authenticateRequestCounter: authenticateRequestCounter, @@ -88,6 +95,11 @@ func (idpMetrics *IDPMetrics) CountCreateUser() { idpMetrics.createUserCounter.Add(idpMetrics.ctx, 1) } +// CountDeleteUser ... +func (idpMetrics *IDPMetrics) CountDeleteUser() { + idpMetrics.deleteUserCounter.Add(idpMetrics.ctx, 1) +} + // CountGetAllAccounts ... func (idpMetrics *IDPMetrics) CountGetAllAccounts() { idpMetrics.getAllAccountsCounter.Add(idpMetrics.ctx, 1) diff --git a/management/server/user.go b/management/server/user.go index 8ee036df7..ebebe1e0f 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -327,15 +327,43 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t return status.Errorf(status.NotFound, "user not found") } if executingUser.Role != UserRoleAdmin { - return status.Errorf(status.PermissionDenied, "only admins can delete service users") + return status.Errorf(status.PermissionDenied, "only admins can delete users") } - if !targetUser.IsServiceUser { - return status.Errorf(status.PermissionDenied, "regular users can not be deleted") + peers, err := account.FindUserPeers(targetUserID) + if err != nil { + return status.Errorf(status.Internal, "failed to find user peers") } - meta := map[string]any{"name": targetUser.ServiceUserName} - am.storeEvent(initiatorUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta) + if err := am.expireAndUpdatePeers(account, peers); err != nil { + log.Errorf("failed update deleted peers expiration: %s", err) + return err + } + + targetUserEmail, err := am.getEmailOfTargetUser(account.Id, initiatorUserID, targetUserID) + if err != nil { + log.Errorf("failed to resolve email address: %s", err) + return err + } + + var meta map[string]any + var eventAction activity.Activity + if targetUser.IsServiceUser { + meta = map[string]any{"name": targetUser.ServiceUserName} + eventAction = activity.ServiceUserDeleted + } else { + meta = map[string]any{"email": targetUserEmail} + eventAction = activity.UserDeleted + + } + am.storeEvent(initiatorUserID, targetUserID, accountID, eventAction, meta) + + if !isNil(am.idpManager) { + err := am.deleteUserFromIDP(targetUserID, accountID) + if err != nil { + return err + } + } delete(account.Users, targetUserID) @@ -609,23 +637,10 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd if err != nil { return nil, err } - var peerIDs []string - for _, peer := range blockedPeers { - peerIDs = append(peerIDs, peer.ID) - peer.MarkLoginExpired(true) - account.UpdatePeer(peer) - err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status) - if err != nil { - log.Errorf("failed saving peer status while expiring peer %s", peer.ID) - return nil, err - } - } - am.peersUpdateManager.CloseChannels(peerIDs) - err = am.updateAccountPeers(account) - if err != nil { - log.Errorf("failed updating account peers while expiring peers of a blocked user %s", accountID) - return nil, err + if err := am.expireAndUpdatePeers(account, blockedPeers); err != nil { + log.Errorf("failed update expired peers: %s", err) + return nil, err } } @@ -814,6 +829,67 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( return userInfos, nil } +// expireAndUpdatePeers expires all peers of the given user and updates them in the account +func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers []*Peer) error { + var peerIDs []string + for _, peer := range peers { + peerIDs = append(peerIDs, peer.ID) + peer.MarkLoginExpired(true) + account.UpdatePeer(peer) + if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil { + return err + } + am.storeEvent( + peer.UserID, peer.ID, account.Id, + activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()), + ) + } + + if len(peerIDs) != 0 { + // this will trigger peer disconnect from the management service + am.peersUpdateManager.CloseChannels(peerIDs) + if err := am.updateAccountPeers(account); err != nil { + return err + } + } + return nil +} + +func (am *DefaultAccountManager) deleteUserFromIDP(targetUserID, accountID string) error { + if am.userDeleteFromIDPEnabled { + log.Debugf("user %s deleted from IdP", targetUserID) + err := am.idpManager.DeleteUser(targetUserID) + if err != nil { + return fmt.Errorf("failed to delete user %s from IdP: %s", targetUserID, err) + } + } else { + err := am.idpManager.UpdateUserAppMetadata(targetUserID, idp.AppMetadata{}) + if err != nil { + return fmt.Errorf("failed to remove user %s app metadata in IdP: %s", targetUserID, err) + } + + _, err = am.refreshCache(accountID) + if err != nil { + log.Errorf("refresh account (%q) cache: %v", accountID, err) + } + } + return nil +} + +func (am *DefaultAccountManager) getEmailOfTargetUser(accountId string, initiatorId, targetId string) (string, error) { + userInfos, err := am.GetUsersFromAccount(accountId, initiatorId) + if err != nil { + return "", err + } + for _, ui := range userInfos { + if ui.ID == targetId { + return ui.Email, nil + } + } + + return "", fmt.Errorf("email not found for user: %s", targetId) +} + func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { for _, user := range userData { if user.ID == userID { diff --git a/management/server/user_test.go b/management/server/user_test.go index b07154663..bd64074b9 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -439,8 +439,9 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { } err = am.DeleteUser(mockAccountID, mockUserID, mockUserID) - - assert.Errorf(t, err, "Regular users can not be deleted (yet)") + if err != nil { + t.Errorf("unexpected error: %s", err) + } } func TestDefaultAccountManager_GetUser(t *testing.T) {