Compare commits

...

13 Commits

Author SHA1 Message Date
Zoltan Papp
da7b6b11ad Fix/user deletion (#1157)
Extend the deleted user info with the username
- Because initially, we did not store the user name in the activity db 
Sometimes, we can not provide the user name in the API response.

Fix service user deletion
  - In case of service user deletion, do not invoke the IdP delete function
  - Prevent self deletion
2023-09-23 10:47:49 +02:00
Maycon Santos
e260270825 Add direct write file to avoid moving docker mounted files (#1155)
Add a direct write to handle management.json write operation. 

Remove empty configuration types to avoid unnecessary fields in the generated management.json file.
2023-09-22 10:25:04 +02:00
Givi Khojanashvili
d4b6d7646c Handle user delete (#1113)
Implement user deletion across all IDP-ss. Expires all user peers
when the user is deleted. Users are permanently removed from a local
store, but in IDP, we remove Netbird attributes for the user
untilUserDeleteFromIDPEnabled setting is not enabled.

To test, an admin user should remove any additional users.

Until the UI incorporates this feature, use a curl DELETE request
targeting the /users/<USER_ID> management endpoint. Note that this
request only removes user attributes and doesn't trigger a delete
from the IDP.

To enable user removal from the IdP, set UserDeleteFromIDPEnabled
to true in account settings. Until we have a UI for this, make this
change directly in the store file.

Store the deleted email addresses in encrypted in activity store.
2023-09-19 18:08:40 +02:00
Bethuel Mmbaga
8febab4076 Improve Client Authentication (#1135)
* shutdown the pkce server on user cancellation

* Refactor openURL to exclusively manage authentication flow instructions and browser launching

* Refactor authentication flow initialization based on client OS

The NewOAuthFlow method now first checks the operating system and if it is a non-desktop Linux, it opts for Device Code Flow. PKCEFlow is tried first and if it fails, then it falls back on Device Code Flow. If both unsuccessful, the authentication process halts and error messages have been updated to provide more helpful feedback for troubleshooting authentication errors

* Replace log-based Linux desktop check with process check

To verify if a Linux OS is running a desktop environment in the Authentication utility, the log-based method that checks the XDG_CURRENT_DESKTOP env has been replaced with a method that checks directly if either X or Wayland display server processes are running. This method is more reliable as it directly checks for the display server process rather than relying on an environment variable that may not be set in all desktop environments.

* Refactor PKCE Authorization Flow to improve server handling

* refactor check for linux running desktop environment

* Improve server shutdown handling and encapsulate handlers with new server multiplexer

The changes enhance the way the server shuts down by specifying a context with timeout of 5 seconds, adding a safeguard to ensure the server halts even on potential hanging requests. Also, the server's root handler is now encapsulated within a new ServeMux instance, to support multiple registrations of a path
2023-09-19 19:06:18 +03:00
Zoltan Papp
34e2c6b943 Fix sso check (#1152)
Fix SSO check

- change the order of the PKCE and device auth flow check, prefer PKCE
- fix error handling in PKCE check
2023-09-18 16:04:53 +02:00
Yury Gargay
0be8c72601 Remove unused methods from AccountManager interface (#1149)
This PR removes the following unused methods from the AccountManager interface:
* `UpdateGroup`
* `UpdateNameServerGroup`
* `UpdateRoute`
2023-09-18 12:25:12 +02:00
Maycon Santos
c34e53477f Add signal port tests to CI workflow (#1148) 2023-09-14 17:01:14 +02:00
Fabio Fantoni
8d18190c94 fix NETBIRD_SIGNAL_PORT not working with custom port (#1143) (#1145)
Use NETBIRD_SIGNAL_PORT variable instead of the static port for signal
container in the docker-compose template to make setting of custom
signal port working

Signed-off-by: Fabio Fantoni <fabio.fantoni@m2r.biz>
2023-09-14 15:58:28 +02:00
Zoltan Papp
06bec61be9 Add Android test build (#1144)
Extend the CI with gomobile build.
With this step we can validate that the code can run on Android
2023-09-13 17:58:12 +02:00
Zoltan Papp
2135533f1d Fix Android build (#1142)
The source code files related to the Android firewall had incorrect build tags.
2023-09-13 17:36:24 +02:00
Bethuel Mmbaga
bb791d59f3 update check for linux running desktop (#1137) 2023-09-08 20:08:02 +02:00
Maycon Santos
30f1c54ed1 Fix: docker test for infrastructure files (#1136)
* Fix: docker test for infrastructure files

* Fix: docker test for infrastructure files
2023-09-08 19:28:34 +02:00
Maycon Santos
5c8541ef42 Set not found ebpf log to Info (#1134)
added an additional log event
2023-09-08 18:24:19 +02:00
56 changed files with 1115 additions and 1328 deletions

View File

@@ -0,0 +1,41 @@
name: Android build validation
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v4
with:
go-version: "1.20.x"
- name: Setup Android SDK
uses: android-actions/setup-android@v2
- name: NDK Cache
id: ndk-cache
uses: actions/cache@v3
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
- name: Setup NDK
run: /usr/local/lib/android/sdk/tools/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
- name: gomobile init
run: gomobile init
- name: build android nebtird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
env:
CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620

View File

@@ -80,6 +80,7 @@ jobs:
CI_NETBIRD_MGMT_IDP: "none" CI_NETBIRD_MGMT_IDP: "none"
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345
run: | run: |
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
@@ -91,6 +92,7 @@ jobs:
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073" grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$' grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
grep $CI_NETBIRD_SIGNAL_PORT docker-compose.yml | grep ':80'
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$' grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
@@ -120,7 +122,7 @@ jobs:
- name: test running containers - name: test running containers
run: | run: |
count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running) count=$(docker compose ps --format json | jq '. | select(.Name | contains("infrastructure_files")) | .State' | grep -c running)
test $count -eq 4 test $count -eq 4
working-directory: infrastructure_files working-directory: infrastructure_files

View File

@@ -84,10 +84,14 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
func (a *Auth) saveConfigIfSSOSupported() (bool, error) { func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) { err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound { s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false supportsSSO = false
err = nil err = nil
} }

View File

@@ -3,8 +3,6 @@ package cmd
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"runtime"
"strings" "strings"
"time" "time"
@@ -195,60 +193,12 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
} }
browserAuthMsg := "Please do the SSO login in your browser. \n" + cmd.Println("Please do the SSO login in your browser. \n" +
"If your browser didn't open automatically, use this URL to log in:\n\n" + "If your browser didn't open automatically, use this URL to log in:\n\n" +
verificationURIComplete + " " + codeMsg verificationURIComplete + " " + codeMsg)
cmd.Println("")
setupKeyAuthMsg := "\nAlternatively, you may want to use a setup key, see:\n\n" + if err := open.Run(verificationURIComplete); err != nil {
"https://docs.netbird.io/how-to/register-machines-using-setup-keys" cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
authenticateUsingBrowser := func() {
cmd.Println(browserAuthMsg)
cmd.Println("")
if err := open.Run(verificationURIComplete); err != nil {
cmd.Println(setupKeyAuthMsg)
}
}
switch runtime.GOOS {
case "windows", "darwin":
authenticateUsingBrowser()
case "linux":
if isLinuxRunningDesktop() {
authenticateUsingBrowser()
} else {
// If current flow is PKCE, it implies the server is anticipating the redirect to localhost.
// Devices lacking browser support are incompatible with this flow.Therefore,
// these devices will need to resort to setup keys instead.
if isPKCEFlow(verificationURIComplete) {
cmd.Println("Please proceed with setting up this device using setup keys, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
} else {
cmd.Println(browserAuthMsg)
}
}
} }
} }
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment.
func isLinuxRunningDesktop() bool {
for _, env := range os.Environ() {
values := strings.Split(env, "=")
if len(values) == 2 {
key, value := values[0], values[1]
if key == "XDG_CURRENT_DESKTOP" && value != "" {
return true
}
}
}
return false
}
// isPKCEFlow determines if the PKCE flow is active or not,
// by checking the existence of redirect_uri inside the verification URL.
func isPKCEFlow(verificationURL string) bool {
if verificationURL == "" {
return false
}
return strings.Contains(verificationURL, "redirect_uri")
}

View File

@@ -76,7 +76,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
return nil, nil return nil, nil
} }
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "", accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore) eventStore, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -1,4 +1,4 @@
//go:build !linux //go:build !linux || android
package acl package acl

View File

@@ -1,3 +1,5 @@
//go:build !android
package acl package acl
import ( import (

View File

@@ -4,8 +4,8 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"runtime"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
@@ -57,25 +57,43 @@ func (t TokenInfo) GetTokenToUse() string {
return t.AccessToken return t.AccessToken
} }
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration. // NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
//
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
// and if that also fails, the authentication process is deemed unsuccessful
//
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
log.Debug("loading pkce authorization flow info") if runtime.GOOS == "linux" && !isLinuxRunningDesktop() {
return authenticateWithDeviceCodeFlow(ctx, config)
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err == nil {
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
} }
log.Debugf("loading pkce authorization flow info failed with error: %v", err) pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
log.Debugf("falling back to device authorization flow info") if err != nil {
// fallback to device code flow
return authenticateWithDeviceCodeFlow(ctx, config)
}
return pkceFlow, nil
}
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
}
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil { if err != nil {
s, ok := gstatus.FromError(err) s, ok := gstatus.FromError(err)
if ok && s.Code() == codes.NotFound { if ok && s.Code() == codes.NotFound {
return nil, fmt.Errorf("no SSO provider returned from management. " + return nil, fmt.Errorf("no SSO provider returned from management. " +
"If you are using hosting Netbird see documentation at " + "Please proceed with setting up this device using setup keys " +
"https://github.com/netbirdio/netbird/tree/main/management for details") "https://docs.netbird.io/how-to/register-machines-using-setup-keys")
} else if ok && s.Code() == codes.Unimplemented { } else if ok && s.Code() == codes.Unimplemented {
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+ return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
"please update your server or use Setup Keys to login", config.ManagementURL) "please update your server or use Setup Keys to login", config.ManagementURL)

View File

@@ -12,7 +12,6 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -80,7 +79,7 @@ func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
} }
// RequestAuthInfo requests a authorization code login flow information. // RequestAuthInfo requests a authorization code login flow information.
func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) { func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
state, err := randomBytesInHex(24) state, err := randomBytesInHex(24)
if err != nil { if err != nil {
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err) return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
@@ -114,64 +113,37 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
tokenChan := make(chan *oauth2.Token, 1) tokenChan := make(chan *oauth2.Token, 1)
errChan := make(chan error, 1) errChan := make(chan error, 1)
go p.startServer(tokenChan, errChan) parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
if err != nil {
return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err)
}
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
defer func() {
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
log.Errorf("failed to close the server: %v", err)
}
}()
go p.startServer(server, tokenChan, errChan)
select { select {
case <-ctx.Done(): case <-ctx.Done():
return TokenInfo{}, ctx.Err() return TokenInfo{}, ctx.Err()
case token := <-tokenChan: case token := <-tokenChan:
return p.handleOAuthToken(token) return p.parseOAuthToken(token)
case err := <-errChan: case err := <-errChan:
return TokenInfo{}, err return TokenInfo{}, err
} }
} }
func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) { func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
var wg sync.WaitGroup mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL) token, err := p.handleRequest(req)
if err != nil {
errChan <- fmt.Errorf("failed to parse redirect URL: %v", err)
return
}
server := http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
go func() {
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
errChan <- err
}
}()
wg.Add(1)
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
defer wg.Done()
tokenValidatorFunc := func() (*oauth2.Token, error) {
query := req.URL.Query()
if authError := query.Get(queryError); authError != "" {
authErrorDesc := query.Get(queryErrorDesc)
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
}
// Prevent timing attacks on state
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
return nil, fmt.Errorf("invalid state")
}
code := query.Get(queryCode)
if code == "" {
return nil, fmt.Errorf("missing code")
}
return p.oAuthConfig.Exchange(
req.Context(),
code,
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
)
}
token, err := tokenValidatorFunc()
if err != nil { if err != nil {
renderPKCEFlowTmpl(w, err) renderPKCEFlowTmpl(w, err)
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err) errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
@@ -182,13 +154,38 @@ func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errC
tokenChan <- token tokenChan <- token
}) })
wg.Wait() server.Handler = mux
if err := server.Shutdown(context.Background()); err != nil { if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Errorf("error while shutting down pkce flow server: %v", err) errChan <- err
} }
} }
func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) { func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) {
query := req.URL.Query()
if authError := query.Get(queryError); authError != "" {
authErrorDesc := query.Get(queryErrorDesc)
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
}
// Prevent timing attacks on the state
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
return nil, fmt.Errorf("invalid state")
}
code := query.Get(queryCode)
if code == "" {
return nil, fmt.Errorf("missing code")
}
return p.oAuthConfig.Exchange(
req.Context(),
code,
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
)
}
func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) {
tokenInfo := TokenInfo{ tokenInfo := TokenInfo{
AccessToken: token.AccessToken, AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken, RefreshToken: token.RefreshToken,

View File

@@ -7,6 +7,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"os"
"reflect" "reflect"
"strings" "strings"
) )
@@ -60,3 +61,8 @@ func isValidAccessToken(token string, audience string) error {
return fmt.Errorf("invalid JWT token audience field") return fmt.Errorf("invalid JWT token audience field")
} }
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
func isLinuxRunningDesktop() bool {
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
}

View File

@@ -1,3 +1,3 @@
//go:build !linux //go:build !linux || android
package checkfw package checkfw

View File

@@ -1049,7 +1049,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
return nil, "", err return nil, "", err
} }
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore) eventStore, false)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -135,6 +135,7 @@ func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
} }
p.removeTurnConn(endpointPort) p.removeTurnConn(endpointPort)
log.Infof("stop forward turn packages to port: %d. error: %s", endpointPort, err)
return return
} }
err = p.sendPkg(buf[:n], endpointPort) err = p.sendPkg(buf[:n], endpointPort)
@@ -158,7 +159,7 @@ func (p *WGEBPFProxy) proxyToRemote() {
conn, ok := p.turnConnStore[uint16(addr.Port)] conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock() p.turnConnMutex.Unlock()
if !ok { if !ok {
log.Errorf("turn conn not found by port: %d", addr.Port) log.Infof("turn conn not found by port: %d", addr.Port)
continue continue
} }

View File

@@ -36,7 +36,7 @@ services:
volumes: volumes:
- $SIGNAL_VOLUMENAME:/var/lib/netbird - $SIGNAL_VOLUMENAME:/var/lib/netbird
ports: ports:
- 10000:80 - $NETBIRD_SIGNAL_PORT:80
# # port and command for Let's Encrypt validation # # port and command for Let's Encrypt validation
# - 443:443 # - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]

View File

@@ -21,4 +21,5 @@ NETBIRD_AUTH_USER_ID_CLAIM="email"
NETBIRD_AUTH_DEVICE_AUTH_SCOPE="openid email" NETBIRD_AUTH_DEVICE_AUTH_SCOPE="openid email"
NETBIRD_MGMT_IDP=$CI_NETBIRD_MGMT_IDP NETBIRD_MGMT_IDP=$CI_NETBIRD_MGMT_IDP
NETBIRD_IDP_MGMT_CLIENT_ID=$CI_NETBIRD_IDP_MGMT_CLIENT_ID NETBIRD_IDP_MGMT_CLIENT_ID=$CI_NETBIRD_IDP_MGMT_CLIENT_ID
NETBIRD_IDP_MGMT_CLIENT_SECRET=$CI_NETBIRD_IDP_MGMT_CLIENT_SECRET NETBIRD_IDP_MGMT_CLIENT_SECRET=$CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
NETBIRD_SIGNAL_PORT=12345

View File

@@ -61,7 +61,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
peersUpdateManager := mgmt.NewPeersUpdateManager() peersUpdateManager := mgmt.NewPeersUpdateManager()
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "", accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore) eventStore, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -31,6 +31,7 @@ import (
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/activity/sqlite" "github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http" httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
@@ -142,12 +143,22 @@ var (
if disableSingleAccMode { if disableSingleAccMode {
mgmtSingleAccModeDomain = "" mgmtSingleAccModeDomain = ""
} }
eventStore, err := sqlite.NewSQLiteStore(config.Datadir) eventStore, key, err := initEventStore(config.Datadir, config.DataStoreEncryptionKey)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to initialize database: %s", err)
} }
if key != "" {
log.Infof("update config with activity store key")
config.DataStoreEncryptionKey = key
err := updateMgmtConfig(mgmtConfig, config)
if err != nil {
return fmt.Errorf("failed to write out store encryption key: %s", err)
}
}
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore) dnsDomain, eventStore, userDeleteFromIDPEnabled)
if err != nil { if err != nil {
return fmt.Errorf("failed to build default manager: %v", err) return fmt.Errorf("failed to build default manager: %v", err)
} }
@@ -287,6 +298,20 @@ var (
} }
) )
func initEventStore(dataDir string, key string) (activity.Store, string, error) {
var err error
if key == "" {
log.Debugf("generate new activity store encryption key")
key, err = sqlite.GenerateKey()
if err != nil {
return nil, "", err
}
}
store, err := sqlite.NewSQLiteStore(dataDir, key)
return store, key, err
}
func notifyStop(msg string) { func notifyStop(msg string) {
select { select {
case stopCh <- 1: case stopCh <- 1:
@@ -440,6 +465,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
return loadedConfig, err return loadedConfig, err
} }
func updateMgmtConfig(path string, config *server.Config) error {
return util.DirectWriteJson(path, config)
}
// OIDCConfigResponse used for parsing OIDC config response // OIDCConfigResponse used for parsing OIDC config response
type OIDCConfigResponse struct { type OIDCConfigResponse struct {
Issuer string `json:"issuer"` Issuer string `json:"issuer"`

View File

@@ -24,6 +24,7 @@ var (
disableMetrics bool disableMetrics bool
disableSingleAccMode bool disableSingleAccMode bool
idpSignKeyRefreshEnabled bool idpSignKeyRefreshEnabled bool
userDeleteFromIDPEnabled bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird-mgmt", Use: "netbird-mgmt",
@@ -56,6 +57,7 @@ func init() {
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird") mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain)) mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.") mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.")
mgmtCmd.Flags().BoolVar(&userDeleteFromIDPEnabled, "user-delete-from-idp", false, "Allows to delete user from IDP when user is deleted from account")
rootCmd.MarkFlagRequired("config") //nolint rootCmd.MarkFlagRequired("config") //nolint
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "") rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "")

View File

@@ -80,7 +80,6 @@ type AccountManager interface {
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
GetGroup(accountId, groupID string) (*Group, error) GetGroup(accountId, groupID string) (*Group, error)
SaveGroup(accountID, userID string, group *Group) error SaveGroup(accountID, userID string, group *Group) error
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
DeleteGroup(accountId, userId, groupID string) error DeleteGroup(accountId, userId, groupID string) error
ListGroups(accountId string) ([]*Group, error) ListGroups(accountId string) ([]*Group, error)
GroupAddPeer(accountId, groupID, peerID string) error GroupAddPeer(accountId, groupID, peerID string) error
@@ -93,13 +92,11 @@ type AccountManager interface {
GetRoute(accountID, routeID, userID string) (*route.Route, error) GetRoute(accountID, routeID, userID string) (*route.Route, error)
CreateRoute(accountID string, prefix, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) CreateRoute(accountID string, prefix, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
SaveRoute(accountID, userID string, route *route.Route) error SaveRoute(accountID, userID string, route *route.Route) error
UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
DeleteRoute(accountID, routeID, userID string) error DeleteRoute(accountID, routeID, userID string) error
ListRoutes(accountID, userID string) ([]*route.Route, error) ListRoutes(accountID, userID string) ([]*route.Route, error)
GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
DeleteNameServerGroup(accountID, nsGroupID, userID string) error DeleteNameServerGroup(accountID, nsGroupID, userID string) error
ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error)
GetDNSDomain() string GetDNSDomain() string
@@ -133,6 +130,9 @@ type DefaultAccountManager struct {
// dnsDomain is used for peer resolution. This is appended to the peer's name // dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string dnsDomain string
peerLoginExpiry Scheduler peerLoginExpiry Scheduler
// userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account
userDeleteFromIDPEnabled bool
} }
// Settings represents Account settings structure that can be modified via API and Dashboard // Settings represents Account settings structure that can be modified via API and Dashboard
@@ -738,18 +738,19 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
// BuildManager creates a new DefaultAccountManager with a provided Store // BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, userDeleteFromIDPEnabled bool,
) (*DefaultAccountManager, error) { ) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{ am := &DefaultAccountManager{
Store: store, Store: store,
peersUpdateManager: peersUpdateManager, peersUpdateManager: peersUpdateManager,
idpManager: idpManager, idpManager: idpManager,
ctx: context.Background(), ctx: context.Background(),
cacheMux: sync.Mutex{}, cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{}, cacheLoading: map[string]chan struct{}{},
dnsDomain: dnsDomain, dnsDomain: dnsDomain,
eventStore: eventStore, eventStore: eventStore,
peerLoginExpiry: NewDefaultScheduler(), peerLoginExpiry: NewDefaultScheduler(),
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
} }
allAccounts := store.GetAllAccounts() allAccounts := store.GetAllAccounts()
// enable single account mode only if configured by user and number of existing accounts is not grater than 1 // enable single account mode only if configured by user and number of existing accounts is not grater than 1
@@ -874,33 +875,19 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func()
return account.GetNextPeerExpiration() return account.GetNextPeerExpiration()
} }
expiredPeers := account.GetExpiredPeers()
var peerIDs []string var peerIDs []string
for _, peer := range account.GetExpiredPeers() { for _, peer := range expiredPeers {
if peer.Status.LoginExpired {
continue
}
peerIDs = append(peerIDs, peer.ID) peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true)
account.UpdatePeer(peer)
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
if err != nil {
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
return account.GetNextPeerExpiration()
}
am.storeEvent(peer.UserID, peer.ID, account.Id, activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()))
} }
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
if len(peerIDs) != 0 { if err := am.expireAndUpdatePeers(account, expiredPeers); err != nil {
// this will trigger peer disconnect from the management service log.Errorf("failed updating account peers while expiring peers for account %s", account.Id)
am.peersUpdateManager.CloseChannels(peerIDs) return account.GetNextPeerExpiration()
err = am.updateAccountPeers(account)
if err != nil {
log.Errorf("failed updating account peers while expiring peers for account %s", accountID)
return account.GetNextPeerExpiration()
}
} }
return account.GetNextPeerExpiration() return account.GetNextPeerExpiration()
} }
} }
@@ -1605,19 +1592,3 @@ func newAccountWithId(accountID, userID, domain string) *Account {
} }
return acc return acc
} }
func removeFromList(inputList []string, toRemove []string) []string {
toRemoveMap := make(map[string]struct{})
for _, item := range toRemove {
toRemoveMap[item] = struct{}{}
}
var resultList []string
for _, item := range inputList {
_, ok := toRemoveMap[item]
if !ok {
resultList = append(resultList, item)
}
}
return resultList
}

View File

@@ -2063,7 +2063,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err return nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore) return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore, false)
} }
func createStore(t *testing.T) (Store, error) { func createStore(t *testing.T) (Store, error) {

View File

@@ -104,6 +104,8 @@ const (
UserBlocked UserBlocked
// UserUnblocked indicates that a user unblocked another user // UserUnblocked indicates that a user unblocked another user
UserUnblocked UserUnblocked
// UserDeleted indicates that a user deleted another user
UserDeleted
// GroupDeleted indicates that a user deleted group // GroupDeleted indicates that a user deleted group
GroupDeleted GroupDeleted
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login // UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
@@ -162,6 +164,7 @@ var activityMap = map[Activity]Code{
ServiceUserDeleted: {"Service user deleted", "service.user.delete"}, ServiceUserDeleted: {"Service user deleted", "service.user.delete"},
UserBlocked: {"User blocked", "user.block"}, UserBlocked: {"User blocked", "user.block"},
UserUnblocked: {"User unblocked", "user.unblock"}, UserUnblocked: {"User unblocked", "user.unblock"},
UserDeleted: {"User deleted", "user.delete"},
GroupDeleted: {"Group deleted", "group.delete"}, GroupDeleted: {"Group deleted", "group.delete"},
UserLoggedInPeer: {"User logged in peer", "user.peer.login"}, UserLoggedInPeer: {"User logged in peer", "user.peer.login"},
PeerLoginExpired: {"Peer login expired", "peer.login.expire"}, PeerLoginExpired: {"Peer login expired", "peer.login.expire"},

View File

@@ -18,10 +18,15 @@ type Event struct {
ID uint64 ID uint64
// InitiatorID is the ID of an object that initiated the event (e.g., a user) // InitiatorID is the ID of an object that initiated the event (e.g., a user)
InitiatorID string InitiatorID string
// InitiatorName is the name of an object that initiated the event.
InitiatorName string
// InitiatorEmail is the email address of an object that initiated the event.
InitiatorEmail string
// TargetID is the ID of an object that was effected by the event (e.g., a peer) // TargetID is the ID of an object that was effected by the event (e.g., a peer)
TargetID string TargetID string
// AccountID is the ID of an account where the event happened // AccountID is the ID of an account where the event happened
AccountID string AccountID string
// Meta of the event, e.g. deleted peer information like name, IP, etc // Meta of the event, e.g. deleted peer information like name, IP, etc
Meta map[string]any Meta map[string]any
} }
@@ -35,12 +40,14 @@ func (e *Event) Copy() *Event {
} }
return &Event{ return &Event{
Timestamp: e.Timestamp, Timestamp: e.Timestamp,
Activity: e.Activity, Activity: e.Activity,
ID: e.ID, ID: e.ID,
InitiatorID: e.InitiatorID, InitiatorID: e.InitiatorID,
TargetID: e.TargetID, InitiatorName: e.InitiatorName,
AccountID: e.AccountID, InitiatorEmail: e.InitiatorEmail,
Meta: meta, TargetID: e.TargetID,
AccountID: e.AccountID,
Meta: meta,
} }
} }

View File

@@ -0,0 +1,81 @@
package sqlite
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
)
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
type FieldEncrypt struct {
block cipher.Block
}
func GenerateKey() (string, error) {
key := make([]byte, 32)
_, err := rand.Read(key)
if err != nil {
return "", err
}
readableKey := base64.StdEncoding.EncodeToString(key)
return readableKey, nil
}
func NewFieldEncrypt(key string) (*FieldEncrypt, error) {
binKey, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(binKey)
if err != nil {
return nil, err
}
ec := &FieldEncrypt{
block: block,
}
return ec, nil
}
func (ec *FieldEncrypt) Encrypt(payload string) string {
plainText := pkcs5Padding([]byte(payload))
cipherText := make([]byte, len(plainText))
cbc := cipher.NewCBCEncrypter(ec.block, iv)
cbc.CryptBlocks(cipherText, plainText)
return base64.StdEncoding.EncodeToString(cipherText)
}
func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
}
cbc := cipher.NewCBCDecrypter(ec.block, iv)
cbc.CryptBlocks(cipherText, cipherText)
payload, err := pkcs5UnPadding(cipherText)
if err != nil {
return "", err
}
return string(payload), nil
}
func pkcs5Padding(ciphertext []byte) []byte {
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padText...)
}
func pkcs5UnPadding(src []byte) ([]byte, error) {
srcLen := len(src)
paddingLen := int(src[srcLen-1])
if paddingLen >= srcLen || paddingLen > aes.BlockSize {
return nil, fmt.Errorf("padding size error")
}
return src[:srcLen-paddingLen], nil
}

View File

@@ -0,0 +1,63 @@
package sqlite
import (
"testing"
)
func TestGenerateKey(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.Encrypt(testData)
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
decrypted, err := ee.Decrypt(encrypted)
if err != nil {
t.Fatalf("failed to decrypt data: %s", err)
}
if decrypted != testData {
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
}
}
func TestCorruptKey(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.Encrypt(testData)
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
newKey, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err = NewFieldEncrypt(newKey)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
res, _ := ee.Decrypt(encrypted)
if res == testData {
t.Fatalf("incorrect decryption, the result is: %s", res)
}
}

View File

@@ -3,14 +3,14 @@ package sqlite
import ( import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt"
"github.com/netbirdio/netbird/management/server/activity"
// sqlite driver
"path/filepath" "path/filepath"
"time" "time"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
) )
const ( const (
@@ -25,69 +25,122 @@ const (
"meta TEXT," + "meta TEXT," +
" target_id TEXT);" " target_id TEXT);"
selectDescQuery = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" + creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`
" FROM events WHERE account_id = ? ORDER BY timestamp DESC LIMIT ? OFFSET ?;"
selectAscQuery = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" + selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
" FROM events WHERE account_id = ? ORDER BY timestamp ASC LIMIT ? OFFSET ?;" FROM events
LEFT JOIN deleted_users i ON events.initiator_id = i.id
LEFT JOIN deleted_users t ON events.target_id = t.id
WHERE account_id = ?
ORDER BY timestamp DESC LIMIT ? OFFSET ?;`
selectAscQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
FROM events
LEFT JOIN deleted_users i ON events.initiator_id = i.id
LEFT JOIN deleted_users t ON events.target_id = t.id
WHERE account_id = ?
ORDER BY timestamp ASC LIMIT ? OFFSET ?;`
insertQuery = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " + insertQuery = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " +
"VALUES(?, ?, ?, ?, ?, ?)" "VALUES(?, ?, ?, ?, ?, ?)"
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`
) )
// Store is the implementation of the activity.Store interface backed by SQLite // Store is the implementation of the activity.Store interface backed by SQLite
type Store struct { type Store struct {
db *sql.DB db *sql.DB
fieldEncrypt *FieldEncrypt
insertStatement *sql.Stmt insertStatement *sql.Stmt
selectAscStatement *sql.Stmt selectAscStatement *sql.Stmt
selectDescStatement *sql.Stmt selectDescStatement *sql.Stmt
deleteUserStmt *sql.Stmt
} }
// NewSQLiteStore creates a new Store with an event table if not exists. // NewSQLiteStore creates a new Store with an event table if not exists.
func NewSQLiteStore(dataDir string) (*Store, error) { func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
dbFile := filepath.Join(dataDir, eventSinkDB) dbFile := filepath.Join(dataDir, eventSinkDB)
db, err := sql.Open("sqlite3", dbFile) db, err := sql.Open("sqlite3", dbFile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
crypt, err := NewFieldEncrypt(encryptionKey)
if err != nil {
_ = db.Close()
return nil, err
}
_, err = db.Exec(createTableQuery) _, err = db.Exec(createTableQuery)
if err != nil { if err != nil {
_ = db.Close()
return nil, err
}
_, err = db.Exec(creatTableDeletedUsersQuery)
if err != nil {
_ = db.Close()
return nil, err
}
err = updateDeletedUsersTable(db)
if err != nil {
_ = db.Close()
return nil, err return nil, err
} }
insertStmt, err := db.Prepare(insertQuery) insertStmt, err := db.Prepare(insertQuery)
if err != nil { if err != nil {
_ = db.Close()
return nil, err return nil, err
} }
selectDescStmt, err := db.Prepare(selectDescQuery) selectDescStmt, err := db.Prepare(selectDescQuery)
if err != nil { if err != nil {
_ = db.Close()
return nil, err return nil, err
} }
selectAscStmt, err := db.Prepare(selectAscQuery) selectAscStmt, err := db.Prepare(selectAscQuery)
if err != nil { if err != nil {
_ = db.Close()
return nil, err return nil, err
} }
return &Store{ deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
if err != nil {
_ = db.Close()
return nil, err
}
s := &Store{
db: db, db: db,
fieldEncrypt: crypt,
insertStatement: insertStmt, insertStatement: insertStmt,
selectDescStatement: selectDescStmt, selectDescStatement: selectDescStmt,
selectAscStatement: selectAscStmt, selectAscStatement: selectAscStmt,
}, nil deleteUserStmt: deleteUserStmt,
}
return s, nil
} }
func processResult(result *sql.Rows) ([]*activity.Event, error) { func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
events := make([]*activity.Event, 0) events := make([]*activity.Event, 0)
for result.Next() { for result.Next() {
var id int64 var id int64
var operation activity.Activity var operation activity.Activity
var timestamp time.Time var timestamp time.Time
var initiator string var initiator string
var initiatorName *string
var initiatorEmail *string
var target string var target string
var targetUserName *string
var targetEmail *string
var account string var account string
var jsonMeta string var jsonMeta string
err := result.Scan(&id, &operation, &timestamp, &initiator, &target, &account, &jsonMeta) err := result.Scan(&id, &operation, &timestamp, &initiator, &initiatorName, &initiatorEmail, &target, &targetUserName, &targetEmail, &account, &jsonMeta)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -100,7 +153,27 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) {
} }
} }
events = append(events, &activity.Event{ if targetUserName != nil {
name, err := store.fieldEncrypt.Decrypt(*targetUserName)
if err != nil {
log.Errorf("failed to decrypt username for target id: %s", target)
meta["username"] = ""
} else {
meta["username"] = name
}
}
if targetEmail != nil {
email, err := store.fieldEncrypt.Decrypt(*targetEmail)
if err != nil {
log.Errorf("failed to decrypt email address for target id: %s", target)
meta["email"] = ""
} else {
meta["email"] = email
}
}
event := &activity.Event{
Timestamp: timestamp, Timestamp: timestamp,
Activity: operation, Activity: operation,
ID: uint64(id), ID: uint64(id),
@@ -108,7 +181,27 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) {
TargetID: target, TargetID: target,
AccountID: account, AccountID: account,
Meta: meta, Meta: meta,
}) }
if initiatorName != nil {
name, err := store.fieldEncrypt.Decrypt(*initiatorName)
if err != nil {
log.Errorf("failed to decrypt username of initiator: %s", initiator)
} else {
event.InitiatorName = name
}
}
if initiatorEmail != nil {
email, err := store.fieldEncrypt.Decrypt(*initiatorEmail)
if err != nil {
log.Errorf("failed to decrypt email address of initiator: %s", initiator)
} else {
event.InitiatorEmail = email
}
}
events = append(events, event)
} }
return events, nil return events, nil
@@ -127,13 +220,18 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([
} }
defer result.Close() //nolint defer result.Close() //nolint
return processResult(result) return store.processResult(result)
} }
// Save an event in the SQLite events table // Save an event in the SQLite events table end encrypt the "email" element in meta map
func (store *Store) Save(event *activity.Event) (*activity.Event, error) { func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
var jsonMeta string var jsonMeta string
if event.Meta != nil { meta, err := store.saveDeletedUserEmailAndNameInEncrypted(event)
if err != nil {
return nil, err
}
if meta != nil {
metaBytes, err := json.Marshal(event.Meta) metaBytes, err := json.Marshal(event.Meta)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -156,6 +254,34 @@ func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
return eventCopy, nil return eventCopy, nil
} }
// saveDeletedUserEmailAndNameInEncrypted if the meta contains email and name then store it in encrypted way and delete
// this item from meta map
func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event) (map[string]any, error) {
email, ok := event.Meta["email"]
if !ok {
return event.Meta, nil
}
name, ok := event.Meta["name"]
if !ok {
return event.Meta, nil
}
encryptedEmail := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
encryptedName := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
_, err := store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName)
if err != nil {
return nil, err
}
if len(event.Meta) == 2 {
return nil, nil // nolint
}
delete(event.Meta, "email")
delete(event.Meta, "name")
return event.Meta, nil
}
// Close the Store // Close the Store
func (store *Store) Close() error { func (store *Store) Close() error {
if store.db != nil { if store.db != nil {
@@ -163,3 +289,44 @@ func (store *Store) Close() error {
} }
return nil return nil
} }
func updateDeletedUsersTable(db *sql.DB) error {
log.Debugf("check deleted_users table version")
rows, err := db.Query(`PRAGMA table_info(deleted_users);`)
if err != nil {
return err
}
defer rows.Close()
found := false
for rows.Next() {
var (
cid int
name string
dataType string
notNull int
dfltVal sql.NullString
pk int
)
err := rows.Scan(&cid, &name, &dataType, &notNull, &dfltVal, &pk)
if err != nil {
return err
}
if name == "name" {
found = true
break
}
}
err = rows.Err()
if err != nil {
return err
}
if found {
return nil
}
log.Debugf("update delted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
return err
}

View File

@@ -12,7 +12,8 @@ import (
func TestNewSQLiteStore(t *testing.T) { func TestNewSQLiteStore(t *testing.T) {
dataDir := t.TempDir() dataDir := t.TempDir()
store, err := NewSQLiteStore(dataDir) key, _ := GenerateKey()
store, err := NewSQLiteStore(dataDir, key)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return

View File

@@ -35,7 +35,8 @@ type Config struct {
TURNConfig *TURNConfig TURNConfig *TURNConfig
Signal *Host Signal *Host
Datadir string Datadir string
DataStoreEncryptionKey string
HttpConfig *HttpServerConfig HttpConfig *HttpServerConfig

View File

@@ -191,7 +191,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err return nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.test", eventStore) return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.test", eventStore, false)
} }
func createDNSStore(t *testing.T) (Store, error) { func createDNSStore(t *testing.T) (Store, error) {

View File

@@ -33,26 +33,6 @@ type Group struct {
Peers []string Peers []string
} }
const (
// UpdateGroupName indicates a name update operation
UpdateGroupName GroupUpdateOperationType = iota
// InsertPeersToGroup indicates insert peers to group operation
InsertPeersToGroup
// RemovePeersFromGroup indicates a remove peers from group operation
RemovePeersFromGroup
// UpdateGroupPeers indicates a replacement of group peers list
UpdateGroupPeers
)
// GroupUpdateOperationType operation type
type GroupUpdateOperationType int
// GroupUpdateOperation operation object with type and values to be applied
type GroupUpdateOperation struct {
Type GroupUpdateOperationType
Values []string
}
// EventMeta returns activity event meta related to the group // EventMeta returns activity event meta related to the group
func (g *Group) EventMeta() map[string]any { func (g *Group) EventMeta() map[string]any {
return map[string]any{"name": g.Name} return map[string]any{"name": g.Name}
@@ -165,57 +145,6 @@ func difference(a, b []string) []string {
return diff return diff
} }
// UpdateGroup updates a group using a list of operations
func (am *DefaultAccountManager) UpdateGroup(accountID string,
groupID string, operations []GroupUpdateOperation,
) (*Group, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
groupToUpdate, ok := account.Groups[groupID]
if !ok {
return nil, status.Errorf(status.NotFound, "group with ID %s no longer exists", groupID)
}
group := groupToUpdate.Copy()
for _, operation := range operations {
switch operation.Type {
case UpdateGroupName:
group.Name = operation.Values[0]
case UpdateGroupPeers:
group.Peers = operation.Values
case InsertPeersToGroup:
sourceList := group.Peers
resultList := removeFromList(sourceList, operation.Values)
group.Peers = append(resultList, operation.Values...)
case RemovePeersFromGroup:
sourceList := group.Peers
resultList := removeFromList(sourceList, operation.Values)
group.Peers = resultList
}
}
account.Groups[groupID] = group
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
return nil, err
}
err = am.updateAccountPeers(account)
if err != nil {
return nil, err
}
return group, nil
}
// DeleteGroup object of the peers // DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error { func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
unlock := am.Store.AcquireAccountLock(accountId) unlock := am.Store.AcquireAccountLock(accountId)

0
management/server/http/api/generate.sh Normal file → Executable file
View File

View File

@@ -922,6 +922,14 @@ components:
description: The ID of the initiator of the event. E.g., an ID of a user that triggered the event. description: The ID of the initiator of the event. E.g., an ID of a user that triggered the event.
type: string type: string
example: google-oauth2|123456789012345678901 example: google-oauth2|123456789012345678901
initiator_name:
description: The name of the initiator of the event.
type: string
example: John Doe
initiator_email:
description: The e-mail address of the initiator of the event. E.g., an e-mail of a user that triggered the event.
type: string
example: demo@netbird.io
target_id: target_id:
description: The ID of the target of the event. E.g., an ID of the peer that a user removed. description: The ID of the target of the event. E.g., an ID of the peer that a user removed.
type: string type: string
@@ -938,6 +946,8 @@ components:
- activity - activity
- activity_code - activity_code
- initiator_id - initiator_id
- initiator_name
- initiator_email
- target_id - target_id
- meta - meta
responses: responses:

View File

@@ -164,9 +164,15 @@ type Event struct {
// Id Event unique identifier // Id Event unique identifier
Id string `json:"id"` Id string `json:"id"`
// InitiatorEmail The e-mail address of the initiator of the event. E.g., an e-mail of a user that triggered the event.
InitiatorEmail string `json:"initiator_email"`
// InitiatorId The ID of the initiator of the event. E.g., an ID of a user that triggered the event. // InitiatorId The ID of the initiator of the event. E.g., an ID of a user that triggered the event.
InitiatorId string `json:"initiator_id"` InitiatorId string `json:"initiator_id"`
// InitiatorName The name of the initiator of the event.
InitiatorName string `json:"initiator_name"`
// Meta The metadata of the event // Meta The metadata of the event
Meta map[string]string `json:"meta"` Meta map[string]string `json:"meta"`

View File

@@ -45,14 +45,66 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
util.WriteError(err, w) util.WriteError(err, w)
return return
} }
events := make([]*api.Event, 0) events := make([]*api.Event, len(accountEvents))
for _, e := range accountEvents { for i, e := range accountEvents {
events = append(events, toEventResponse(e)) events[i] = toEventResponse(e)
}
err = h.fillEventsWithUserInfo(events, account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
return
} }
util.WriteJSONObject(w, events) util.WriteJSONObject(w, events)
} }
func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, userId string) error {
// build email, name maps based on users
userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId)
if err != nil {
log.Errorf("failed to get users from account: %s", err)
return err
}
emails := make(map[string]string)
names := make(map[string]string)
for _, ui := range userInfos {
emails[ui.ID] = ui.Email
names[ui.ID] = ui.Name
}
var ok bool
for _, event := range events {
// fill initiator
if event.InitiatorEmail == "" {
event.InitiatorEmail, ok = emails[event.InitiatorId]
if !ok {
log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
}
}
if event.InitiatorName == "" {
// here to allowed to be empty because in the first release we did not store the name
event.InitiatorName = names[event.InitiatorId]
}
// fill target meta
email, ok := emails[event.TargetId]
if !ok {
continue
}
event.Meta["email"] = email
username, ok := names[event.TargetId]
if !ok {
continue
}
event.Meta["username"] = username
}
return nil
}
func toEventResponse(event *activity.Event) *api.Event { func toEventResponse(event *activity.Event) *api.Event {
meta := make(map[string]string) meta := make(map[string]string)
if event.Meta != nil { if event.Meta != nil {
@@ -60,13 +112,16 @@ func toEventResponse(event *activity.Event) *api.Event {
meta[s] = fmt.Sprintf("%v", a) meta[s] = fmt.Sprintf("%v", a)
} }
} }
return &api.Event{ e := &api.Event{
Id: fmt.Sprint(event.ID), Id: fmt.Sprint(event.ID),
InitiatorId: event.InitiatorID, InitiatorId: event.InitiatorID,
Activity: event.Activity.Message(), InitiatorName: event.InitiatorName,
ActivityCode: api.EventActivityCode(event.Activity.StringCode()), InitiatorEmail: event.InitiatorEmail,
TargetId: event.TargetID, Activity: event.Activity.Message(),
Timestamp: event.Timestamp, ActivityCode: api.EventActivityCode(event.Activity.StringCode()),
Meta: meta, TargetId: event.TargetID,
Timestamp: event.Timestamp,
Meta: meta,
} }
return e
} }

View File

@@ -37,6 +37,9 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
}, },
}, user, nil }, user, nil
}, },
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
return make([]*server.UserInfo, 0), nil
},
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {

View File

@@ -53,22 +53,6 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle
Issued: server.GroupIssuedAPI, Issued: server.GroupIssuedAPI,
}, nil }, nil
}, },
UpdateGroupFunc: func(_ string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) {
var group server.Group
group.ID = groupID
for _, operation := range operations {
switch operation.Type {
case server.UpdateGroupName:
group.Name = operation.Values[0]
case server.UpdateGroupPeers, server.InsertPeersToGroup:
group.Peers = operation.Values
case server.RemovePeersFromGroup:
default:
return nil, fmt.Errorf("no operation")
}
}
return &group, nil
},
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) { GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
for _, peer := range TestPeers { for _, peer := range TestPeers {
if peer.IP.String() == peerIP { if peer.IP.String() == peerIP {

View File

@@ -88,31 +88,6 @@ func initNameserversTestData() *NameserversHandler {
} }
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
}, },
UpdateNameServerGroupFunc: func(accountID, nsGroupID, _ string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
nsGroupToUpdate := baseExistingNSGroup.Copy()
if nsGroupID != nsGroupToUpdate.ID {
return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
}
for _, operation := range operations {
switch operation.Type {
case server.UpdateNameServerGroupName:
nsGroupToUpdate.Name = operation.Values[0]
case server.UpdateNameServerGroupDescription:
nsGroupToUpdate.Description = operation.Values[0]
case server.UpdateNameServerGroupNameServers:
var parsedNSList []nbdns.NameServer
for _, nsURL := range operation.Values {
parsed, err := nbdns.ParseNameServerURL(nsURL)
if err != nil {
return nil, err
}
parsedNSList = append(parsedNSList, parsed)
}
nsGroupToUpdate.NameServers = parsedNSList
}
}
return nsGroupToUpdate, nil
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingNSAccount, testingAccount.Users["test_user"], nil return testingNSAccount, testingAccount.Users["test_user"], nil
}, },

View File

@@ -8,7 +8,6 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/netip" "net/netip"
"strconv"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
@@ -108,38 +107,6 @@ func initRoutesTestData() *RoutesHandler {
IP: netip.MustParseAddr(existingPeerID).AsSlice(), IP: netip.MustParseAddr(existingPeerID).AsSlice(),
}, nil }, nil
}, },
UpdateRouteFunc: func(_ string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) {
routeToUpdate := baseExistingRoute
if routeID != routeToUpdate.ID {
return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID)
}
for _, operation := range operations {
switch operation.Type {
case server.UpdateRouteNetwork:
routeToUpdate.NetworkType, routeToUpdate.Network, _ = route.ParseNetwork(operation.Values[0])
case server.UpdateRouteDescription:
routeToUpdate.Description = operation.Values[0]
case server.UpdateRouteNetworkIdentifier:
routeToUpdate.NetID = operation.Values[0]
case server.UpdateRoutePeer:
routeToUpdate.Peer = operation.Values[0]
if routeToUpdate.Peer == notFoundPeerID {
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", routeToUpdate.Peer)
}
case server.UpdateRouteMetric:
routeToUpdate.Metric, _ = strconv.Atoi(operation.Values[0])
case server.UpdateRouteMasquerade:
routeToUpdate.Masquerade, _ = strconv.ParseBool(operation.Values[0])
case server.UpdateRouteEnabled:
routeToUpdate.Enabled, _ = strconv.ParseBool(operation.Values[0])
case server.UpdateRouteGroups:
routeToUpdate.Groups = operation.Values
default:
return nil, fmt.Errorf("no operation")
}
}
return routeToUpdate, nil
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingAccount, testingAccount.Users["test_user"], nil return testingAccount, testingAccount.Users["test_user"], nil
}, },

View File

@@ -513,7 +513,9 @@ func buildUserExportRequest() (string, error) {
return string(str), nil return string(str), nil
} }
func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) { func (am *Auth0Manager) createRequest(
method string, endpoint string, body io.Reader,
) (*http.Request, error) {
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -521,17 +523,23 @@ func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*
reqURL := am.authIssuer + endpoint reqURL := am.authIssuer + endpoint
payload := strings.NewReader(payloadStr) req, err := http.NewRequest(method, reqURL, body)
req, err := http.NewRequest("POST", reqURL, payload)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
return req, nil
}
func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) {
req, err := am.createRequest("POST", endpoint, strings.NewReader(payloadStr))
if err != nil {
return nil, err
}
req.Header.Add("content-type", "application/json") req.Header.Add("content-type", "application/json")
return req, nil return req, nil
} }
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
@@ -737,6 +745,38 @@ func (am *Auth0Manager) InviteUserByID(userID string) error {
return nil return nil
} }
// DeleteUser from Auth0
func (am *Auth0Manager) DeleteUser(userID string) error {
req, err := am.createRequest(http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
if err != nil {
return err
}
resp, err := am.httpClient.Do(req)
if err != nil {
log.Debugf("execute delete request: %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer func() {
err = resp.Body.Close()
if err != nil {
log.Errorf("close delete request body: %v", err)
}
}()
if resp.StatusCode != 204 {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
}
return nil
}
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob. // checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
// If the status is "completed", then return the downloadLink // If the status is "completed", then return the downloadLink
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) { func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {

View File

@@ -12,9 +12,10 @@ import (
"time" "time"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server/telemetry"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/api/v3" "goauthentik.io/api/v3"
"github.com/netbirdio/netbird/management/server/telemetry"
) )
// AuthentikManager authentik manager client instance. // AuthentikManager authentik manager client instance.
@@ -453,6 +454,38 @@ func (am *AuthentikManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented") return fmt.Errorf("method InviteUserByID not implemented")
} }
// DeleteUser from Authentik
func (am *AuthentikManager) DeleteUser(userID string) error {
ctx, err := am.authenticationContext()
if err != nil {
return err
}
userPk, err := strconv.ParseInt(userID, 10, 32)
if err != nil {
return err
}
resp, err := am.apiClient.CoreApi.CoreUsersDestroy(ctx, int32(userPk)).Execute()
if err != nil {
return err
}
defer resp.Body.Close() // nolint
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountDeleteUser()
}
if resp.StatusCode != http.StatusNoContent {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete user %s, statusCode %d", userID, resp.StatusCode)
}
return nil
}
func (am *AuthentikManager) authenticationContext() (context.Context, error) { func (am *AuthentikManager) authenticationContext() (context.Context, error) {
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate()
if err != nil { if err != nil {

View File

@@ -454,6 +454,43 @@ func (am *AzureManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented") return fmt.Errorf("method InviteUserByID not implemented")
} }
// DeleteUser from Azure
func (am *AzureManager) DeleteUser(userID string) error {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return err
}
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, url.QueryEscape(userID))
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.Debugf("delete idp user %s", userID)
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountDeleteUser()
}
if resp.StatusCode != http.StatusNoContent {
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
}
return nil
}
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) { func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
q := url.Values{} q := url.Values{}
q.Add("$select", extensionFields) q.Add("$select", extensionFields)

View File

@@ -254,6 +254,19 @@ func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented") return fmt.Errorf("method InviteUserByID not implemented")
} }
// DeleteUser from GoogleWorkspace.
func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error {
if err := gm.usersService.Delete(userID).Do(); err != nil {
return err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountDeleteUser()
}
return nil
}
// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey. // getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey.
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it. // It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
// If that fails, it falls back to using the default Google credentials path. // If that fails, it falls back to using the default Google credentials path.

View File

@@ -18,6 +18,7 @@ type Manager interface {
CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error)
GetUserByEmail(email string) ([]*UserData, error) GetUserByEmail(email string) ([]*UserData, error)
InviteUserByID(userID string) error InviteUserByID(userID string) error
DeleteUser(userID string) error
} }
// ClientConfig defines common client configuration for all IdP manager // ClientConfig defines common client configuration for all IdP manager
@@ -37,10 +38,10 @@ type Config struct {
ManagerType string ManagerType string
ClientConfig *ClientConfig ClientConfig *ClientConfig
ExtraConfig ExtraConfig ExtraConfig ExtraConfig
Auth0ClientCredentials Auth0ClientConfig Auth0ClientCredentials *Auth0ClientConfig
AzureClientCredentials AzureClientConfig AzureClientCredentials *AzureClientConfig
KeycloakClientCredentials KeycloakClientConfig KeycloakClientCredentials *KeycloakClientConfig
ZitadelClientCredentials ZitadelClientConfig ZitadelClientCredentials *ZitadelClientConfig
} }
// ManagerCredentials interface that authenticates using the credential of each type of idp // ManagerCredentials interface that authenticates using the credential of each type of idp
@@ -96,7 +97,7 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
case "auth0": case "auth0":
auth0ClientConfig := config.Auth0ClientCredentials auth0ClientConfig := config.Auth0ClientCredentials
if config.ClientConfig != nil { if config.ClientConfig != nil {
auth0ClientConfig = Auth0ClientConfig{ auth0ClientConfig = &Auth0ClientConfig{
Audience: config.ExtraConfig["Audience"], Audience: config.ExtraConfig["Audience"],
AuthIssuer: config.ClientConfig.Issuer, AuthIssuer: config.ClientConfig.Issuer,
ClientID: config.ClientConfig.ClientID, ClientID: config.ClientConfig.ClientID,
@@ -105,11 +106,11 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
} }
} }
return NewAuth0Manager(auth0ClientConfig, appMetrics) return NewAuth0Manager(*auth0ClientConfig, appMetrics)
case "azure": case "azure":
azureClientConfig := config.AzureClientCredentials azureClientConfig := config.AzureClientCredentials
if config.ClientConfig != nil { if config.ClientConfig != nil {
azureClientConfig = AzureClientConfig{ azureClientConfig = &AzureClientConfig{
ClientID: config.ClientConfig.ClientID, ClientID: config.ClientConfig.ClientID,
ClientSecret: config.ClientConfig.ClientSecret, ClientSecret: config.ClientConfig.ClientSecret,
GrantType: config.ClientConfig.GrantType, GrantType: config.ClientConfig.GrantType,
@@ -119,11 +120,11 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
} }
} }
return NewAzureManager(azureClientConfig, appMetrics) return NewAzureManager(*azureClientConfig, appMetrics)
case "keycloak": case "keycloak":
keycloakClientConfig := config.KeycloakClientCredentials keycloakClientConfig := config.KeycloakClientCredentials
if config.ClientConfig != nil { if config.ClientConfig != nil {
keycloakClientConfig = KeycloakClientConfig{ keycloakClientConfig = &KeycloakClientConfig{
ClientID: config.ClientConfig.ClientID, ClientID: config.ClientConfig.ClientID,
ClientSecret: config.ClientConfig.ClientSecret, ClientSecret: config.ClientConfig.ClientSecret,
GrantType: config.ClientConfig.GrantType, GrantType: config.ClientConfig.GrantType,
@@ -132,11 +133,11 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
} }
} }
return NewKeycloakManager(keycloakClientConfig, appMetrics) return NewKeycloakManager(*keycloakClientConfig, appMetrics)
case "zitadel": case "zitadel":
zitadelClientConfig := config.ZitadelClientCredentials zitadelClientConfig := config.ZitadelClientCredentials
if config.ClientConfig != nil { if config.ClientConfig != nil {
zitadelClientConfig = ZitadelClientConfig{ zitadelClientConfig = &ZitadelClientConfig{
ClientID: config.ClientConfig.ClientID, ClientID: config.ClientConfig.ClientID,
ClientSecret: config.ClientConfig.ClientSecret, ClientSecret: config.ClientConfig.ClientSecret,
GrantType: config.ClientConfig.GrantType, GrantType: config.ClientConfig.GrantType,
@@ -145,7 +146,7 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
} }
} }
return NewZitadelManager(zitadelClientConfig, appMetrics) return NewZitadelManager(*zitadelClientConfig, appMetrics)
case "authentik": case "authentik":
authentikConfig := AuthentikClientConfig{ authentikConfig := AuthentikClientConfig{
Issuer: config.ClientConfig.Issuer, Issuer: config.ClientConfig.Issuer,

View File

@@ -467,6 +467,47 @@ func (km *KeycloakManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented") return fmt.Errorf("method InviteUserByID not implemented")
} }
// DeleteUser from Keycloack
func (km *KeycloakManager) DeleteUser(userID string) error {
jwtToken, err := km.credentials.Authenticate()
if err != nil {
return err
}
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID))
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountDeleteUser()
}
resp, err := km.httpClient.Do(req)
if err != nil {
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close() // nolint
// In the docs, they specified 200, but in the endpoints, they return 204
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
}
return nil
}
func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) { func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) {
attrs := keycloakUserAttributes{} attrs := keycloakUserAttributes{}
attrs.Set(wtAccountID, appMetadata.WTAccountID) attrs.Set(wtAccountID, appMetadata.WTAccountID)

View File

@@ -319,6 +319,28 @@ func (om *OktaManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented") return fmt.Errorf("method InviteUserByID not implemented")
} }
// DeleteUser from Okta
func (om *OktaManager) DeleteUser(userID string) error {
resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil)
if err != nil {
fmt.Println(err.Error())
return err
}
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountDeleteUser()
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
}
return nil
}
// updateUserProfileSchema updates the Okta user schema to include custom fields, // updateUserProfileSchema updates the Okta user schema to include custom fields,
// wt_account_id and wt_pending_invite. // wt_account_id and wt_pending_invite.
func updateUserProfileSchema(client *okta.Client) error { func updateUserProfileSchema(client *okta.Client) error {

View File

@@ -428,7 +428,7 @@ func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMe
return err return err
} }
resource := fmt.Sprintf("users/%s/metadata/_bulk", userID) resource := fmt.Sprintf("users/%s", userID)
_, err = zm.post(resource, string(payload)) _, err = zm.post(resource, string(payload))
if err != nil { if err != nil {
return err return err
@@ -447,6 +447,21 @@ func (zm *ZitadelManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented") return fmt.Errorf("method InviteUserByID not implemented")
} }
// DeleteUser from Zitadel
func (zm *ZitadelManager) DeleteUser(userID string) error {
resource := fmt.Sprintf("users/%s", userID)
if err := zm.delete(resource); err != nil {
return err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountDeleteUser()
}
return nil
}
// getUserMetadata requests user metadata from zitadel via ID. // getUserMetadata requests user metadata from zitadel via ID.
func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) { func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) {
resource := fmt.Sprintf("users/%s/metadata/_search", userID) resource := fmt.Sprintf("users/%s/metadata/_search", userID)
@@ -500,6 +515,42 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
return io.ReadAll(resp.Body) return io.ReadAll(resp.Body)
} }
// delete perform Delete requests.
func (zm *ZitadelManager) delete(resource string) error {
jwtToken, err := zm.credentials.Authenticate()
if err != nil {
return err
}
reqURL := fmt.Sprintf("%s/%s", zm.managementEndpoint, resource)
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := zm.httpClient.Do(req)
if err != nil {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete %s, statusCode %d", reqURL, resp.StatusCode)
}
return nil
}
// get perform Get requests. // get perform Get requests.
func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) { func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := zm.credentials.Authenticate() jwtToken, err := zm.credentials.Authenticate()

View File

@@ -412,7 +412,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
peersUpdateManager := NewPeersUpdateManager() peersUpdateManager := NewPeersUpdateManager()
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "", accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "",
eventStore) eventStore, false)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -503,7 +503,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
peersUpdateManager := server.NewPeersUpdateManager() peersUpdateManager := server.NewPeersUpdateManager()
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore) eventStore, false)
if err != nil { if err != nil {
log.Fatalf("failed creating a manager: %v", err) log.Fatalf("failed creating a manager: %v", err)
} }

View File

@@ -31,7 +31,6 @@ type MockAccountManager struct {
AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, *server.NetworkMap, error) AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, *server.NetworkMap, error)
GetGroupFunc func(accountID, groupID string) (*server.Group, error) GetGroupFunc func(accountID, groupID string) (*server.Group, error)
SaveGroupFunc func(accountID, userID string, group *server.Group) error SaveGroupFunc func(accountID, userID string, group *server.Group) error
UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error)
DeleteGroupFunc func(accountID, userId, groupID string) error DeleteGroupFunc func(accountID, userId, groupID string) error
ListGroupsFunc func(accountID string) ([]*server.Group, error) ListGroupsFunc func(accountID string) ([]*server.Group, error)
GroupAddPeerFunc func(accountID, groupID, peerKey string) error GroupAddPeerFunc func(accountID, groupID, peerKey string) error
@@ -54,7 +53,6 @@ type MockAccountManager struct {
CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error) GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID, userID string, route *route.Route) error SaveRouteFunc func(accountID, userID string, route *route.Route) error
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
DeleteRouteFunc func(accountID, routeID, userID string) error DeleteRouteFunc func(accountID, routeID, userID string) error
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error) ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
@@ -68,7 +66,6 @@ type MockAccountManager struct {
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
UpdateNameServerGroupFunc func(accountID, nsGroupID, userID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error) ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
@@ -267,14 +264,6 @@ func (am *MockAccountManager) SaveGroup(accountID, userID string, group *server.
return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented") return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented")
} }
// UpdateGroup mock implementation of UpdateGroup from server.AccountManager interface
func (am *MockAccountManager) UpdateGroup(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) {
if am.UpdateGroupFunc != nil {
return am.UpdateGroupFunc(accountID, groupID, operations)
}
return nil, status.Errorf(codes.Unimplemented, "method UpdateGroup not implemented")
}
// DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface // DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface
func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) error { func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) error {
if am.DeleteGroupFunc != nil { if am.DeleteGroupFunc != nil {
@@ -435,14 +424,6 @@ func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.R
return status.Errorf(codes.Unimplemented, "method SaveRoute is not implemented") return status.Errorf(codes.Unimplemented, "method SaveRoute is not implemented")
} }
// UpdateRoute mock implementation of UpdateRoute from server.AccountManager interface
func (am *MockAccountManager) UpdateRoute(accountID, ruleID string, operations []server.RouteUpdateOperation) (*route.Route, error) {
if am.UpdateRouteFunc != nil {
return am.UpdateRouteFunc(accountID, ruleID, operations)
}
return nil, status.Errorf(codes.Unimplemented, "method UpdateRoute not implemented")
}
// DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface // DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface
func (am *MockAccountManager) DeleteRoute(accountID, routeID, userID string) error { func (am *MockAccountManager) DeleteRoute(accountID, routeID, userID string) error {
if am.DeleteRouteFunc != nil { if am.DeleteRouteFunc != nil {
@@ -533,14 +514,6 @@ func (am *MockAccountManager) SaveNameServerGroup(accountID, userID string, nsGr
return nil return nil
} }
// UpdateNameServerGroup mocks UpdateNameServerGroup of the AccountManager interface
func (am *MockAccountManager) UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
if am.UpdateNameServerGroupFunc != nil {
return am.UpdateNameServerGroupFunc(accountID, nsGroupID, userID, operations)
}
return nil, nil
}
// DeleteNameServerGroup mocks DeleteNameServerGroup of the AccountManager interface // DeleteNameServerGroup mocks DeleteNameServerGroup of the AccountManager interface
func (am *MockAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error { func (am *MockAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
if am.DeleteNameServerGroupFunc != nil { if am.DeleteNameServerGroupFunc != nil {

View File

@@ -3,7 +3,6 @@ package server
import ( import (
"errors" "errors"
"regexp" "regexp"
"strconv"
"unicode/utf8" "unicode/utf8"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -15,54 +14,7 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
const ( const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
// UpdateNameServerGroupName indicates a nameserver group name update operation
UpdateNameServerGroupName NameServerGroupUpdateOperationType = iota
// UpdateNameServerGroupDescription indicates a nameserver group description update operation
UpdateNameServerGroupDescription
// UpdateNameServerGroupNameServers indicates a nameserver group nameservers list update operation
UpdateNameServerGroupNameServers
// UpdateNameServerGroupGroups indicates a nameserver group' groups update operation
UpdateNameServerGroupGroups
// UpdateNameServerGroupEnabled indicates a nameserver group status update operation
UpdateNameServerGroupEnabled
// UpdateNameServerGroupPrimary indicates a nameserver group primary status update operation
UpdateNameServerGroupPrimary
// UpdateNameServerGroupDomains indicates a nameserver group' domains update operation
UpdateNameServerGroupDomains
domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
)
// NameServerGroupUpdateOperationType operation type
type NameServerGroupUpdateOperationType int
func (t NameServerGroupUpdateOperationType) String() string {
switch t {
case UpdateNameServerGroupDescription:
return "UpdateNameServerGroupDescription"
case UpdateNameServerGroupName:
return "UpdateNameServerGroupName"
case UpdateNameServerGroupNameServers:
return "UpdateNameServerGroupNameServers"
case UpdateNameServerGroupGroups:
return "UpdateNameServerGroupGroups"
case UpdateNameServerGroupEnabled:
return "UpdateNameServerGroupEnabled"
case UpdateNameServerGroupPrimary:
return "UpdateNameServerGroupPrimary"
case UpdateNameServerGroupDomains:
return "UpdateNameServerGroupDomains"
default:
return "InvalidOperation"
}
}
// NameServerGroupUpdateOperation operation object with type and values to be applied
type NameServerGroupUpdateOperation struct {
Type NameServerGroupUpdateOperationType
Values []string
}
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
@@ -172,109 +124,6 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n
return nil return nil
} }
// UpdateNameServerGroup updates existing nameserver group with set of operations
func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
if len(operations) == 0 {
return nil, status.Errorf(status.InvalidArgument, "operations shouldn't be empty")
}
nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID]
if !ok {
return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
}
newNSGroup := nsGroupToUpdate.Copy()
for _, operation := range operations {
valuesCount := len(operation.Values)
if valuesCount < 1 {
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String())
}
for _, value := range operation.Values {
if value == "" {
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String())
}
}
switch operation.Type {
case UpdateNameServerGroupDescription:
newNSGroup.Description = operation.Values[0]
case UpdateNameServerGroupName:
if valuesCount > 1 {
return nil, status.Errorf(status.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount)
}
err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups)
if err != nil {
return nil, err
}
newNSGroup.Name = operation.Values[0]
case UpdateNameServerGroupNameServers:
var nsList []nbdns.NameServer
for _, url := range operation.Values {
ns, err := nbdns.ParseNameServerURL(url)
if err != nil {
return nil, err
}
nsList = append(nsList, ns)
}
err = validateNSList(nsList)
if err != nil {
return nil, err
}
newNSGroup.NameServers = nsList
case UpdateNameServerGroupGroups:
err = validateGroups(operation.Values, account.Groups)
if err != nil {
return nil, err
}
newNSGroup.Groups = operation.Values
case UpdateNameServerGroupEnabled:
enabled, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
}
newNSGroup.Enabled = enabled
case UpdateNameServerGroupPrimary:
primary, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0])
}
newNSGroup.Primary = primary
case UpdateNameServerGroupDomains:
err = validateDomainInput(false, operation.Values)
if err != nil {
return nil, err
}
newNSGroup.Domains = operation.Values
}
}
account.NameServerGroups[nsGroupID] = newNSGroup
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after update nameserver %s", newNSGroup.Name)
}
return newNSGroup.Copy(), nil
}
// DeleteNameServerGroup deletes nameserver group with nsGroupID // DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error { func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {

View File

@@ -655,323 +655,6 @@ func TestSaveNameServerGroup(t *testing.T) {
} }
} }
func TestUpdateNameServerGroup(t *testing.T) {
nsGroupID := "testingNSGroup"
existingNSGroup := &nbdns.NameServerGroup{
ID: nsGroupID,
Name: "super",
Description: "super",
Primary: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Groups: []string{group1ID},
Enabled: true,
}
testCases := []struct {
name string
existingNSGroup *nbdns.NameServerGroup
nsGroupID string
operations []NameServerGroupUpdateOperation
shouldCreate bool
errFunc require.ErrorAssertionFunc
expectedNSGroup *nbdns.NameServerGroup
}{
{
name: "Should Config Single Property",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{"superNew"},
},
},
errFunc: require.NoError,
shouldCreate: true,
expectedNSGroup: &nbdns.NameServerGroup{
ID: nsGroupID,
Name: "superNew",
Description: "super",
Primary: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Groups: []string{group1ID},
Enabled: true,
},
},
{
name: "Should Config Multiple Properties",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{"superNew"},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupDescription,
Values: []string{"superDescription"},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupNameServers,
Values: []string{"udp://127.0.0.1:53", "udp://8.8.8.8:53"},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupGroups,
Values: []string{group1ID, group2ID},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupEnabled,
Values: []string{"false"},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupPrimary,
Values: []string{"false"},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupDomains,
Values: []string{validDomain},
},
},
errFunc: require.NoError,
shouldCreate: true,
expectedNSGroup: &nbdns.NameServerGroup{
ID: nsGroupID,
Name: "superNew",
Description: "superDescription",
Primary: false,
Domains: []string{validDomain},
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("127.0.0.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Groups: []string{group1ID, group2ID},
Enabled: false,
},
},
{
name: "Should Not Config On Invalid ID",
existingNSGroup: existingNSGroup,
nsGroupID: "nonExistingNSGroup",
errFunc: require.Error,
},
{
name: "Should Not Config On Empty Operations",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{},
errFunc: require.Error,
},
{
name: "Should Not Config On Empty Values",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Empty String",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{""},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid Name Large String",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{"12345678901234567890qwertyuiopqwertyuiop1"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid On Existing Name",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{existingNSGroupName},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid On Multiple Name Values",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{"nameOne", "nameTwo"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid Boolean",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupEnabled,
Values: []string{"yes"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid Nameservers Wrong Schema",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupNameServers,
Values: []string{"https://127.0.0.1:53"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid Nameservers Wrong IP",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupNameServers,
Values: []string{"udp://8.8.8.300:53"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Large Number Of Nameservers",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupNameServers,
Values: []string{"udp://127.0.0.1:53", "udp://8.8.8.8:53", "udp://8.8.4.4:53"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid GroupID",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupGroups,
Values: []string{"nonExistingGroupID"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid Domains",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupDomains,
Values: []string{invalidDomain},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid Primary Status",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupPrimary,
Values: []string{"yes"},
},
},
errFunc: require.Error,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
am, err := createNSManager(t)
if err != nil {
t.Error("failed to create account manager")
}
account, err := initTestNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup
err = am.Store.SaveAccount(account)
if err != nil {
t.Error("account should be saved")
}
updatedRoute, err := am.UpdateNameServerGroup(account.Id, testCase.nsGroupID, userID, testCase.operations)
testCase.errFunc(t, err)
if !testCase.shouldCreate {
return
}
testCase.expectedNSGroup.ID = updatedRoute.ID
if !testCase.expectedNSGroup.IsEqual(updatedRoute) {
t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", updatedRoute, testCase.expectedNSGroup)
}
})
}
}
func TestDeleteNameServerGroup(t *testing.T) { func TestDeleteNameServerGroup(t *testing.T) {
nsGroupID := "testingNSGroup" nsGroupID := "testingNSGroup"
@@ -1061,7 +744,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err return nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore) return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore, false)
} }
func createNSStore(t *testing.T) (Store, error) { func createNSStore(t *testing.T) (Store, error) {

View File

@@ -2,7 +2,6 @@ package server
import ( import (
"net/netip" "net/netip"
"strconv"
"unicode/utf8" "unicode/utf8"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
@@ -13,57 +12,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const (
// UpdateRouteDescription indicates a route description update operation
UpdateRouteDescription RouteUpdateOperationType = iota
// UpdateRouteNetwork indicates a route IP update operation
UpdateRouteNetwork
// UpdateRoutePeer indicates a route peer update operation
UpdateRoutePeer
// UpdateRouteMetric indicates a route metric update operation
UpdateRouteMetric
// UpdateRouteMasquerade indicates a route masquerade update operation
UpdateRouteMasquerade
// UpdateRouteEnabled indicates a route enabled update operation
UpdateRouteEnabled
// UpdateRouteNetworkIdentifier indicates a route net ID update operation
UpdateRouteNetworkIdentifier
// UpdateRouteGroups indicates a group list update operation
UpdateRouteGroups
)
// RouteUpdateOperationType operation type
type RouteUpdateOperationType int
func (t RouteUpdateOperationType) String() string {
switch t {
case UpdateRouteDescription:
return "UpdateRouteDescription"
case UpdateRouteNetwork:
return "UpdateRouteNetwork"
case UpdateRoutePeer:
return "UpdateRoutePeer"
case UpdateRouteMetric:
return "UpdateRouteMetric"
case UpdateRouteMasquerade:
return "UpdateRouteMasquerade"
case UpdateRouteEnabled:
return "UpdateRouteEnabled"
case UpdateRouteNetworkIdentifier:
return "UpdateRouteNetworkIdentifier"
case UpdateRouteGroups:
return "UpdateRouteGroups"
default:
return "InvalidOperation"
}
}
// RouteUpdateOperation operation object with type and values to be applied
type RouteUpdateOperation struct {
Type RouteUpdateOperationType
Values []string
}
// GetRoute gets a route object from account and route IDs // GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) { func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
@@ -241,109 +189,6 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
return nil return nil
} }
// UpdateRoute updates existing route with set of operations
func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
routeToUpdate, ok := account.Routes[routeID]
if !ok {
return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID)
}
newRoute := routeToUpdate.Copy()
for _, operation := range operations {
if len(operation.Values) != 1 {
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be 1", operation.Type.String())
}
switch operation.Type {
case UpdateRouteDescription:
newRoute.Description = operation.Values[0]
case UpdateRouteNetworkIdentifier:
if utf8.RuneCountInString(operation.Values[0]) > route.MaxNetIDChar || operation.Values[0] == "" {
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
newRoute.NetID = operation.Values[0]
case UpdateRouteNetwork:
prefixType, prefix, err := route.ParseNetwork(operation.Values[0])
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", operation.Values[0])
}
err = am.checkPrefixPeerExists(accountID, routeToUpdate.Peer, prefix)
if err != nil {
return nil, err
}
newRoute.Network = prefix
newRoute.NetworkType = prefixType
case UpdateRoutePeer:
if operation.Values[0] != "" {
peer := account.GetPeer(operation.Values[0])
if peer == nil {
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", operation.Values[0])
}
}
err = am.checkPrefixPeerExists(accountID, operation.Values[0], routeToUpdate.Network)
if err != nil {
return nil, err
}
newRoute.Peer = operation.Values[0]
case UpdateRouteMetric:
metric, err := strconv.Atoi(operation.Values[0])
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, not int", operation.Values[0])
}
if metric < route.MinMetric || metric > route.MaxMetric {
return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, value should be %d > N < %d",
operation.Values[0],
route.MinMetric,
route.MaxMetric,
)
}
newRoute.Metric = metric
case UpdateRouteMasquerade:
masquerade, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "failed to parse masquerade %s, not boolean", operation.Values[0])
}
newRoute.Masquerade = masquerade
case UpdateRouteEnabled:
enabled, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
}
newRoute.Enabled = enabled
case UpdateRouteGroups:
err = validateGroups(operation.Values, account.Groups)
if err != nil {
return nil, err
}
newRoute.Groups = operation.Values
}
}
account.Routes[routeID] = newRoute
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
return nil, err
}
err = am.updateAccountPeers(account)
if err != nil {
return nil, status.Errorf(status.Internal, "failed to update account peers")
}
return newRoute, nil
}
// DeleteRoute deletes route with routeID // DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error { func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)

View File

@@ -524,265 +524,6 @@ func TestSaveRoute(t *testing.T) {
} }
} }
func TestUpdateRoute(t *testing.T) {
routeID := "testingRouteID"
existingRoute := &route.Route{
ID: routeID,
Network: netip.MustParsePrefix("192.168.0.0/16"),
NetID: "superRoute",
NetworkType: route.IPv4Network,
Peer: peer1ID,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
}
testCases := []struct {
name string
existingRoute *route.Route
operations []RouteUpdateOperation
shouldCreate bool
errFunc require.ErrorAssertionFunc
expectedRoute *route.Route
}{
{
name: "Happy Path Single OPS",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRoutePeer,
Values: []string{peer2ID},
},
},
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
ID: routeID,
Network: netip.MustParsePrefix("192.168.0.0/16"),
NetID: "superRoute",
NetworkType: route.IPv4Network,
Peer: peer2ID,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
},
},
{
name: "Happy Path Multiple OPS",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRouteDescription,
Values: []string{"great"},
},
{
Type: UpdateRouteNetwork,
Values: []string{"192.168.0.0/24"},
},
{
Type: UpdateRoutePeer,
Values: []string{peer2ID},
},
{
Type: UpdateRouteMetric,
Values: []string{"3030"},
},
{
Type: UpdateRouteMasquerade,
Values: []string{"true"},
},
{
Type: UpdateRouteEnabled,
Values: []string{"false"},
},
{
Type: UpdateRouteNetworkIdentifier,
Values: []string{"megaRoute"},
},
{
Type: UpdateRouteGroups,
Values: []string{routeGroup2},
},
},
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
ID: routeID,
Network: netip.MustParsePrefix("192.168.0.0/24"),
NetID: "megaRoute",
NetworkType: route.IPv4Network,
Peer: peer2ID,
Description: "great",
Masquerade: true,
Metric: 3030,
Enabled: false,
Groups: []string{routeGroup2},
},
},
{
name: "Empty Values Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRoutePeer,
},
},
errFunc: require.Error,
},
{
name: "Multiple Values Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRoutePeer,
Values: []string{peer2ID, peer1ID},
},
},
errFunc: require.Error,
},
{
name: "Bad Prefix Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRouteNetwork,
Values: []string{"192.168.0.0/34"},
},
},
errFunc: require.Error,
},
{
name: "Bad Peer Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRoutePeer,
Values: []string{"non existing Peer"},
},
},
errFunc: require.Error,
},
{
name: "Empty Peer",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRoutePeer,
Values: []string{""},
},
},
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
ID: routeID,
Network: netip.MustParsePrefix("192.168.0.0/16"),
NetID: "superRoute",
NetworkType: route.IPv4Network,
Peer: "",
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
},
},
{
name: "Large Network ID Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRouteNetworkIdentifier,
Values: []string{"12345678901234567890qwertyuiopqwertyuiop1"},
},
},
errFunc: require.Error,
},
{
name: "Empty Network ID Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRouteNetworkIdentifier,
Values: []string{""},
},
},
errFunc: require.Error,
},
{
name: "Invalid Metric Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRouteMetric,
Values: []string{"999999"},
},
},
errFunc: require.Error,
},
{
name: "Invalid Boolean Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRouteMasquerade,
Values: []string{"yes"},
},
},
errFunc: require.Error,
},
{
name: "Invalid Group Should Fail",
existingRoute: existingRoute,
operations: []RouteUpdateOperation{
{
Type: UpdateRouteGroups,
Values: []string{routeInvalidGroup1},
},
},
errFunc: require.Error,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
am, err := createRouterManager(t)
if err != nil {
t.Error("failed to create account manager")
}
account, err := initTestRouteAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
account.Routes[testCase.existingRoute.ID] = testCase.existingRoute
err = am.Store.SaveAccount(account)
if err != nil {
t.Error("account should be saved")
}
updatedRoute, err := am.UpdateRoute(account.Id, testCase.existingRoute.ID, testCase.operations)
testCase.errFunc(t, err)
if !testCase.shouldCreate {
return
}
testCase.expectedRoute.ID = updatedRoute.ID
if !testCase.expectedRoute.IsEqual(updatedRoute) {
t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", updatedRoute, testCase.expectedRoute)
}
})
}
}
func TestDeleteRoute(t *testing.T) { func TestDeleteRoute(t *testing.T) {
testingRoute := &route.Route{ testingRoute := &route.Route{
ID: "testingRoute", ID: "testingRoute",
@@ -940,7 +681,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err return nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore) return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore, false)
} }
func createRouterStore(t *testing.T) (Store, error) { func createRouterStore(t *testing.T) (Store, error) {

View File

@@ -2,6 +2,7 @@ package telemetry
import ( import (
"context" "context"
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/metric/instrument" "go.opentelemetry.io/otel/metric/instrument"
"go.opentelemetry.io/otel/metric/instrument/syncint64" "go.opentelemetry.io/otel/metric/instrument/syncint64"
@@ -13,6 +14,7 @@ type IDPMetrics struct {
getUserByEmailCounter syncint64.Counter getUserByEmailCounter syncint64.Counter
getAllAccountsCounter syncint64.Counter getAllAccountsCounter syncint64.Counter
createUserCounter syncint64.Counter createUserCounter syncint64.Counter
deleteUserCounter syncint64.Counter
getAccountCounter syncint64.Counter getAccountCounter syncint64.Counter
getUserByIDCounter syncint64.Counter getUserByIDCounter syncint64.Counter
authenticateRequestCounter syncint64.Counter authenticateRequestCounter syncint64.Counter
@@ -39,6 +41,10 @@ func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
deleteUserCounter, err := meter.SyncInt64().Counter("management.idp.delete.user.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
getAccountCounter, err := meter.SyncInt64().Counter("management.idp.get.account.counter", instrument.WithUnit("1")) getAccountCounter, err := meter.SyncInt64().Counter("management.idp.get.account.counter", instrument.WithUnit("1"))
if err != nil { if err != nil {
return nil, err return nil, err
@@ -65,6 +71,7 @@ func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error)
getUserByEmailCounter: getUserByEmailCounter, getUserByEmailCounter: getUserByEmailCounter,
getAllAccountsCounter: getAllAccountsCounter, getAllAccountsCounter: getAllAccountsCounter,
createUserCounter: createUserCounter, createUserCounter: createUserCounter,
deleteUserCounter: deleteUserCounter,
getAccountCounter: getAccountCounter, getAccountCounter: getAccountCounter,
getUserByIDCounter: getUserByIDCounter, getUserByIDCounter: getUserByIDCounter,
authenticateRequestCounter: authenticateRequestCounter, authenticateRequestCounter: authenticateRequestCounter,
@@ -88,6 +95,11 @@ func (idpMetrics *IDPMetrics) CountCreateUser() {
idpMetrics.createUserCounter.Add(idpMetrics.ctx, 1) idpMetrics.createUserCounter.Add(idpMetrics.ctx, 1)
} }
// CountDeleteUser ...
func (idpMetrics *IDPMetrics) CountDeleteUser() {
idpMetrics.deleteUserCounter.Add(idpMetrics.ctx, 1)
}
// CountGetAllAccounts ... // CountGetAllAccounts ...
func (idpMetrics *IDPMetrics) CountGetAllAccounts() { func (idpMetrics *IDPMetrics) CountGetAllAccounts() {
idpMetrics.getAllAccountsCounter.Add(idpMetrics.ctx, 1) idpMetrics.getAllAccountsCounter.Add(idpMetrics.ctx, 1)

View File

@@ -309,6 +309,9 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (
// DeleteUser deletes a user from the given account. // DeleteUser deletes a user from the given account.
func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error { func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error {
if initiatorUserID == targetUserID {
return status.Errorf(status.InvalidArgument, "self deletion is not allowed")
}
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@@ -327,15 +330,43 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t
return status.Errorf(status.NotFound, "user not found") return status.Errorf(status.NotFound, "user not found")
} }
if executingUser.Role != UserRoleAdmin { if executingUser.Role != UserRoleAdmin {
return status.Errorf(status.PermissionDenied, "only admins can delete service users") return status.Errorf(status.PermissionDenied, "only admins can delete users")
} }
if !targetUser.IsServiceUser { peers, err := account.FindUserPeers(targetUserID)
return status.Errorf(status.PermissionDenied, "regular users can not be deleted") if err != nil {
return status.Errorf(status.Internal, "failed to find user peers")
} }
meta := map[string]any{"name": targetUser.ServiceUserName} if err := am.expireAndUpdatePeers(account, peers); err != nil {
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta) log.Errorf("failed update deleted peers expiration: %s", err)
return err
}
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(account.Id, initiatorUserID, targetUserID)
if err != nil {
log.Errorf("failed to resolve email address: %s", err)
return err
}
var meta map[string]any
var eventAction activity.Activity
if targetUser.IsServiceUser {
meta = map[string]any{"name": targetUser.ServiceUserName}
eventAction = activity.ServiceUserDeleted
} else {
meta = map[string]any{"name": tuName, "email": tuEmail}
eventAction = activity.UserDeleted
}
am.storeEvent(initiatorUserID, targetUserID, accountID, eventAction, meta)
if !targetUser.IsServiceUser && !isNil(am.idpManager) {
err := am.deleteUserFromIDP(targetUserID, accountID)
if err != nil {
log.Debugf("failed to delete user from IDP: %s", targetUserID)
return err
}
}
delete(account.Users, targetUserID) delete(account.Users, targetUserID)
@@ -609,23 +640,10 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd
if err != nil { if err != nil {
return nil, err return nil, err
} }
var peerIDs []string
for _, peer := range blockedPeers {
peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true)
account.UpdatePeer(peer)
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
if err != nil {
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
return nil, err
}
}
am.peersUpdateManager.CloseChannels(peerIDs)
err = am.updateAccountPeers(account)
if err != nil {
log.Errorf("failed updating account peers while expiring peers of a blocked user %s", accountID)
return nil, err
if err := am.expireAndUpdatePeers(account, blockedPeers); err != nil {
log.Errorf("failed update expired peers: %s", err)
return nil, err
} }
} }
@@ -814,6 +832,67 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
return userInfos, nil return userInfos, nil
} }
// expireAndUpdatePeers expires all peers of the given user and updates them in the account
func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers []*Peer) error {
var peerIDs []string
for _, peer := range peers {
peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true)
account.UpdatePeer(peer)
if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil {
return err
}
am.storeEvent(
peer.UserID, peer.ID, account.Id,
activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()),
)
}
if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service
am.peersUpdateManager.CloseChannels(peerIDs)
if err := am.updateAccountPeers(account); err != nil {
return err
}
}
return nil
}
func (am *DefaultAccountManager) deleteUserFromIDP(targetUserID, accountID string) error {
if am.userDeleteFromIDPEnabled {
log.Debugf("user %s deleted from IdP", targetUserID)
err := am.idpManager.DeleteUser(targetUserID)
if err != nil {
return fmt.Errorf("failed to delete user %s from IdP: %s", targetUserID, err)
}
} else {
err := am.idpManager.UpdateUserAppMetadata(targetUserID, idp.AppMetadata{})
if err != nil {
return fmt.Errorf("failed to remove user %s app metadata in IdP: %s", targetUserID, err)
}
_, err = am.refreshCache(accountID)
if err != nil {
log.Errorf("refresh account (%q) cache: %v", accountID, err)
}
}
return nil
}
func (am *DefaultAccountManager) getEmailAndNameOfTargetUser(accountId, initiatorId, targetId string) (string, string, error) {
userInfos, err := am.GetUsersFromAccount(accountId, initiatorId)
if err != nil {
return "", "", err
}
for _, ui := range userInfos {
if ui.ID == targetId {
return ui.Email, ui.Name, nil
}
}
return "", "", fmt.Errorf("user info not found for user: %s", targetId)
}
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
for _, user := range userData { for _, user := range userData {
if user.ID == userID { if user.ID == userID {

View File

@@ -424,7 +424,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
assert.Nil(t, store.Accounts[mockAccountID].Users[mockServiceUserID]) assert.Nil(t, store.Accounts[mockAccountID].Users[mockServiceUserID])
} }
func TestUser_DeleteUser_regularUser(t *testing.T) { func TestUser_DeleteUser_SelfDelete(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "") account := newAccountWithId(mockAccountID, mockUserID, "")
@@ -439,8 +439,35 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
} }
err = am.DeleteUser(mockAccountID, mockUserID, mockUserID) err = am.DeleteUser(mockAccountID, mockUserID, mockUserID)
if err == nil {
t.Fatalf("failed to prevent self deletion")
}
}
assert.Errorf(t, err, "Regular users can not be deleted (yet)") func TestUser_DeleteUser_regularUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
targetId := "user2"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: true,
ServiceUserName: "user2username",
}
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
err = am.DeleteUser(mockAccountID, mockUserID, targetId)
if err != nil {
t.Errorf("unexpected error: %s", err)
}
} }
func TestDefaultAccountManager_GetUser(t *testing.T) { func TestDefaultAccountManager_GetUser(t *testing.T) {

View File

@@ -5,6 +5,8 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
log "github.com/sirupsen/logrus"
) )
// WriteJson writes JSON config object to a file creating parent directories if required // WriteJson writes JSON config object to a file creating parent directories if required
@@ -54,6 +56,68 @@ func WriteJson(file string, obj interface{}) error {
return nil return nil
} }
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
func DirectWriteJson(file string, obj interface{}) error {
_, _, err := prepareConfigFileDir(file)
if err != nil {
return err
}
targetFile, err := openOrCreateFile(file)
if err != nil {
return err
}
defer func() {
err = targetFile.Close()
if err != nil {
log.Errorf("failed to close file %s: %v", file, err)
}
}()
// make it pretty
bs, err := json.MarshalIndent(obj, "", " ")
if err != nil {
return err
}
err = targetFile.Truncate(0)
if err != nil {
return err
}
_, err = targetFile.Write(bs)
if err != nil {
return err
}
return nil
}
func openOrCreateFile(file string) (*os.File, error) {
s, err := os.Stat(file)
if err == nil {
return os.OpenFile(file, os.O_WRONLY, s.Mode())
}
if !os.IsNotExist(err) {
return nil, err
}
targetFile, err := os.Create(file)
if err != nil {
return nil, err
}
//no:lint
err = targetFile.Chmod(0640)
if err != nil {
_ = targetFile.Close()
return nil, err
}
return targetFile, nil
}
// ReadJson reads JSON config file and maps to a provided interface // ReadJson reads JSON config file and maps to a provided interface
func ReadJson(file string, res interface{}) (interface{}, error) { func ReadJson(file string, res interface{}) (interface{}, error) {