Compare commits

...

11 Commits

Author SHA1 Message Date
Givi Khojanashvili
d4b6d7646c Handle user delete (#1113)
Implement user deletion across all IDP-ss. Expires all user peers
when the user is deleted. Users are permanently removed from a local
store, but in IDP, we remove Netbird attributes for the user
untilUserDeleteFromIDPEnabled setting is not enabled.

To test, an admin user should remove any additional users.

Until the UI incorporates this feature, use a curl DELETE request
targeting the /users/<USER_ID> management endpoint. Note that this
request only removes user attributes and doesn't trigger a delete
from the IDP.

To enable user removal from the IdP, set UserDeleteFromIDPEnabled
to true in account settings. Until we have a UI for this, make this
change directly in the store file.

Store the deleted email addresses in encrypted in activity store.
2023-09-19 18:08:40 +02:00
Bethuel Mmbaga
8febab4076 Improve Client Authentication (#1135)
* shutdown the pkce server on user cancellation

* Refactor openURL to exclusively manage authentication flow instructions and browser launching

* Refactor authentication flow initialization based on client OS

The NewOAuthFlow method now first checks the operating system and if it is a non-desktop Linux, it opts for Device Code Flow. PKCEFlow is tried first and if it fails, then it falls back on Device Code Flow. If both unsuccessful, the authentication process halts and error messages have been updated to provide more helpful feedback for troubleshooting authentication errors

* Replace log-based Linux desktop check with process check

To verify if a Linux OS is running a desktop environment in the Authentication utility, the log-based method that checks the XDG_CURRENT_DESKTOP env has been replaced with a method that checks directly if either X or Wayland display server processes are running. This method is more reliable as it directly checks for the display server process rather than relying on an environment variable that may not be set in all desktop environments.

* Refactor PKCE Authorization Flow to improve server handling

* refactor check for linux running desktop environment

* Improve server shutdown handling and encapsulate handlers with new server multiplexer

The changes enhance the way the server shuts down by specifying a context with timeout of 5 seconds, adding a safeguard to ensure the server halts even on potential hanging requests. Also, the server's root handler is now encapsulated within a new ServeMux instance, to support multiple registrations of a path
2023-09-19 19:06:18 +03:00
Zoltan Papp
34e2c6b943 Fix sso check (#1152)
Fix SSO check

- change the order of the PKCE and device auth flow check, prefer PKCE
- fix error handling in PKCE check
2023-09-18 16:04:53 +02:00
Yury Gargay
0be8c72601 Remove unused methods from AccountManager interface (#1149)
This PR removes the following unused methods from the AccountManager interface:
* `UpdateGroup`
* `UpdateNameServerGroup`
* `UpdateRoute`
2023-09-18 12:25:12 +02:00
Maycon Santos
c34e53477f Add signal port tests to CI workflow (#1148) 2023-09-14 17:01:14 +02:00
Fabio Fantoni
8d18190c94 fix NETBIRD_SIGNAL_PORT not working with custom port (#1143) (#1145)
Use NETBIRD_SIGNAL_PORT variable instead of the static port for signal
container in the docker-compose template to make setting of custom
signal port working

Signed-off-by: Fabio Fantoni <fabio.fantoni@m2r.biz>
2023-09-14 15:58:28 +02:00
Zoltan Papp
06bec61be9 Add Android test build (#1144)
Extend the CI with gomobile build.
With this step we can validate that the code can run on Android
2023-09-13 17:58:12 +02:00
Zoltan Papp
2135533f1d Fix Android build (#1142)
The source code files related to the Android firewall had incorrect build tags.
2023-09-13 17:36:24 +02:00
Bethuel Mmbaga
bb791d59f3 update check for linux running desktop (#1137) 2023-09-08 20:08:02 +02:00
Maycon Santos
30f1c54ed1 Fix: docker test for infrastructure files (#1136)
* Fix: docker test for infrastructure files

* Fix: docker test for infrastructure files
2023-09-08 19:28:34 +02:00
Maycon Santos
5c8541ef42 Set not found ebpf log to Info (#1134)
added an additional log event
2023-09-08 18:24:19 +02:00
55 changed files with 898 additions and 1317 deletions

View File

@@ -0,0 +1,41 @@
name: Android build validation
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v4
with:
go-version: "1.20.x"
- name: Setup Android SDK
uses: android-actions/setup-android@v2
- name: NDK Cache
id: ndk-cache
uses: actions/cache@v3
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
- name: Setup NDK
run: /usr/local/lib/android/sdk/tools/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
- name: gomobile init
run: gomobile init
- name: build android nebtird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
env:
CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620

View File

@@ -80,6 +80,7 @@ jobs:
CI_NETBIRD_MGMT_IDP: "none"
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345
run: |
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
@@ -91,6 +92,7 @@ jobs:
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
grep $CI_NETBIRD_SIGNAL_PORT docker-compose.yml | grep ':80'
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
@@ -120,7 +122,7 @@ jobs:
- name: test running containers
run: |
count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running)
count=$(docker compose ps --format json | jq '. | select(.Name | contains("infrastructure_files")) | .State' | grep -c running)
test $count -eq 4
working-directory: infrastructure_files

View File

@@ -84,10 +84,14 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}

View File

@@ -3,8 +3,6 @@ package cmd
import (
"context"
"fmt"
"os"
"runtime"
"strings"
"time"
@@ -195,60 +193,12 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
}
browserAuthMsg := "Please do the SSO login in your browser. \n" +
cmd.Println("Please do the SSO login in your browser. \n" +
"If your browser didn't open automatically, use this URL to log in:\n\n" +
verificationURIComplete + " " + codeMsg
setupKeyAuthMsg := "\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys"
authenticateUsingBrowser := func() {
cmd.Println(browserAuthMsg)
cmd.Println("")
if err := open.Run(verificationURIComplete); err != nil {
cmd.Println(setupKeyAuthMsg)
}
}
switch runtime.GOOS {
case "windows", "darwin":
authenticateUsingBrowser()
case "linux":
if isLinuxRunningDesktop() {
authenticateUsingBrowser()
} else {
// If current flow is PKCE, it implies the server is anticipating the redirect to localhost.
// Devices lacking browser support are incompatible with this flow.Therefore,
// these devices will need to resort to setup keys instead.
if isPKCEFlow(verificationURIComplete) {
cmd.Println("Please proceed with setting up this device using setup keys, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
} else {
cmd.Println(browserAuthMsg)
}
}
verificationURIComplete + " " + codeMsg)
cmd.Println("")
if err := open.Run(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
}
}
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment.
func isLinuxRunningDesktop() bool {
for _, env := range os.Environ() {
values := strings.Split(env, "=")
if len(values) == 2 {
key, value := values[0], values[1]
if key == "XDG_CURRENT_DESKTOP" && value != "" {
return true
}
}
}
return false
}
// isPKCEFlow determines if the PKCE flow is active or not,
// by checking the existence of redirect_uri inside the verification URL.
func isPKCEFlow(verificationURL string) bool {
if verificationURL == "" {
return false
}
return strings.Contains(verificationURL, "redirect_uri")
}

View File

@@ -76,7 +76,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
return nil, nil
}
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore)
eventStore, false)
if err != nil {
t.Fatal(err)
}

View File

@@ -1,4 +1,4 @@
//go:build !linux
//go:build !linux || android
package acl

View File

@@ -1,3 +1,5 @@
//go:build !android
package acl
import (

View File

@@ -4,8 +4,8 @@ import (
"context"
"fmt"
"net/http"
"runtime"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
@@ -57,25 +57,43 @@ func (t TokenInfo) GetTokenToUse() string {
return t.AccessToken
}
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration.
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
//
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
// and if that also fails, the authentication process is deemed unsuccessful
//
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
log.Debug("loading pkce authorization flow info")
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err == nil {
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
if runtime.GOOS == "linux" && !isLinuxRunningDesktop() {
return authenticateWithDeviceCodeFlow(ctx, config)
}
log.Debugf("loading pkce authorization flow info failed with error: %v", err)
log.Debugf("falling back to device authorization flow info")
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
if err != nil {
// fallback to device code flow
return authenticateWithDeviceCodeFlow(ctx, config)
}
return pkceFlow, nil
}
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
}
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
s, ok := gstatus.FromError(err)
if ok && s.Code() == codes.NotFound {
return nil, fmt.Errorf("no SSO provider returned from management. " +
"If you are using hosting Netbird see documentation at " +
"https://github.com/netbirdio/netbird/tree/main/management for details")
"Please proceed with setting up this device using setup keys " +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
} else if ok && s.Code() == codes.Unimplemented {
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
"please update your server or use Setup Keys to login", config.ManagementURL)

View File

@@ -12,7 +12,6 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
@@ -80,7 +79,7 @@ func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
}
// RequestAuthInfo requests a authorization code login flow information.
func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) {
func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
state, err := randomBytesInHex(24)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
@@ -114,64 +113,37 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
tokenChan := make(chan *oauth2.Token, 1)
errChan := make(chan error, 1)
go p.startServer(tokenChan, errChan)
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
if err != nil {
return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err)
}
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
defer func() {
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
log.Errorf("failed to close the server: %v", err)
}
}()
go p.startServer(server, tokenChan, errChan)
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case token := <-tokenChan:
return p.handleOAuthToken(token)
return p.parseOAuthToken(token)
case err := <-errChan:
return TokenInfo{}, err
}
}
func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) {
var wg sync.WaitGroup
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
if err != nil {
errChan <- fmt.Errorf("failed to parse redirect URL: %v", err)
return
}
server := http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
go func() {
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
errChan <- err
}
}()
wg.Add(1)
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
defer wg.Done()
tokenValidatorFunc := func() (*oauth2.Token, error) {
query := req.URL.Query()
if authError := query.Get(queryError); authError != "" {
authErrorDesc := query.Get(queryErrorDesc)
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
}
// Prevent timing attacks on state
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
return nil, fmt.Errorf("invalid state")
}
code := query.Get(queryCode)
if code == "" {
return nil, fmt.Errorf("missing code")
}
return p.oAuthConfig.Exchange(
req.Context(),
code,
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
)
}
token, err := tokenValidatorFunc()
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
token, err := p.handleRequest(req)
if err != nil {
renderPKCEFlowTmpl(w, err)
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
@@ -182,13 +154,38 @@ func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errC
tokenChan <- token
})
wg.Wait()
if err := server.Shutdown(context.Background()); err != nil {
log.Errorf("error while shutting down pkce flow server: %v", err)
server.Handler = mux
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
errChan <- err
}
}
func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) {
func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) {
query := req.URL.Query()
if authError := query.Get(queryError); authError != "" {
authErrorDesc := query.Get(queryErrorDesc)
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
}
// Prevent timing attacks on the state
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
return nil, fmt.Errorf("invalid state")
}
code := query.Get(queryCode)
if code == "" {
return nil, fmt.Errorf("missing code")
}
return p.oAuthConfig.Exchange(
req.Context(),
code,
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
)
}
func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) {
tokenInfo := TokenInfo{
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,

View File

@@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"os"
"reflect"
"strings"
)
@@ -60,3 +61,8 @@ func isValidAccessToken(token string, audience string) error {
return fmt.Errorf("invalid JWT token audience field")
}
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
func isLinuxRunningDesktop() bool {
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
}

View File

@@ -1,3 +1,3 @@
//go:build !linux
//go:build !linux || android
package checkfw

View File

@@ -1049,7 +1049,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
return nil, "", err
}
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore)
eventStore, false)
if err != nil {
return nil, "", err
}

View File

@@ -135,6 +135,7 @@ func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
p.removeTurnConn(endpointPort)
log.Infof("stop forward turn packages to port: %d. error: %s", endpointPort, err)
return
}
err = p.sendPkg(buf[:n], endpointPort)
@@ -158,7 +159,7 @@ func (p *WGEBPFProxy) proxyToRemote() {
conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock()
if !ok {
log.Errorf("turn conn not found by port: %d", addr.Port)
log.Infof("turn conn not found by port: %d", addr.Port)
continue
}

View File

@@ -36,7 +36,7 @@ services:
volumes:
- $SIGNAL_VOLUMENAME:/var/lib/netbird
ports:
- 10000:80
- $NETBIRD_SIGNAL_PORT:80
# # port and command for Let's Encrypt validation
# - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]

View File

@@ -21,4 +21,5 @@ NETBIRD_AUTH_USER_ID_CLAIM="email"
NETBIRD_AUTH_DEVICE_AUTH_SCOPE="openid email"
NETBIRD_MGMT_IDP=$CI_NETBIRD_MGMT_IDP
NETBIRD_IDP_MGMT_CLIENT_ID=$CI_NETBIRD_IDP_MGMT_CLIENT_ID
NETBIRD_IDP_MGMT_CLIENT_SECRET=$CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
NETBIRD_IDP_MGMT_CLIENT_SECRET=$CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
NETBIRD_SIGNAL_PORT=12345

View File

@@ -61,7 +61,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
peersUpdateManager := mgmt.NewPeersUpdateManager()
eventStore := &activity.InMemoryEventStore{}
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore)
eventStore, false)
if err != nil {
t.Fatal(err)
}

View File

@@ -31,6 +31,7 @@ import (
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/idp"
@@ -142,12 +143,22 @@ var (
if disableSingleAccMode {
mgmtSingleAccModeDomain = ""
}
eventStore, err := sqlite.NewSQLiteStore(config.Datadir)
eventStore, key, err := initEventStore(config.Datadir, config.DataStoreEncryptionKey)
if err != nil {
return err
return fmt.Errorf("failed to initialize database: %s", err)
}
if key != "" {
log.Debugf("update config with activity store key")
config.DataStoreEncryptionKey = key
err := updateMgmtConfig(mgmtConfig, config)
if err != nil {
return fmt.Errorf("failed to write out store encryption key: %s", err)
}
}
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore)
dnsDomain, eventStore, userDeleteFromIDPEnabled)
if err != nil {
return fmt.Errorf("failed to build default manager: %v", err)
}
@@ -287,6 +298,20 @@ var (
}
)
func initEventStore(dataDir string, key string) (activity.Store, string, error) {
var err error
if key == "" {
log.Debugf("generate new activity store encryption key")
key, err = sqlite.GenerateKey()
if err != nil {
return nil, "", err
}
}
store, err := sqlite.NewSQLiteStore(dataDir, key)
return store, key, err
}
func notifyStop(msg string) {
select {
case stopCh <- 1:
@@ -440,6 +465,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
return loadedConfig, err
}
func updateMgmtConfig(path string, config *server.Config) error {
return util.WriteJson(path, config)
}
// OIDCConfigResponse used for parsing OIDC config response
type OIDCConfigResponse struct {
Issuer string `json:"issuer"`

View File

@@ -24,6 +24,7 @@ var (
disableMetrics bool
disableSingleAccMode bool
idpSignKeyRefreshEnabled bool
userDeleteFromIDPEnabled bool
rootCmd = &cobra.Command{
Use: "netbird-mgmt",
@@ -56,6 +57,7 @@ func init() {
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.")
mgmtCmd.Flags().BoolVar(&userDeleteFromIDPEnabled, "user-delete-from-idp", false, "Allows to delete user from IDP when user is deleted from account")
rootCmd.MarkFlagRequired("config") //nolint
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "")

View File

@@ -80,7 +80,6 @@ type AccountManager interface {
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
GetGroup(accountId, groupID string) (*Group, error)
SaveGroup(accountID, userID string, group *Group) error
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
DeleteGroup(accountId, userId, groupID string) error
ListGroups(accountId string) ([]*Group, error)
GroupAddPeer(accountId, groupID, peerID string) error
@@ -93,13 +92,11 @@ type AccountManager interface {
GetRoute(accountID, routeID, userID string) (*route.Route, error)
CreateRoute(accountID string, prefix, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
SaveRoute(accountID, userID string, route *route.Route) error
UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
DeleteRoute(accountID, routeID, userID string) error
ListRoutes(accountID, userID string) ([]*route.Route, error)
GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
DeleteNameServerGroup(accountID, nsGroupID, userID string) error
ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error)
GetDNSDomain() string
@@ -133,6 +130,9 @@ type DefaultAccountManager struct {
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
peerLoginExpiry Scheduler
// userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account
userDeleteFromIDPEnabled bool
}
// Settings represents Account settings structure that can be modified via API and Dashboard
@@ -738,18 +738,19 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
// BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, userDeleteFromIDPEnabled bool,
) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{
Store: store,
peersUpdateManager: peersUpdateManager,
idpManager: idpManager,
ctx: context.Background(),
cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{},
dnsDomain: dnsDomain,
eventStore: eventStore,
peerLoginExpiry: NewDefaultScheduler(),
Store: store,
peersUpdateManager: peersUpdateManager,
idpManager: idpManager,
ctx: context.Background(),
cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{},
dnsDomain: dnsDomain,
eventStore: eventStore,
peerLoginExpiry: NewDefaultScheduler(),
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
}
allAccounts := store.GetAllAccounts()
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
@@ -874,33 +875,19 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func()
return account.GetNextPeerExpiration()
}
expiredPeers := account.GetExpiredPeers()
var peerIDs []string
for _, peer := range account.GetExpiredPeers() {
if peer.Status.LoginExpired {
continue
}
for _, peer := range expiredPeers {
peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true)
account.UpdatePeer(peer)
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
if err != nil {
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
return account.GetNextPeerExpiration()
}
am.storeEvent(peer.UserID, peer.ID, account.Id, activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()))
}
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service
am.peersUpdateManager.CloseChannels(peerIDs)
err = am.updateAccountPeers(account)
if err != nil {
log.Errorf("failed updating account peers while expiring peers for account %s", accountID)
return account.GetNextPeerExpiration()
}
if err := am.expireAndUpdatePeers(account, expiredPeers); err != nil {
log.Errorf("failed updating account peers while expiring peers for account %s", account.Id)
return account.GetNextPeerExpiration()
}
return account.GetNextPeerExpiration()
}
}
@@ -1605,19 +1592,3 @@ func newAccountWithId(accountID, userID, domain string) *Account {
}
return acc
}
func removeFromList(inputList []string, toRemove []string) []string {
toRemoveMap := make(map[string]struct{})
for _, item := range toRemove {
toRemoveMap[item] = struct{}{}
}
var resultList []string
for _, item := range inputList {
_, ok := toRemoveMap[item]
if !ok {
resultList = append(resultList, item)
}
}
return resultList
}

View File

@@ -2063,7 +2063,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore)
return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore, false)
}
func createStore(t *testing.T) (Store, error) {

View File

@@ -104,6 +104,8 @@ const (
UserBlocked
// UserUnblocked indicates that a user unblocked another user
UserUnblocked
// UserDeleted indicates that a user deleted another user
UserDeleted
// GroupDeleted indicates that a user deleted group
GroupDeleted
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
@@ -162,6 +164,7 @@ var activityMap = map[Activity]Code{
ServiceUserDeleted: {"Service user deleted", "service.user.delete"},
UserBlocked: {"User blocked", "user.block"},
UserUnblocked: {"User unblocked", "user.unblock"},
UserDeleted: {"User deleted", "user.delete"},
GroupDeleted: {"Group deleted", "group.delete"},
UserLoggedInPeer: {"User logged in peer", "user.peer.login"},
PeerLoginExpired: {"Peer login expired", "peer.login.expire"},

View File

@@ -18,10 +18,13 @@ type Event struct {
ID uint64
// InitiatorID is the ID of an object that initiated the event (e.g., a user)
InitiatorID string
// InitiatorEmail is the email address of an object that initiated the event. This will be set on deleted users only
InitiatorEmail string
// TargetID is the ID of an object that was effected by the event (e.g., a peer)
TargetID string
// AccountID is the ID of an account where the event happened
AccountID string
// Meta of the event, e.g. deleted peer information like name, IP, etc
Meta map[string]any
}
@@ -35,12 +38,13 @@ func (e *Event) Copy() *Event {
}
return &Event{
Timestamp: e.Timestamp,
Activity: e.Activity,
ID: e.ID,
InitiatorID: e.InitiatorID,
TargetID: e.TargetID,
AccountID: e.AccountID,
Meta: meta,
Timestamp: e.Timestamp,
Activity: e.Activity,
ID: e.ID,
InitiatorID: e.InitiatorID,
InitiatorEmail: e.InitiatorEmail,
TargetID: e.TargetID,
AccountID: e.AccountID,
Meta: meta,
}
}

View File

@@ -0,0 +1,81 @@
package sqlite
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
)
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
type EmailEncrypt struct {
block cipher.Block
}
func GenerateKey() (string, error) {
key := make([]byte, 32)
_, err := rand.Read(key)
if err != nil {
return "", err
}
readableKey := base64.StdEncoding.EncodeToString(key)
return readableKey, nil
}
func NewEmailEncrypt(key string) (*EmailEncrypt, error) {
binKey, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(binKey)
if err != nil {
return nil, err
}
ec := &EmailEncrypt{
block: block,
}
return ec, nil
}
func (ec *EmailEncrypt) Encrypt(payload string) string {
plainText := pkcs5Padding([]byte(payload))
cipherText := make([]byte, len(plainText))
cbc := cipher.NewCBCEncrypter(ec.block, iv)
cbc.CryptBlocks(cipherText, plainText)
return base64.StdEncoding.EncodeToString(cipherText)
}
func (ec *EmailEncrypt) Decrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
}
cbc := cipher.NewCBCDecrypter(ec.block, iv)
cbc.CryptBlocks(cipherText, cipherText)
payload, err := pkcs5UnPadding(cipherText)
if err != nil {
return "", err
}
return string(payload), nil
}
func pkcs5Padding(ciphertext []byte) []byte {
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padText...)
}
func pkcs5UnPadding(src []byte) ([]byte, error) {
srcLen := len(src)
paddingLen := int(src[srcLen-1])
if paddingLen >= srcLen || paddingLen > aes.BlockSize {
return nil, fmt.Errorf("padding size error")
}
return src[:srcLen-paddingLen], nil
}

View File

@@ -0,0 +1,63 @@
package sqlite
import (
"testing"
)
func TestGenerateKey(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewEmailEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.Encrypt(testData)
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
decrypted, err := ee.Decrypt(encrypted)
if err != nil {
t.Fatalf("failed to decrypt data: %s", err)
}
if decrypted != testData {
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
}
}
func TestCorruptKey(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewEmailEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.Encrypt(testData)
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
newKey, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err = NewEmailEncrypt(newKey)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
res, err := ee.Decrypt(encrypted)
if err == nil || res == testData {
t.Fatalf("incorrect decryption, the result is: %s", res)
}
}

View File

@@ -3,14 +3,14 @@ package sqlite
import (
"database/sql"
"encoding/json"
"github.com/netbirdio/netbird/management/server/activity"
// sqlite driver
"fmt"
"path/filepath"
"time"
_ "github.com/mattn/go-sqlite3"
_ "github.com/mattn/go-sqlite3" // sqlite driver
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
)
const (
@@ -25,35 +25,62 @@ const (
"meta TEXT," +
" target_id TEXT);"
selectDescQuery = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" +
" FROM events WHERE account_id = ? ORDER BY timestamp DESC LIMIT ? OFFSET ?;"
selectAscQuery = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" +
" FROM events WHERE account_id = ? ORDER BY timestamp ASC LIMIT ? OFFSET ?;"
creatTableAccountEmailQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL);`
selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.email as "initiator_email", target_id, t.email as "target_email", account_id, meta
FROM events
LEFT JOIN deleted_users i ON events.initiator_id = i.id
LEFT JOIN deleted_users t ON events.target_id = t.id
WHERE account_id = ?
ORDER BY timestamp DESC LIMIT ? OFFSET ?;`
selectAscQuery = `SELECT events.id, activity, timestamp, initiator_id, i.email as "initiator_email", target_id, t.email as "target_email", account_id, meta
FROM events
LEFT JOIN deleted_users i ON events.initiator_id = i.id
LEFT JOIN deleted_users t ON events.target_id = t.id
WHERE account_id = ?
ORDER BY timestamp ASC LIMIT ? OFFSET ?;`
insertQuery = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " +
"VALUES(?, ?, ?, ?, ?, ?)"
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email) VALUES(?, ?)`
)
// Store is the implementation of the activity.Store interface backed by SQLite
type Store struct {
db *sql.DB
db *sql.DB
emailEncrypt *EmailEncrypt
insertStatement *sql.Stmt
selectAscStatement *sql.Stmt
selectDescStatement *sql.Stmt
deleteUserStmt *sql.Stmt
}
// NewSQLiteStore creates a new Store with an event table if not exists.
func NewSQLiteStore(dataDir string) (*Store, error) {
func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
dbFile := filepath.Join(dataDir, eventSinkDB)
db, err := sql.Open("sqlite3", dbFile)
if err != nil {
return nil, err
}
crypt, err := NewEmailEncrypt(encryptionKey)
if err != nil {
return nil, err
}
_, err = db.Exec(createTableQuery)
if err != nil {
return nil, err
}
_, err = db.Exec(creatTableAccountEmailQuery)
if err != nil {
return nil, err
}
insertStmt, err := db.Prepare(insertQuery)
if err != nil {
return nil, err
@@ -69,25 +96,35 @@ func NewSQLiteStore(dataDir string) (*Store, error) {
return nil, err
}
return &Store{
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
if err != nil {
return nil, err
}
s := &Store{
db: db,
emailEncrypt: crypt,
insertStatement: insertStmt,
selectDescStatement: selectDescStmt,
selectAscStatement: selectAscStmt,
}, nil
deleteUserStmt: deleteUserStmt,
}
return s, nil
}
func processResult(result *sql.Rows) ([]*activity.Event, error) {
func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
events := make([]*activity.Event, 0)
for result.Next() {
var id int64
var operation activity.Activity
var timestamp time.Time
var initiator string
var initiatorEmail *string
var target string
var targetEmail *string
var account string
var jsonMeta string
err := result.Scan(&id, &operation, &timestamp, &initiator, &target, &account, &jsonMeta)
err := result.Scan(&id, &operation, &timestamp, &initiator, &initiatorEmail, &target, &targetEmail, &account, &jsonMeta)
if err != nil {
return nil, err
}
@@ -100,7 +137,17 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) {
}
}
events = append(events, &activity.Event{
if targetEmail != nil {
email, err := store.emailEncrypt.Decrypt(*targetEmail)
if err != nil {
log.Errorf("failed to decrypt email address for target id: %s", target)
meta["email"] = ""
} else {
meta["email"] = email
}
}
event := &activity.Event{
Timestamp: timestamp,
Activity: operation,
ID: uint64(id),
@@ -108,7 +155,18 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) {
TargetID: target,
AccountID: account,
Meta: meta,
})
}
if initiatorEmail != nil {
email, err := store.emailEncrypt.Decrypt(*initiatorEmail)
if err != nil {
log.Errorf("failed to decrypt email address of initiator: %s", initiator)
} else {
event.InitiatorEmail = email
}
}
events = append(events, event)
}
return events, nil
@@ -127,13 +185,18 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([
}
defer result.Close() //nolint
return processResult(result)
return store.processResult(result)
}
// Save an event in the SQLite events table
// Save an event in the SQLite events table end encrypt the "email" element in meta map
func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
var jsonMeta string
if event.Meta != nil {
meta, err := store.saveDeletedUserEmailInEncrypted(event)
if err != nil {
return nil, err
}
if meta != nil {
metaBytes, err := json.Marshal(event.Meta)
if err != nil {
return nil, err
@@ -156,6 +219,29 @@ func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
return eventCopy, nil
}
// saveDeletedUserEmailInEncrypted if the meta contains email then store it in encrypted way and delete this item from
// meta map
func (store *Store) saveDeletedUserEmailInEncrypted(event *activity.Event) (map[string]any, error) {
email, ok := event.Meta["email"]
if !ok {
return event.Meta, nil
}
delete(event.Meta, "email")
encrypted := store.emailEncrypt.Encrypt(fmt.Sprintf("%s", email))
_, err := store.deleteUserStmt.Exec(event.TargetID, encrypted)
if err != nil {
return nil, err
}
if len(event.Meta) == 1 {
return nil, nil // nolint
}
delete(event.Meta, "email")
return event.Meta, nil
}
// Close the Store
func (store *Store) Close() error {
if store.db != nil {

View File

@@ -12,7 +12,8 @@ import (
func TestNewSQLiteStore(t *testing.T) {
dataDir := t.TempDir()
store, err := NewSQLiteStore(dataDir)
key, _ := GenerateKey()
store, err := NewSQLiteStore(dataDir, key)
if err != nil {
t.Fatal(err)
return

View File

@@ -35,7 +35,8 @@ type Config struct {
TURNConfig *TURNConfig
Signal *Host
Datadir string
Datadir string
DataStoreEncryptionKey string
HttpConfig *HttpServerConfig

View File

@@ -191,7 +191,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.test", eventStore)
return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.test", eventStore, false)
}
func createDNSStore(t *testing.T) (Store, error) {

View File

@@ -33,26 +33,6 @@ type Group struct {
Peers []string
}
const (
// UpdateGroupName indicates a name update operation
UpdateGroupName GroupUpdateOperationType = iota
// InsertPeersToGroup indicates insert peers to group operation
InsertPeersToGroup
// RemovePeersFromGroup indicates a remove peers from group operation
RemovePeersFromGroup
// UpdateGroupPeers indicates a replacement of group peers list
UpdateGroupPeers
)
// GroupUpdateOperationType operation type
type GroupUpdateOperationType int
// GroupUpdateOperation operation object with type and values to be applied
type GroupUpdateOperation struct {
Type GroupUpdateOperationType
Values []string
}
// EventMeta returns activity event meta related to the group
func (g *Group) EventMeta() map[string]any {
return map[string]any{"name": g.Name}
@@ -165,57 +145,6 @@ func difference(a, b []string) []string {
return diff
}
// UpdateGroup updates a group using a list of operations
func (am *DefaultAccountManager) UpdateGroup(accountID string,
groupID string, operations []GroupUpdateOperation,
) (*Group, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
groupToUpdate, ok := account.Groups[groupID]
if !ok {
return nil, status.Errorf(status.NotFound, "group with ID %s no longer exists", groupID)
}
group := groupToUpdate.Copy()
for _, operation := range operations {
switch operation.Type {
case UpdateGroupName:
group.Name = operation.Values[0]
case UpdateGroupPeers:
group.Peers = operation.Values
case InsertPeersToGroup:
sourceList := group.Peers
resultList := removeFromList(sourceList, operation.Values)
group.Peers = append(resultList, operation.Values...)
case RemovePeersFromGroup:
sourceList := group.Peers
resultList := removeFromList(sourceList, operation.Values)
group.Peers = resultList
}
}
account.Groups[groupID] = group
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
return nil, err
}
err = am.updateAccountPeers(account)
if err != nil {
return nil, err
}
return group, nil
}
// DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
unlock := am.Store.AcquireAccountLock(accountId)

0
management/server/http/api/generate.sh Normal file → Executable file
View File

View File

@@ -922,6 +922,10 @@ components:
description: The ID of the initiator of the event. E.g., an ID of a user that triggered the event.
type: string
example: google-oauth2|123456789012345678901
initiator_email:
description: The e-mail address of the initiator of the event. E.g., an e-mail of a user that triggered the event.
type: string
example: demo@netbird.io
target_id:
description: The ID of the target of the event. E.g., an ID of the peer that a user removed.
type: string
@@ -938,6 +942,7 @@ components:
- activity
- activity_code
- initiator_id
- initiator_email
- target_id
- meta
responses:

View File

@@ -164,6 +164,9 @@ type Event struct {
// Id Event unique identifier
Id string `json:"id"`
// InitiatorEmail The e-mail address of the initiator of the event. E.g., an e-mail of a user that triggered the event.
InitiatorEmail string `json:"initiator_email"`
// InitiatorId The ID of the initiator of the event. E.g., an ID of a user that triggered the event.
InitiatorId string `json:"initiator_id"`

View File

@@ -45,14 +45,46 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
util.WriteError(err, w)
return
}
events := make([]*api.Event, 0)
for _, e := range accountEvents {
events = append(events, toEventResponse(e))
events := make([]*api.Event, len(accountEvents))
for i, e := range accountEvents {
events[i] = toEventResponse(e)
}
err = h.fillEventsWithInitiatorEmail(events, account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, events)
}
func (h *EventsHandler) fillEventsWithInitiatorEmail(events []*api.Event, accountId, userId string) error {
// build email map based on users
userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId)
if err != nil {
log.Errorf("failed to get users from account: %s", err)
return err
}
emails := make(map[string]string)
for _, ui := range userInfos {
emails[ui.ID] = ui.Email
}
// fill event with email of initiator
var ok bool
for _, event := range events {
if event.InitiatorEmail == "" {
event.InitiatorEmail, ok = emails[event.InitiatorId]
if !ok {
log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
}
}
}
return nil
}
func toEventResponse(event *activity.Event) *api.Event {
meta := make(map[string]string)
if event.Meta != nil {
@@ -60,13 +92,15 @@ func toEventResponse(event *activity.Event) *api.Event {
meta[s] = fmt.Sprintf("%v", a)
}
}
return &api.Event{
Id: fmt.Sprint(event.ID),
InitiatorId: event.InitiatorID,
Activity: event.Activity.Message(),
ActivityCode: api.EventActivityCode(event.Activity.StringCode()),
TargetId: event.TargetID,
Timestamp: event.Timestamp,
Meta: meta,
e := &api.Event{
Id: fmt.Sprint(event.ID),
InitiatorId: event.InitiatorID,
InitiatorEmail: event.InitiatorEmail,
Activity: event.Activity.Message(),
ActivityCode: api.EventActivityCode(event.Activity.StringCode()),
TargetId: event.TargetID,
Timestamp: event.Timestamp,
Meta: meta,
}
return e
}

View File

@@ -37,6 +37,9 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
},
}, user, nil
},
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
return make([]*server.UserInfo, 0), nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {

View File

@@ -53,22 +53,6 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle
Issued: server.GroupIssuedAPI,
}, nil
},
UpdateGroupFunc: func(_ string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) {
var group server.Group
group.ID = groupID
for _, operation := range operations {
switch operation.Type {
case server.UpdateGroupName:
group.Name = operation.Values[0]
case server.UpdateGroupPeers, server.InsertPeersToGroup:
group.Peers = operation.Values
case server.RemovePeersFromGroup:
default:
return nil, fmt.Errorf("no operation")
}
}
return &group, nil
},
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
for _, peer := range TestPeers {
if peer.IP.String() == peerIP {

View File

@@ -88,31 +88,6 @@ func initNameserversTestData() *NameserversHandler {
}
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
},
UpdateNameServerGroupFunc: func(accountID, nsGroupID, _ string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
nsGroupToUpdate := baseExistingNSGroup.Copy()
if nsGroupID != nsGroupToUpdate.ID {
return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
}
for _, operation := range operations {
switch operation.Type {
case server.UpdateNameServerGroupName:
nsGroupToUpdate.Name = operation.Values[0]
case server.UpdateNameServerGroupDescription:
nsGroupToUpdate.Description = operation.Values[0]
case server.UpdateNameServerGroupNameServers:
var parsedNSList []nbdns.NameServer
for _, nsURL := range operation.Values {
parsed, err := nbdns.ParseNameServerURL(nsURL)
if err != nil {
return nil, err
}
parsedNSList = append(parsedNSList, parsed)
}
nsGroupToUpdate.NameServers = parsedNSList
}
}
return nsGroupToUpdate, nil
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingNSAccount, testingAccount.Users["test_user"], nil
},

View File

@@ -8,7 +8,6 @@ import (
"net/http"
"net/http/httptest"
"net/netip"
"strconv"
"testing"
"github.com/netbirdio/netbird/management/server/http/api"
@@ -108,38 +107,6 @@ func initRoutesTestData() *RoutesHandler {
IP: netip.MustParseAddr(existingPeerID).AsSlice(),
}, nil
},
UpdateRouteFunc: func(_ string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) {
routeToUpdate := baseExistingRoute
if routeID != routeToUpdate.ID {
return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID)
}
for _, operation := range operations {
switch operation.Type {
case server.UpdateRouteNetwork:
routeToUpdate.NetworkType, routeToUpdate.Network, _ = route.ParseNetwork(operation.Values[0])
case server.UpdateRouteDescription:
routeToUpdate.Description = operation.Values[0]
case server.UpdateRouteNetworkIdentifier:
routeToUpdate.NetID = operation.Values[0]
case server.UpdateRoutePeer:
routeToUpdate.Peer = operation.Values[0]
if routeToUpdate.Peer == notFoundPeerID {
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", routeToUpdate.Peer)
}
case server.UpdateRouteMetric:
routeToUpdate.Metric, _ = strconv.Atoi(operation.Values[0])
case server.UpdateRouteMasquerade:
routeToUpdate.Masquerade, _ = strconv.ParseBool(operation.Values[0])
case server.UpdateRouteEnabled:
routeToUpdate.Enabled, _ = strconv.ParseBool(operation.Values[0])
case server.UpdateRouteGroups:
routeToUpdate.Groups = operation.Values
default:
return nil, fmt.Errorf("no operation")
}
}
return routeToUpdate, nil
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingAccount, testingAccount.Users["test_user"], nil
},

View File

@@ -513,7 +513,9 @@ func buildUserExportRequest() (string, error) {
return string(str), nil
}
func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) {
func (am *Auth0Manager) createRequest(
method string, endpoint string, body io.Reader,
) (*http.Request, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
@@ -521,17 +523,23 @@ func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*
reqURL := am.authIssuer + endpoint
payload := strings.NewReader(payloadStr)
req, err := http.NewRequest("POST", reqURL, payload)
req, err := http.NewRequest(method, reqURL, body)
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
return req, nil
}
func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) {
req, err := am.createRequest("POST", endpoint, strings.NewReader(payloadStr))
if err != nil {
return nil, err
}
req.Header.Add("content-type", "application/json")
return req, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
@@ -737,6 +745,38 @@ func (am *Auth0Manager) InviteUserByID(userID string) error {
return nil
}
// DeleteUser from Auth0
func (am *Auth0Manager) DeleteUser(userID string) error {
req, err := am.createRequest(http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
if err != nil {
return err
}
resp, err := am.httpClient.Do(req)
if err != nil {
log.Debugf("execute delete request: %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer func() {
err = resp.Body.Close()
if err != nil {
log.Errorf("close delete request body: %v", err)
}
}()
if resp.StatusCode != 204 {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
}
return nil
}
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
// If the status is "completed", then return the downloadLink
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {

View File

@@ -12,9 +12,10 @@ import (
"time"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server/telemetry"
log "github.com/sirupsen/logrus"
"goauthentik.io/api/v3"
"github.com/netbirdio/netbird/management/server/telemetry"
)
// AuthentikManager authentik manager client instance.
@@ -453,6 +454,38 @@ func (am *AuthentikManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Authentik
func (am *AuthentikManager) DeleteUser(userID string) error {
ctx, err := am.authenticationContext()
if err != nil {
return err
}
userPk, err := strconv.ParseInt(userID, 10, 32)
if err != nil {
return err
}
resp, err := am.apiClient.CoreApi.CoreUsersDestroy(ctx, int32(userPk)).Execute()
if err != nil {
return err
}
defer resp.Body.Close() // nolint
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountDeleteUser()
}
if resp.StatusCode != http.StatusNoContent {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete user %s, statusCode %d", userID, resp.StatusCode)
}
return nil
}
func (am *AuthentikManager) authenticationContext() (context.Context, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {

View File

@@ -454,6 +454,43 @@ func (am *AzureManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Azure
func (am *AzureManager) DeleteUser(userID string) error {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return err
}
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, url.QueryEscape(userID))
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.Debugf("delete idp user %s", userID)
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountDeleteUser()
}
if resp.StatusCode != http.StatusNoContent {
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
}
return nil
}
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
q := url.Values{}
q.Add("$select", extensionFields)

View File

@@ -254,6 +254,19 @@ func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from GoogleWorkspace.
func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error {
if err := gm.usersService.Delete(userID).Do(); err != nil {
return err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountDeleteUser()
}
return nil
}
// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey.
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
// If that fails, it falls back to using the default Google credentials path.

View File

@@ -18,6 +18,7 @@ type Manager interface {
CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error)
GetUserByEmail(email string) ([]*UserData, error)
InviteUserByID(userID string) error
DeleteUser(userID string) error
}
// ClientConfig defines common client configuration for all IdP manager

View File

@@ -467,6 +467,47 @@ func (km *KeycloakManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Keycloack
func (km *KeycloakManager) DeleteUser(userID string) error {
jwtToken, err := km.credentials.Authenticate()
if err != nil {
return err
}
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID))
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountDeleteUser()
}
resp, err := km.httpClient.Do(req)
if err != nil {
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close() // nolint
// In the docs, they specified 200, but in the endpoints, they return 204
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
}
return nil
}
func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) {
attrs := keycloakUserAttributes{}
attrs.Set(wtAccountID, appMetadata.WTAccountID)

View File

@@ -319,6 +319,28 @@ func (om *OktaManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Okta
func (om *OktaManager) DeleteUser(userID string) error {
resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil)
if err != nil {
fmt.Println(err.Error())
return err
}
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountDeleteUser()
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
}
return nil
}
// updateUserProfileSchema updates the Okta user schema to include custom fields,
// wt_account_id and wt_pending_invite.
func updateUserProfileSchema(client *okta.Client) error {

View File

@@ -428,7 +428,7 @@ func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMe
return err
}
resource := fmt.Sprintf("users/%s/metadata/_bulk", userID)
resource := fmt.Sprintf("users/%s", userID)
_, err = zm.post(resource, string(payload))
if err != nil {
return err
@@ -447,6 +447,21 @@ func (zm *ZitadelManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Zitadel
func (zm *ZitadelManager) DeleteUser(userID string) error {
resource := fmt.Sprintf("users/%s", userID)
if err := zm.delete(resource); err != nil {
return err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountDeleteUser()
}
return nil
}
// getUserMetadata requests user metadata from zitadel via ID.
func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) {
resource := fmt.Sprintf("users/%s/metadata/_search", userID)
@@ -500,6 +515,42 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
return io.ReadAll(resp.Body)
}
// delete perform Delete requests.
func (zm *ZitadelManager) delete(resource string) error {
jwtToken, err := zm.credentials.Authenticate()
if err != nil {
return err
}
reqURL := fmt.Sprintf("%s/%s", zm.managementEndpoint, resource)
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := zm.httpClient.Do(req)
if err != nil {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete %s, statusCode %d", reqURL, resp.StatusCode)
}
return nil
}
// get perform Get requests.
func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := zm.credentials.Authenticate()

View File

@@ -412,7 +412,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
peersUpdateManager := NewPeersUpdateManager()
eventStore := &activity.InMemoryEventStore{}
accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "",
eventStore)
eventStore, false)
if err != nil {
return nil, "", err
}

View File

@@ -503,7 +503,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
peersUpdateManager := server.NewPeersUpdateManager()
eventStore := &activity.InMemoryEventStore{}
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore)
eventStore, false)
if err != nil {
log.Fatalf("failed creating a manager: %v", err)
}

View File

@@ -31,7 +31,6 @@ type MockAccountManager struct {
AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, *server.NetworkMap, error)
GetGroupFunc func(accountID, groupID string) (*server.Group, error)
SaveGroupFunc func(accountID, userID string, group *server.Group) error
UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error)
DeleteGroupFunc func(accountID, userId, groupID string) error
ListGroupsFunc func(accountID string) ([]*server.Group, error)
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
@@ -54,7 +53,6 @@ type MockAccountManager struct {
CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID, userID string, route *route.Route) error
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
DeleteRouteFunc func(accountID, routeID, userID string) error
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
@@ -68,7 +66,6 @@ type MockAccountManager struct {
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
UpdateNameServerGroupFunc func(accountID, nsGroupID, userID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
@@ -267,14 +264,6 @@ func (am *MockAccountManager) SaveGroup(accountID, userID string, group *server.
return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented")
}
// UpdateGroup mock implementation of UpdateGroup from server.AccountManager interface
func (am *MockAccountManager) UpdateGroup(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) {
if am.UpdateGroupFunc != nil {
return am.UpdateGroupFunc(accountID, groupID, operations)
}
return nil, status.Errorf(codes.Unimplemented, "method UpdateGroup not implemented")
}
// DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface
func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) error {
if am.DeleteGroupFunc != nil {
@@ -435,14 +424,6 @@ func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.R
return status.Errorf(codes.Unimplemented, "method SaveRoute is not implemented")
}
// UpdateRoute mock implementation of UpdateRoute from server.AccountManager interface
func (am *MockAccountManager) UpdateRoute(accountID, ruleID string, operations []server.RouteUpdateOperation) (*route.Route, error) {
if am.UpdateRouteFunc != nil {
return am.UpdateRouteFunc(accountID, ruleID, operations)
}
return nil, status.Errorf(codes.Unimplemented, "method UpdateRoute not implemented")
}
// DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface
func (am *MockAccountManager) DeleteRoute(accountID, routeID, userID string) error {
if am.DeleteRouteFunc != nil {
@@ -533,14 +514,6 @@ func (am *MockAccountManager) SaveNameServerGroup(accountID, userID string, nsGr
return nil
}
// UpdateNameServerGroup mocks UpdateNameServerGroup of the AccountManager interface
func (am *MockAccountManager) UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
if am.UpdateNameServerGroupFunc != nil {
return am.UpdateNameServerGroupFunc(accountID, nsGroupID, userID, operations)
}
return nil, nil
}
// DeleteNameServerGroup mocks DeleteNameServerGroup of the AccountManager interface
func (am *MockAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
if am.DeleteNameServerGroupFunc != nil {

View File

@@ -3,7 +3,6 @@ package server
import (
"errors"
"regexp"
"strconv"
"unicode/utf8"
"github.com/miekg/dns"
@@ -15,54 +14,7 @@ import (
"github.com/netbirdio/netbird/management/server/status"
)
const (
// UpdateNameServerGroupName indicates a nameserver group name update operation
UpdateNameServerGroupName NameServerGroupUpdateOperationType = iota
// UpdateNameServerGroupDescription indicates a nameserver group description update operation
UpdateNameServerGroupDescription
// UpdateNameServerGroupNameServers indicates a nameserver group nameservers list update operation
UpdateNameServerGroupNameServers
// UpdateNameServerGroupGroups indicates a nameserver group' groups update operation
UpdateNameServerGroupGroups
// UpdateNameServerGroupEnabled indicates a nameserver group status update operation
UpdateNameServerGroupEnabled
// UpdateNameServerGroupPrimary indicates a nameserver group primary status update operation
UpdateNameServerGroupPrimary
// UpdateNameServerGroupDomains indicates a nameserver group' domains update operation
UpdateNameServerGroupDomains
domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
)
// NameServerGroupUpdateOperationType operation type
type NameServerGroupUpdateOperationType int
func (t NameServerGroupUpdateOperationType) String() string {
switch t {
case UpdateNameServerGroupDescription:
return "UpdateNameServerGroupDescription"
case UpdateNameServerGroupName:
return "UpdateNameServerGroupName"
case UpdateNameServerGroupNameServers:
return "UpdateNameServerGroupNameServers"
case UpdateNameServerGroupGroups:
return "UpdateNameServerGroupGroups"
case UpdateNameServerGroupEnabled:
return "UpdateNameServerGroupEnabled"
case UpdateNameServerGroupPrimary:
return "UpdateNameServerGroupPrimary"
case UpdateNameServerGroupDomains:
return "UpdateNameServerGroupDomains"
default:
return "InvalidOperation"
}
}
// NameServerGroupUpdateOperation operation object with type and values to be applied
type NameServerGroupUpdateOperation struct {
Type NameServerGroupUpdateOperationType
Values []string
}
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
@@ -172,109 +124,6 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n
return nil
}
// UpdateNameServerGroup updates existing nameserver group with set of operations
func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
if len(operations) == 0 {
return nil, status.Errorf(status.InvalidArgument, "operations shouldn't be empty")
}
nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID]
if !ok {
return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
}
newNSGroup := nsGroupToUpdate.Copy()
for _, operation := range operations {
valuesCount := len(operation.Values)
if valuesCount < 1 {
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String())
}
for _, value := range operation.Values {
if value == "" {
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String())
}
}
switch operation.Type {
case UpdateNameServerGroupDescription:
newNSGroup.Description = operation.Values[0]
case UpdateNameServerGroupName:
if valuesCount > 1 {
return nil, status.Errorf(status.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount)
}
err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups)
if err != nil {
return nil, err
}
newNSGroup.Name = operation.Values[0]
case UpdateNameServerGroupNameServers:
var nsList []nbdns.NameServer
for _, url := range operation.Values {
ns, err := nbdns.ParseNameServerURL(url)
if err != nil {
return nil, err
}
nsList = append(nsList, ns)
}
err = validateNSList(nsList)
if err != nil {
return nil, err
}
newNSGroup.NameServers = nsList
case UpdateNameServerGroupGroups:
err = validateGroups(operation.Values, account.Groups)
if err != nil {
return nil, err
}
newNSGroup.Groups = operation.Values
case UpdateNameServerGroupEnabled:
enabled, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
}
newNSGroup.Enabled = enabled
case UpdateNameServerGroupPrimary:
primary, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0])
}
newNSGroup.Primary = primary
case UpdateNameServerGroupDomains:
err = validateDomainInput(false, operation.Values)
if err != nil {
return nil, err
}
newNSGroup.Domains = operation.Values
}
}
account.NameServerGroups[nsGroupID] = newNSGroup
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after update nameserver %s", newNSGroup.Name)
}
return newNSGroup.Copy(), nil
}
// DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {

View File

@@ -655,323 +655,6 @@ func TestSaveNameServerGroup(t *testing.T) {
}
}
func TestUpdateNameServerGroup(t *testing.T) {
nsGroupID := "testingNSGroup"
existingNSGroup := &nbdns.NameServerGroup{
ID: nsGroupID,
Name: "super",
Description: "super",
Primary: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Groups: []string{group1ID},
Enabled: true,
}
testCases := []struct {
name string
existingNSGroup *nbdns.NameServerGroup
nsGroupID string
operations []NameServerGroupUpdateOperation
shouldCreate bool
errFunc require.ErrorAssertionFunc
expectedNSGroup *nbdns.NameServerGroup
}{
{
name: "Should Config Single Property",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{"superNew"},
},
},
errFunc: require.NoError,
shouldCreate: true,
expectedNSGroup: &nbdns.NameServerGroup{
ID: nsGroupID,
Name: "superNew",
Description: "super",
Primary: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Groups: []string{group1ID},
Enabled: true,
},
},
{
name: "Should Config Multiple Properties",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{"superNew"},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupDescription,
Values: []string{"superDescription"},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupNameServers,
Values: []string{"udp://127.0.0.1:53", "udp://8.8.8.8:53"},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupGroups,
Values: []string{group1ID, group2ID},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupEnabled,
Values: []string{"false"},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupPrimary,
Values: []string{"false"},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupDomains,
Values: []string{validDomain},
},
},
errFunc: require.NoError,
shouldCreate: true,
expectedNSGroup: &nbdns.NameServerGroup{
ID: nsGroupID,
Name: "superNew",
Description: "superDescription",
Primary: false,
Domains: []string{validDomain},
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("127.0.0.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Groups: []string{group1ID, group2ID},
Enabled: false,
},
},
{
name: "Should Not Config On Invalid ID",
existingNSGroup: existingNSGroup,
nsGroupID: "nonExistingNSGroup",
errFunc: require.Error,
},
{
name: "Should Not Config On Empty Operations",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{},
errFunc: require.Error,
},
{
name: "Should Not Config On Empty Values",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Empty String",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{""},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid Name Large String",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{"12345678901234567890qwertyuiopqwertyuiop1"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid On Existing Name",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{existingNSGroupName},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid On Multiple Name Values",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{"nameOne", "nameTwo"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid Boolean",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupEnabled,
Values: []string{"yes"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid Nameservers Wrong Schema",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupNameServers,
Values: []string{"https://127.0.0.1:53"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid Nameservers Wrong IP",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupNameServers,
Values: []string{"udp://8.8.8.300:53"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Large Number Of Nameservers",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupNameServers,
Values: []string{"udp://127.0.0.1:53", "udp://8.8.8.8:53", "udp://8.8.4.4:53"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid GroupID",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupGroups,
Values: []string{"nonExistingGroupID"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid Domains",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupDomains,
Values: []string{invalidDomain},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid Primary Status",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupPrimary,
Values: []string{"yes"},
},
},
errFunc: require.Error,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
am, err := createNSManager(t)
if err != nil {
t.Error("failed to create account manager")
}
account, err := initTestNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup
err = am.Store.SaveAccount(account)
if err != nil {
t.Error("account should be saved")
}
updatedRoute, err := am.UpdateNameServerGroup(account.Id, testCase.nsGroupID, userID, testCase.operations)
testCase.errFunc(t, err)
if !testCase.shouldCreate {
return
}
testCase.expectedNSGroup.ID = updatedRoute.ID
if !testCase.expectedNSGroup.IsEqual(updatedRoute) {
t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", updatedRoute, testCase.expectedNSGroup)
}
})
}
}
func TestDeleteNameServerGroup(t *testing.T) {
nsGroupID := "testingNSGroup"
@@ -1061,7 +744,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore)
return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore, false)
}
func createNSStore(t *testing.T) (Store, error) {

View File

@@ -2,7 +2,6 @@ package server
import (
"net/netip"
"strconv"
"unicode/utf8"
"github.com/netbirdio/netbird/management/proto"
@@ -13,57 +12,6 @@ import (
log "github.com/sirupsen/logrus"
)
const (
// UpdateRouteDescription indicates a route description update operation
UpdateRouteDescription RouteUpdateOperationType = iota
// UpdateRouteNetwork indicates a route IP update operation
UpdateRouteNetwork
// UpdateRoutePeer indicates a route peer update operation
UpdateRoutePeer
// UpdateRouteMetric indicates a route metric update operation
UpdateRouteMetric
// UpdateRouteMasquerade indicates a route masquerade update operation
UpdateRouteMasquerade
// UpdateRouteEnabled indicates a route enabled update operation
UpdateRouteEnabled
// UpdateRouteNetworkIdentifier indicates a route net ID update operation
UpdateRouteNetworkIdentifier
// UpdateRouteGroups indicates a group list update operation
UpdateRouteGroups
)
// RouteUpdateOperationType operation type
type RouteUpdateOperationType int
func (t RouteUpdateOperationType) String() string {
switch t {
case UpdateRouteDescription:
return "UpdateRouteDescription"
case UpdateRouteNetwork:
return "UpdateRouteNetwork"
case UpdateRoutePeer:
return "UpdateRoutePeer"
case UpdateRouteMetric:
return "UpdateRouteMetric"
case UpdateRouteMasquerade:
return "UpdateRouteMasquerade"
case UpdateRouteEnabled:
return "UpdateRouteEnabled"
case UpdateRouteNetworkIdentifier:
return "UpdateRouteNetworkIdentifier"
case UpdateRouteGroups:
return "UpdateRouteGroups"
default:
return "InvalidOperation"
}
}
// RouteUpdateOperation operation object with type and values to be applied
type RouteUpdateOperation struct {
Type RouteUpdateOperationType
Values []string
}
// GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID)
@@ -241,109 +189,6 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
return nil
}
// UpdateRoute updates existing route with set of operations
func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
routeToUpdate, ok := account.Routes[routeID]
if !ok {
return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID)
}
newRoute := routeToUpdate.Copy()
for _, operation := range operations {
if len(operation.Values) != 1 {
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be 1", operation.Type.String())
}
switch operation.Type {
case UpdateRouteDescription:
newRoute.Description = operation.Values[0]
case UpdateRouteNetworkIdentifier:
if utf8.RuneCountInString(operation.Values[0]) > route.MaxNetIDChar || operation.Values[0] == "" {
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
newRoute.NetID = operation.Values[0]
case UpdateRouteNetwork:
prefixType, prefix, err := route.ParseNetwork(operation.Values[0])
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", operation.Values[0])
}
err = am.checkPrefixPeerExists(accountID, routeToUpdate.Peer, prefix)
if err != nil {
return nil, err
}
newRoute.Network = prefix
newRoute.NetworkType = prefixType
case UpdateRoutePeer:
if operation.Values[0] != "" {
peer := account.GetPeer(operation.Values[0])
if peer == nil {
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", operation.Values[0])
}
}
err = am.checkPrefixPeerExists(accountID, operation.Values[0], routeToUpdate.Network)
if err != nil {
return nil, err
}
newRoute.Peer = operation.Values[0]
case UpdateRouteMetric:
metric, err := strconv.Atoi(operation.Values[0])
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, not int", operation.Values[0])
}
if metric < route.MinMetric || metric > route.MaxMetric {
return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, value should be %d > N < %d",
operation.Values[0],
route.MinMetric,
route.MaxMetric,
)
}
newRoute.Metric = metric
case UpdateRouteMasquerade:
masquerade, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "failed to parse masquerade %s, not boolean", operation.Values[0])
}
newRoute.Masquerade = masquerade
case UpdateRouteEnabled:
enabled, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
}
newRoute.Enabled = enabled
case UpdateRouteGroups:
err = validateGroups(operation.Values, account.Groups)
if err != nil {
return nil, err
}
newRoute.Groups = operation.Values
}
}
account.Routes[routeID] = newRoute
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
return nil, err
}
err = am.updateAccountPeers(account)
if err != nil {
return nil, status.Errorf(status.Internal, "failed to update account peers")
}
return newRoute, nil
}
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID)

View File

@@ -524,265 +524,6 @@ func TestSaveRoute(t *testing.T) {
}
}
func TestUpdateRoute(t *testing.T) {
routeID := "testingRouteID"
existingRoute := &route.Route{
ID: routeID,
Network: netip.MustParsePrefix("192.168.0.0/16"),
NetID: "superRoute",
NetworkType: route.IPv4Network,
Peer: peer1ID,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
}
testCases := []struct {
name string
existingRoute *route.Route
operations []RouteUpdateOperation
shouldCreate bool
errFunc require.ErrorAssertionFunc
expectedRoute *route.Route
}{
{
name: "Happy Path Single OPS",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRoutePeer,
Values: []string{peer2ID},
},
},
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
ID: routeID,
Network: netip.MustParsePrefix("192.168.0.0/16"),
NetID: "superRoute",
NetworkType: route.IPv4Network,
Peer: peer2ID,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
},
},
{
name: "Happy Path Multiple OPS",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRouteDescription,
Values: []string{"great"},
},
{
Type: UpdateRouteNetwork,
Values: []string{"192.168.0.0/24"},
},
{
Type: UpdateRoutePeer,
Values: []string{peer2ID},
},
{
Type: UpdateRouteMetric,
Values: []string{"3030"},
},
{
Type: UpdateRouteMasquerade,
Values: []string{"true"},
},
{
Type: UpdateRouteEnabled,
Values: []string{"false"},
},
{
Type: UpdateRouteNetworkIdentifier,
Values: []string{"megaRoute"},
},
{
Type: UpdateRouteGroups,
Values: []string{routeGroup2},
},
},
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
ID: routeID,
Network: netip.MustParsePrefix("192.168.0.0/24"),
NetID: "megaRoute",
NetworkType: route.IPv4Network,
Peer: peer2ID,
Description: "great",
Masquerade: true,
Metric: 3030,
Enabled: false,
Groups: []string{routeGroup2},
},
},
{
name: "Empty Values Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRoutePeer,
},
},
errFunc: require.Error,
},
{
name: "Multiple Values Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRoutePeer,
Values: []string{peer2ID, peer1ID},
},
},
errFunc: require.Error,
},
{
name: "Bad Prefix Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRouteNetwork,
Values: []string{"192.168.0.0/34"},
},
},
errFunc: require.Error,
},
{
name: "Bad Peer Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRoutePeer,
Values: []string{"non existing Peer"},
},
},
errFunc: require.Error,
},
{
name: "Empty Peer",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRoutePeer,
Values: []string{""},
},
},
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
ID: routeID,
Network: netip.MustParsePrefix("192.168.0.0/16"),
NetID: "superRoute",
NetworkType: route.IPv4Network,
Peer: "",
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
},
},
{
name: "Large Network ID Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRouteNetworkIdentifier,
Values: []string{"12345678901234567890qwertyuiopqwertyuiop1"},
},
},
errFunc: require.Error,
},
{
name: "Empty Network ID Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRouteNetworkIdentifier,
Values: []string{""},
},
},
errFunc: require.Error,
},
{
name: "Invalid Metric Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRouteMetric,
Values: []string{"999999"},
},
},
errFunc: require.Error,
},
{
name: "Invalid Boolean Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRouteMasquerade,
Values: []string{"yes"},
},
},
errFunc: require.Error,
},
{
name: "Invalid Group Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRouteGroups,
Values: []string{routeInvalidGroup1},
},
},
errFunc: require.Error,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
am, err := createRouterManager(t)
if err != nil {
t.Error("failed to create account manager")
}
account, err := initTestRouteAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
account.Routes[testCase.existingRoute.ID] = testCase.existingRoute
err = am.Store.SaveAccount(account)
if err != nil {
t.Error("account should be saved")
}
updatedRoute, err := am.UpdateRoute(account.Id, testCase.existingRoute.ID, testCase.operations)
testCase.errFunc(t, err)
if !testCase.shouldCreate {
return
}
testCase.expectedRoute.ID = updatedRoute.ID
if !testCase.expectedRoute.IsEqual(updatedRoute) {
t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", updatedRoute, testCase.expectedRoute)
}
})
}
}
func TestDeleteRoute(t *testing.T) {
testingRoute := &route.Route{
ID: "testingRoute",
@@ -940,7 +681,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore)
return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore, false)
}
func createRouterStore(t *testing.T) (Store, error) {

View File

@@ -2,6 +2,7 @@ package telemetry
import (
"context"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/metric/instrument"
"go.opentelemetry.io/otel/metric/instrument/syncint64"
@@ -13,6 +14,7 @@ type IDPMetrics struct {
getUserByEmailCounter syncint64.Counter
getAllAccountsCounter syncint64.Counter
createUserCounter syncint64.Counter
deleteUserCounter syncint64.Counter
getAccountCounter syncint64.Counter
getUserByIDCounter syncint64.Counter
authenticateRequestCounter syncint64.Counter
@@ -39,6 +41,10 @@ func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error)
if err != nil {
return nil, err
}
deleteUserCounter, err := meter.SyncInt64().Counter("management.idp.delete.user.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
getAccountCounter, err := meter.SyncInt64().Counter("management.idp.get.account.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
@@ -65,6 +71,7 @@ func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error)
getUserByEmailCounter: getUserByEmailCounter,
getAllAccountsCounter: getAllAccountsCounter,
createUserCounter: createUserCounter,
deleteUserCounter: deleteUserCounter,
getAccountCounter: getAccountCounter,
getUserByIDCounter: getUserByIDCounter,
authenticateRequestCounter: authenticateRequestCounter,
@@ -88,6 +95,11 @@ func (idpMetrics *IDPMetrics) CountCreateUser() {
idpMetrics.createUserCounter.Add(idpMetrics.ctx, 1)
}
// CountDeleteUser ...
func (idpMetrics *IDPMetrics) CountDeleteUser() {
idpMetrics.deleteUserCounter.Add(idpMetrics.ctx, 1)
}
// CountGetAllAccounts ...
func (idpMetrics *IDPMetrics) CountGetAllAccounts() {
idpMetrics.getAllAccountsCounter.Add(idpMetrics.ctx, 1)

View File

@@ -327,15 +327,43 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t
return status.Errorf(status.NotFound, "user not found")
}
if executingUser.Role != UserRoleAdmin {
return status.Errorf(status.PermissionDenied, "only admins can delete service users")
return status.Errorf(status.PermissionDenied, "only admins can delete users")
}
if !targetUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "regular users can not be deleted")
peers, err := account.FindUserPeers(targetUserID)
if err != nil {
return status.Errorf(status.Internal, "failed to find user peers")
}
meta := map[string]any{"name": targetUser.ServiceUserName}
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
if err := am.expireAndUpdatePeers(account, peers); err != nil {
log.Errorf("failed update deleted peers expiration: %s", err)
return err
}
targetUserEmail, err := am.getEmailOfTargetUser(account.Id, initiatorUserID, targetUserID)
if err != nil {
log.Errorf("failed to resolve email address: %s", err)
return err
}
var meta map[string]any
var eventAction activity.Activity
if targetUser.IsServiceUser {
meta = map[string]any{"name": targetUser.ServiceUserName}
eventAction = activity.ServiceUserDeleted
} else {
meta = map[string]any{"email": targetUserEmail}
eventAction = activity.UserDeleted
}
am.storeEvent(initiatorUserID, targetUserID, accountID, eventAction, meta)
if !isNil(am.idpManager) {
err := am.deleteUserFromIDP(targetUserID, accountID)
if err != nil {
return err
}
}
delete(account.Users, targetUserID)
@@ -609,23 +637,10 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd
if err != nil {
return nil, err
}
var peerIDs []string
for _, peer := range blockedPeers {
peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true)
account.UpdatePeer(peer)
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
if err != nil {
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
return nil, err
}
}
am.peersUpdateManager.CloseChannels(peerIDs)
err = am.updateAccountPeers(account)
if err != nil {
log.Errorf("failed updating account peers while expiring peers of a blocked user %s", accountID)
return nil, err
if err := am.expireAndUpdatePeers(account, blockedPeers); err != nil {
log.Errorf("failed update expired peers: %s", err)
return nil, err
}
}
@@ -814,6 +829,67 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
return userInfos, nil
}
// expireAndUpdatePeers expires all peers of the given user and updates them in the account
func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers []*Peer) error {
var peerIDs []string
for _, peer := range peers {
peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true)
account.UpdatePeer(peer)
if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil {
return err
}
am.storeEvent(
peer.UserID, peer.ID, account.Id,
activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()),
)
}
if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service
am.peersUpdateManager.CloseChannels(peerIDs)
if err := am.updateAccountPeers(account); err != nil {
return err
}
}
return nil
}
func (am *DefaultAccountManager) deleteUserFromIDP(targetUserID, accountID string) error {
if am.userDeleteFromIDPEnabled {
log.Debugf("user %s deleted from IdP", targetUserID)
err := am.idpManager.DeleteUser(targetUserID)
if err != nil {
return fmt.Errorf("failed to delete user %s from IdP: %s", targetUserID, err)
}
} else {
err := am.idpManager.UpdateUserAppMetadata(targetUserID, idp.AppMetadata{})
if err != nil {
return fmt.Errorf("failed to remove user %s app metadata in IdP: %s", targetUserID, err)
}
_, err = am.refreshCache(accountID)
if err != nil {
log.Errorf("refresh account (%q) cache: %v", accountID, err)
}
}
return nil
}
func (am *DefaultAccountManager) getEmailOfTargetUser(accountId string, initiatorId, targetId string) (string, error) {
userInfos, err := am.GetUsersFromAccount(accountId, initiatorId)
if err != nil {
return "", err
}
for _, ui := range userInfos {
if ui.ID == targetId {
return ui.Email, nil
}
}
return "", fmt.Errorf("email not found for user: %s", targetId)
}
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
for _, user := range userData {
if user.ID == userID {

View File

@@ -439,8 +439,9 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
}
err = am.DeleteUser(mockAccountID, mockUserID, mockUserID)
assert.Errorf(t, err, "Regular users can not be deleted (yet)")
if err != nil {
t.Errorf("unexpected error: %s", err)
}
}
func TestDefaultAccountManager_GetUser(t *testing.T) {