mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-23 02:36:42 +00:00
Compare commits
14 Commits
handle-use
...
v0.23.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
732afd8393 | ||
|
|
da7b6b11ad | ||
|
|
e260270825 | ||
|
|
d4b6d7646c | ||
|
|
8febab4076 | ||
|
|
34e2c6b943 | ||
|
|
0be8c72601 | ||
|
|
c34e53477f | ||
|
|
8d18190c94 | ||
|
|
06bec61be9 | ||
|
|
2135533f1d | ||
|
|
bb791d59f3 | ||
|
|
30f1c54ed1 | ||
|
|
5c8541ef42 |
41
.github/workflows/android-build-validation.yml
vendored
Normal file
41
.github/workflows/android-build-validation.yml
vendored
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
name: Android build validation
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v4
|
||||||
|
with:
|
||||||
|
go-version: "1.20.x"
|
||||||
|
- name: Setup Android SDK
|
||||||
|
uses: android-actions/setup-android@v2
|
||||||
|
- name: NDK Cache
|
||||||
|
id: ndk-cache
|
||||||
|
uses: actions/cache@v3
|
||||||
|
with:
|
||||||
|
path: /usr/local/lib/android/sdk/ndk
|
||||||
|
key: ndk-cache-23.1.7779620
|
||||||
|
- name: Setup NDK
|
||||||
|
run: /usr/local/lib/android/sdk/tools/bin/sdkmanager --install "ndk;23.1.7779620"
|
||||||
|
- name: install gomobile
|
||||||
|
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
|
||||||
|
- name: gomobile init
|
||||||
|
run: gomobile init
|
||||||
|
- name: build android nebtird lib
|
||||||
|
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
|
||||||
|
env:
|
||||||
|
CGO_ENABLED: 0
|
||||||
|
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
||||||
@@ -80,6 +80,7 @@ jobs:
|
|||||||
CI_NETBIRD_MGMT_IDP: "none"
|
CI_NETBIRD_MGMT_IDP: "none"
|
||||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
||||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||||
|
CI_NETBIRD_SIGNAL_PORT: 12345
|
||||||
|
|
||||||
run: |
|
run: |
|
||||||
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
||||||
@@ -91,6 +92,7 @@ jobs:
|
|||||||
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
|
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
|
||||||
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
|
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
|
||||||
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
|
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
|
||||||
|
grep $CI_NETBIRD_SIGNAL_PORT docker-compose.yml | grep ':80'
|
||||||
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
|
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
|
||||||
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
|
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
|
||||||
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
|
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
|
||||||
@@ -120,7 +122,7 @@ jobs:
|
|||||||
|
|
||||||
- name: test running containers
|
- name: test running containers
|
||||||
run: |
|
run: |
|
||||||
count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running)
|
count=$(docker compose ps --format json | jq '. | select(.Name | contains("infrastructure_files")) | .State' | grep -c running)
|
||||||
test $count -eq 4
|
test $count -eq 4
|
||||||
working-directory: infrastructure_files
|
working-directory: infrastructure_files
|
||||||
|
|
||||||
|
|||||||
@@ -84,10 +84,14 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
|||||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||||
supportsSSO := true
|
supportsSSO := true
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
s, ok := gstatus.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
||||||
supportsSSO = false
|
supportsSSO = false
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -195,60 +193,12 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
|||||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
browserAuthMsg := "Please do the SSO login in your browser. \n" +
|
cmd.Println("Please do the SSO login in your browser. \n" +
|
||||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||||
verificationURIComplete + " " + codeMsg
|
verificationURIComplete + " " + codeMsg)
|
||||||
|
cmd.Println("")
|
||||||
setupKeyAuthMsg := "\nAlternatively, you may want to use a setup key, see:\n\n" +
|
if err := open.Run(verificationURIComplete); err != nil {
|
||||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys"
|
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||||
|
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||||
authenticateUsingBrowser := func() {
|
|
||||||
cmd.Println(browserAuthMsg)
|
|
||||||
cmd.Println("")
|
|
||||||
if err := open.Run(verificationURIComplete); err != nil {
|
|
||||||
cmd.Println(setupKeyAuthMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "windows", "darwin":
|
|
||||||
authenticateUsingBrowser()
|
|
||||||
case "linux":
|
|
||||||
if isLinuxRunningDesktop() {
|
|
||||||
authenticateUsingBrowser()
|
|
||||||
} else {
|
|
||||||
// If current flow is PKCE, it implies the server is anticipating the redirect to localhost.
|
|
||||||
// Devices lacking browser support are incompatible with this flow.Therefore,
|
|
||||||
// these devices will need to resort to setup keys instead.
|
|
||||||
if isPKCEFlow(verificationURIComplete) {
|
|
||||||
cmd.Println("Please proceed with setting up this device using setup keys, see:\n\n" +
|
|
||||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
|
||||||
} else {
|
|
||||||
cmd.Println(browserAuthMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment.
|
|
||||||
func isLinuxRunningDesktop() bool {
|
|
||||||
for _, env := range os.Environ() {
|
|
||||||
values := strings.Split(env, "=")
|
|
||||||
if len(values) == 2 {
|
|
||||||
key, value := values[0], values[1]
|
|
||||||
if key == "XDG_CURRENT_DESKTOP" && value != "" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// isPKCEFlow determines if the PKCE flow is active or not,
|
|
||||||
// by checking the existence of redirect_uri inside the verification URL.
|
|
||||||
func isPKCEFlow(verificationURL string) bool {
|
|
||||||
if verificationURL == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return strings.Contains(verificationURL, "redirect_uri")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
|
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
|
||||||
eventStore)
|
eventStore, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !linux
|
//go:build !linux || android
|
||||||
|
|
||||||
package acl
|
package acl
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
@@ -57,25 +57,43 @@ func (t TokenInfo) GetTokenToUse() string {
|
|||||||
return t.AccessToken
|
return t.AccessToken
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration.
|
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
|
||||||
|
//
|
||||||
|
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
|
||||||
|
// and if that also fails, the authentication process is deemed unsuccessful
|
||||||
|
//
|
||||||
|
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||||
func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||||
log.Debug("loading pkce authorization flow info")
|
if runtime.GOOS == "linux" && !isLinuxRunningDesktop() {
|
||||||
|
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
|
||||||
if err == nil {
|
|
||||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("loading pkce authorization flow info failed with error: %v", err)
|
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
|
||||||
log.Debugf("falling back to device authorization flow info")
|
if err != nil {
|
||||||
|
// fallback to device code flow
|
||||||
|
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||||
|
}
|
||||||
|
return pkceFlow, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||||
|
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||||
|
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||||
|
}
|
||||||
|
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||||
|
func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s, ok := gstatus.FromError(err)
|
s, ok := gstatus.FromError(err)
|
||||||
if ok && s.Code() == codes.NotFound {
|
if ok && s.Code() == codes.NotFound {
|
||||||
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||||
"If you are using hosting Netbird see documentation at " +
|
"Please proceed with setting up this device using setup keys " +
|
||||||
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||||
} else if ok && s.Code() == codes.Unimplemented {
|
} else if ok && s.Code() == codes.Unimplemented {
|
||||||
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||||
"please update your server or use Setup Keys to login", config.ManagementURL)
|
"please update your server or use Setup Keys to login", config.ManagementURL)
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -80,7 +79,7 @@ func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RequestAuthInfo requests a authorization code login flow information.
|
// RequestAuthInfo requests a authorization code login flow information.
|
||||||
func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) {
|
func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||||
state, err := randomBytesInHex(24)
|
state, err := randomBytesInHex(24)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
|
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
|
||||||
@@ -114,64 +113,37 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
|||||||
tokenChan := make(chan *oauth2.Token, 1)
|
tokenChan := make(chan *oauth2.Token, 1)
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
go p.startServer(tokenChan, errChan)
|
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
|
||||||
|
if err != nil {
|
||||||
|
return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
||||||
|
defer func() {
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||||
|
log.Errorf("failed to close the server: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go p.startServer(server, tokenChan, errChan)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return TokenInfo{}, ctx.Err()
|
return TokenInfo{}, ctx.Err()
|
||||||
case token := <-tokenChan:
|
case token := <-tokenChan:
|
||||||
return p.handleOAuthToken(token)
|
return p.parseOAuthToken(token)
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
return TokenInfo{}, err
|
return TokenInfo{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
||||||
var wg sync.WaitGroup
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||||
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
|
token, err := p.handleRequest(req)
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("failed to parse redirect URL: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
server := http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
|
||||||
go func() {
|
|
||||||
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
||||||
errChan <- err
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
tokenValidatorFunc := func() (*oauth2.Token, error) {
|
|
||||||
query := req.URL.Query()
|
|
||||||
|
|
||||||
if authError := query.Get(queryError); authError != "" {
|
|
||||||
authErrorDesc := query.Get(queryErrorDesc)
|
|
||||||
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prevent timing attacks on state
|
|
||||||
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
|
||||||
return nil, fmt.Errorf("invalid state")
|
|
||||||
}
|
|
||||||
|
|
||||||
code := query.Get(queryCode)
|
|
||||||
if code == "" {
|
|
||||||
return nil, fmt.Errorf("missing code")
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.oAuthConfig.Exchange(
|
|
||||||
req.Context(),
|
|
||||||
code,
|
|
||||||
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := tokenValidatorFunc()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
renderPKCEFlowTmpl(w, err)
|
renderPKCEFlowTmpl(w, err)
|
||||||
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
|
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
|
||||||
@@ -182,13 +154,38 @@ func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errC
|
|||||||
tokenChan <- token
|
tokenChan <- token
|
||||||
})
|
})
|
||||||
|
|
||||||
wg.Wait()
|
server.Handler = mux
|
||||||
if err := server.Shutdown(context.Background()); err != nil {
|
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
log.Errorf("error while shutting down pkce flow server: %v", err)
|
errChan <- err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) {
|
func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) {
|
||||||
|
query := req.URL.Query()
|
||||||
|
|
||||||
|
if authError := query.Get(queryError); authError != "" {
|
||||||
|
authErrorDesc := query.Get(queryErrorDesc)
|
||||||
|
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prevent timing attacks on the state
|
||||||
|
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
||||||
|
return nil, fmt.Errorf("invalid state")
|
||||||
|
}
|
||||||
|
|
||||||
|
code := query.Get(queryCode)
|
||||||
|
if code == "" {
|
||||||
|
return nil, fmt.Errorf("missing code")
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.oAuthConfig.Exchange(
|
||||||
|
req.Context(),
|
||||||
|
code,
|
||||||
|
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) {
|
||||||
tokenInfo := TokenInfo{
|
tokenInfo := TokenInfo{
|
||||||
AccessToken: token.AccessToken,
|
AccessToken: token.AccessToken,
|
||||||
RefreshToken: token.RefreshToken,
|
RefreshToken: token.RefreshToken,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@@ -60,3 +61,8 @@ func isValidAccessToken(token string, audience string) error {
|
|||||||
|
|
||||||
return fmt.Errorf("invalid JWT token audience field")
|
return fmt.Errorf("invalid JWT token audience field")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
|
||||||
|
func isLinuxRunningDesktop() bool {
|
||||||
|
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
//go:build !linux
|
//go:build !linux || android
|
||||||
|
|
||||||
package checkfw
|
package checkfw
|
||||||
|
|||||||
@@ -1049,7 +1049,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
|
|||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
|
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
|
||||||
eventStore)
|
eventStore, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -135,6 +135,7 @@ func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
|
|||||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
||||||
}
|
}
|
||||||
p.removeTurnConn(endpointPort)
|
p.removeTurnConn(endpointPort)
|
||||||
|
log.Infof("stop forward turn packages to port: %d. error: %s", endpointPort, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = p.sendPkg(buf[:n], endpointPort)
|
err = p.sendPkg(buf[:n], endpointPort)
|
||||||
@@ -158,7 +159,7 @@ func (p *WGEBPFProxy) proxyToRemote() {
|
|||||||
conn, ok := p.turnConnStore[uint16(addr.Port)]
|
conn, ok := p.turnConnStore[uint16(addr.Port)]
|
||||||
p.turnConnMutex.Unlock()
|
p.turnConnMutex.Unlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf("turn conn not found by port: %d", addr.Port)
|
log.Infof("turn conn not found by port: %d", addr.Port)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- $SIGNAL_VOLUMENAME:/var/lib/netbird
|
- $SIGNAL_VOLUMENAME:/var/lib/netbird
|
||||||
ports:
|
ports:
|
||||||
- 10000:80
|
- $NETBIRD_SIGNAL_PORT:80
|
||||||
# # port and command for Let's Encrypt validation
|
# # port and command for Let's Encrypt validation
|
||||||
# - 443:443
|
# - 443:443
|
||||||
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
||||||
|
|||||||
@@ -22,3 +22,4 @@ NETBIRD_AUTH_DEVICE_AUTH_SCOPE="openid email"
|
|||||||
NETBIRD_MGMT_IDP=$CI_NETBIRD_MGMT_IDP
|
NETBIRD_MGMT_IDP=$CI_NETBIRD_MGMT_IDP
|
||||||
NETBIRD_IDP_MGMT_CLIENT_ID=$CI_NETBIRD_IDP_MGMT_CLIENT_ID
|
NETBIRD_IDP_MGMT_CLIENT_ID=$CI_NETBIRD_IDP_MGMT_CLIENT_ID
|
||||||
NETBIRD_IDP_MGMT_CLIENT_SECRET=$CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
|
NETBIRD_IDP_MGMT_CLIENT_SECRET=$CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
|
||||||
|
NETBIRD_SIGNAL_PORT=12345
|
||||||
@@ -61,7 +61,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
|||||||
peersUpdateManager := mgmt.NewPeersUpdateManager()
|
peersUpdateManager := mgmt.NewPeersUpdateManager()
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
|
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
|
||||||
eventStore)
|
eventStore, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/activity/sqlite"
|
"github.com/netbirdio/netbird/management/server/activity/sqlite"
|
||||||
httpapi "github.com/netbirdio/netbird/management/server/http"
|
httpapi "github.com/netbirdio/netbird/management/server/http"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
@@ -142,12 +143,22 @@ var (
|
|||||||
if disableSingleAccMode {
|
if disableSingleAccMode {
|
||||||
mgmtSingleAccModeDomain = ""
|
mgmtSingleAccModeDomain = ""
|
||||||
}
|
}
|
||||||
eventStore, err := sqlite.NewSQLiteStore(config.Datadir)
|
eventStore, key, err := initEventStore(config.Datadir, config.DataStoreEncryptionKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to initialize database: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if key != "" {
|
||||||
|
log.Infof("update config with activity store key")
|
||||||
|
config.DataStoreEncryptionKey = key
|
||||||
|
err := updateMgmtConfig(mgmtConfig, config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write out store encryption key: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
||||||
dnsDomain, eventStore)
|
dnsDomain, eventStore, userDeleteFromIDPEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to build default manager: %v", err)
|
return fmt.Errorf("failed to build default manager: %v", err)
|
||||||
}
|
}
|
||||||
@@ -287,6 +298,20 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func initEventStore(dataDir string, key string) (activity.Store, string, error) {
|
||||||
|
var err error
|
||||||
|
if key == "" {
|
||||||
|
log.Debugf("generate new activity store encryption key")
|
||||||
|
key, err = sqlite.GenerateKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
store, err := sqlite.NewSQLiteStore(dataDir, key)
|
||||||
|
return store, key, err
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func notifyStop(msg string) {
|
func notifyStop(msg string) {
|
||||||
select {
|
select {
|
||||||
case stopCh <- 1:
|
case stopCh <- 1:
|
||||||
@@ -440,6 +465,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
|
|||||||
return loadedConfig, err
|
return loadedConfig, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func updateMgmtConfig(path string, config *server.Config) error {
|
||||||
|
return util.DirectWriteJson(path, config)
|
||||||
|
}
|
||||||
|
|
||||||
// OIDCConfigResponse used for parsing OIDC config response
|
// OIDCConfigResponse used for parsing OIDC config response
|
||||||
type OIDCConfigResponse struct {
|
type OIDCConfigResponse struct {
|
||||||
Issuer string `json:"issuer"`
|
Issuer string `json:"issuer"`
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ var (
|
|||||||
disableMetrics bool
|
disableMetrics bool
|
||||||
disableSingleAccMode bool
|
disableSingleAccMode bool
|
||||||
idpSignKeyRefreshEnabled bool
|
idpSignKeyRefreshEnabled bool
|
||||||
|
userDeleteFromIDPEnabled bool
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "netbird-mgmt",
|
Use: "netbird-mgmt",
|
||||||
@@ -56,6 +57,7 @@ func init() {
|
|||||||
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
|
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
|
||||||
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
|
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
|
||||||
mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.")
|
mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.")
|
||||||
|
mgmtCmd.Flags().BoolVar(&userDeleteFromIDPEnabled, "user-delete-from-idp", false, "Allows to delete user from IDP when user is deleted from account")
|
||||||
rootCmd.MarkFlagRequired("config") //nolint
|
rootCmd.MarkFlagRequired("config") //nolint
|
||||||
|
|
||||||
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "")
|
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "")
|
||||||
|
|||||||
@@ -80,7 +80,6 @@ type AccountManager interface {
|
|||||||
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
|
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
|
||||||
GetGroup(accountId, groupID string) (*Group, error)
|
GetGroup(accountId, groupID string) (*Group, error)
|
||||||
SaveGroup(accountID, userID string, group *Group) error
|
SaveGroup(accountID, userID string, group *Group) error
|
||||||
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
|
|
||||||
DeleteGroup(accountId, userId, groupID string) error
|
DeleteGroup(accountId, userId, groupID string) error
|
||||||
ListGroups(accountId string) ([]*Group, error)
|
ListGroups(accountId string) ([]*Group, error)
|
||||||
GroupAddPeer(accountId, groupID, peerID string) error
|
GroupAddPeer(accountId, groupID, peerID string) error
|
||||||
@@ -93,13 +92,11 @@ type AccountManager interface {
|
|||||||
GetRoute(accountID, routeID, userID string) (*route.Route, error)
|
GetRoute(accountID, routeID, userID string) (*route.Route, error)
|
||||||
CreateRoute(accountID string, prefix, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
CreateRoute(accountID string, prefix, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
||||||
SaveRoute(accountID, userID string, route *route.Route) error
|
SaveRoute(accountID, userID string, route *route.Route) error
|
||||||
UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
|
|
||||||
DeleteRoute(accountID, routeID, userID string) error
|
DeleteRoute(accountID, routeID, userID string) error
|
||||||
ListRoutes(accountID, userID string) ([]*route.Route, error)
|
ListRoutes(accountID, userID string) ([]*route.Route, error)
|
||||||
GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||||
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
|
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
|
||||||
SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||||
UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
|
|
||||||
DeleteNameServerGroup(accountID, nsGroupID, userID string) error
|
DeleteNameServerGroup(accountID, nsGroupID, userID string) error
|
||||||
ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error)
|
ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error)
|
||||||
GetDNSDomain() string
|
GetDNSDomain() string
|
||||||
@@ -133,6 +130,9 @@ type DefaultAccountManager struct {
|
|||||||
// dnsDomain is used for peer resolution. This is appended to the peer's name
|
// dnsDomain is used for peer resolution. This is appended to the peer's name
|
||||||
dnsDomain string
|
dnsDomain string
|
||||||
peerLoginExpiry Scheduler
|
peerLoginExpiry Scheduler
|
||||||
|
|
||||||
|
// userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account
|
||||||
|
userDeleteFromIDPEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Settings represents Account settings structure that can be modified via API and Dashboard
|
// Settings represents Account settings structure that can be modified via API and Dashboard
|
||||||
@@ -738,18 +738,19 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
|
|||||||
|
|
||||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||||
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
|
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
|
||||||
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
|
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, userDeleteFromIDPEnabled bool,
|
||||||
) (*DefaultAccountManager, error) {
|
) (*DefaultAccountManager, error) {
|
||||||
am := &DefaultAccountManager{
|
am := &DefaultAccountManager{
|
||||||
Store: store,
|
Store: store,
|
||||||
peersUpdateManager: peersUpdateManager,
|
peersUpdateManager: peersUpdateManager,
|
||||||
idpManager: idpManager,
|
idpManager: idpManager,
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
cacheMux: sync.Mutex{},
|
cacheMux: sync.Mutex{},
|
||||||
cacheLoading: map[string]chan struct{}{},
|
cacheLoading: map[string]chan struct{}{},
|
||||||
dnsDomain: dnsDomain,
|
dnsDomain: dnsDomain,
|
||||||
eventStore: eventStore,
|
eventStore: eventStore,
|
||||||
peerLoginExpiry: NewDefaultScheduler(),
|
peerLoginExpiry: NewDefaultScheduler(),
|
||||||
|
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||||
}
|
}
|
||||||
allAccounts := store.GetAllAccounts()
|
allAccounts := store.GetAllAccounts()
|
||||||
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
|
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
|
||||||
@@ -874,33 +875,19 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func()
|
|||||||
return account.GetNextPeerExpiration()
|
return account.GetNextPeerExpiration()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expiredPeers := account.GetExpiredPeers()
|
||||||
var peerIDs []string
|
var peerIDs []string
|
||||||
for _, peer := range account.GetExpiredPeers() {
|
for _, peer := range expiredPeers {
|
||||||
if peer.Status.LoginExpired {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
peerIDs = append(peerIDs, peer.ID)
|
peerIDs = append(peerIDs, peer.ID)
|
||||||
peer.MarkLoginExpired(true)
|
|
||||||
account.UpdatePeer(peer)
|
|
||||||
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
|
|
||||||
return account.GetNextPeerExpiration()
|
|
||||||
}
|
|
||||||
am.storeEvent(peer.UserID, peer.ID, account.Id, activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
|
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
|
||||||
|
|
||||||
if len(peerIDs) != 0 {
|
if err := am.expireAndUpdatePeers(account, expiredPeers); err != nil {
|
||||||
// this will trigger peer disconnect from the management service
|
log.Errorf("failed updating account peers while expiring peers for account %s", account.Id)
|
||||||
am.peersUpdateManager.CloseChannels(peerIDs)
|
return account.GetNextPeerExpiration()
|
||||||
err = am.updateAccountPeers(account)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed updating account peers while expiring peers for account %s", accountID)
|
|
||||||
return account.GetNextPeerExpiration()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return account.GetNextPeerExpiration()
|
return account.GetNextPeerExpiration()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1605,19 +1592,3 @@ func newAccountWithId(accountID, userID, domain string) *Account {
|
|||||||
}
|
}
|
||||||
return acc
|
return acc
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeFromList(inputList []string, toRemove []string) []string {
|
|
||||||
toRemoveMap := make(map[string]struct{})
|
|
||||||
for _, item := range toRemove {
|
|
||||||
toRemoveMap[item] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
var resultList []string
|
|
||||||
for _, item := range inputList {
|
|
||||||
_, ok := toRemoveMap[item]
|
|
||||||
if !ok {
|
|
||||||
resultList = append(resultList, item)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return resultList
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2063,7 +2063,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore)
|
return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createStore(t *testing.T) (Store, error) {
|
func createStore(t *testing.T) (Store, error) {
|
||||||
|
|||||||
@@ -104,6 +104,8 @@ const (
|
|||||||
UserBlocked
|
UserBlocked
|
||||||
// UserUnblocked indicates that a user unblocked another user
|
// UserUnblocked indicates that a user unblocked another user
|
||||||
UserUnblocked
|
UserUnblocked
|
||||||
|
// UserDeleted indicates that a user deleted another user
|
||||||
|
UserDeleted
|
||||||
// GroupDeleted indicates that a user deleted group
|
// GroupDeleted indicates that a user deleted group
|
||||||
GroupDeleted
|
GroupDeleted
|
||||||
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
|
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
|
||||||
@@ -162,6 +164,7 @@ var activityMap = map[Activity]Code{
|
|||||||
ServiceUserDeleted: {"Service user deleted", "service.user.delete"},
|
ServiceUserDeleted: {"Service user deleted", "service.user.delete"},
|
||||||
UserBlocked: {"User blocked", "user.block"},
|
UserBlocked: {"User blocked", "user.block"},
|
||||||
UserUnblocked: {"User unblocked", "user.unblock"},
|
UserUnblocked: {"User unblocked", "user.unblock"},
|
||||||
|
UserDeleted: {"User deleted", "user.delete"},
|
||||||
GroupDeleted: {"Group deleted", "group.delete"},
|
GroupDeleted: {"Group deleted", "group.delete"},
|
||||||
UserLoggedInPeer: {"User logged in peer", "user.peer.login"},
|
UserLoggedInPeer: {"User logged in peer", "user.peer.login"},
|
||||||
PeerLoginExpired: {"Peer login expired", "peer.login.expire"},
|
PeerLoginExpired: {"Peer login expired", "peer.login.expire"},
|
||||||
|
|||||||
@@ -18,10 +18,15 @@ type Event struct {
|
|||||||
ID uint64
|
ID uint64
|
||||||
// InitiatorID is the ID of an object that initiated the event (e.g., a user)
|
// InitiatorID is the ID of an object that initiated the event (e.g., a user)
|
||||||
InitiatorID string
|
InitiatorID string
|
||||||
|
// InitiatorName is the name of an object that initiated the event.
|
||||||
|
InitiatorName string
|
||||||
|
// InitiatorEmail is the email address of an object that initiated the event.
|
||||||
|
InitiatorEmail string
|
||||||
// TargetID is the ID of an object that was effected by the event (e.g., a peer)
|
// TargetID is the ID of an object that was effected by the event (e.g., a peer)
|
||||||
TargetID string
|
TargetID string
|
||||||
// AccountID is the ID of an account where the event happened
|
// AccountID is the ID of an account where the event happened
|
||||||
AccountID string
|
AccountID string
|
||||||
|
|
||||||
// Meta of the event, e.g. deleted peer information like name, IP, etc
|
// Meta of the event, e.g. deleted peer information like name, IP, etc
|
||||||
Meta map[string]any
|
Meta map[string]any
|
||||||
}
|
}
|
||||||
@@ -35,12 +40,14 @@ func (e *Event) Copy() *Event {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Event{
|
return &Event{
|
||||||
Timestamp: e.Timestamp,
|
Timestamp: e.Timestamp,
|
||||||
Activity: e.Activity,
|
Activity: e.Activity,
|
||||||
ID: e.ID,
|
ID: e.ID,
|
||||||
InitiatorID: e.InitiatorID,
|
InitiatorID: e.InitiatorID,
|
||||||
TargetID: e.TargetID,
|
InitiatorName: e.InitiatorName,
|
||||||
AccountID: e.AccountID,
|
InitiatorEmail: e.InitiatorEmail,
|
||||||
Meta: meta,
|
TargetID: e.TargetID,
|
||||||
|
AccountID: e.AccountID,
|
||||||
|
Meta: meta,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
81
management/server/activity/sqlite/crypt.go
Normal file
81
management/server/activity/sqlite/crypt.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package sqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
|
||||||
|
|
||||||
|
type FieldEncrypt struct {
|
||||||
|
block cipher.Block
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateKey() (string, error) {
|
||||||
|
key := make([]byte, 32)
|
||||||
|
_, err := rand.Read(key)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
readableKey := base64.StdEncoding.EncodeToString(key)
|
||||||
|
return readableKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFieldEncrypt(key string) (*FieldEncrypt, error) {
|
||||||
|
binKey, err := base64.StdEncoding.DecodeString(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(binKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ec := &FieldEncrypt{
|
||||||
|
block: block,
|
||||||
|
}
|
||||||
|
|
||||||
|
return ec, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ec *FieldEncrypt) Encrypt(payload string) string {
|
||||||
|
plainText := pkcs5Padding([]byte(payload))
|
||||||
|
cipherText := make([]byte, len(plainText))
|
||||||
|
cbc := cipher.NewCBCEncrypter(ec.block, iv)
|
||||||
|
cbc.CryptBlocks(cipherText, plainText)
|
||||||
|
return base64.StdEncoding.EncodeToString(cipherText)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
|
||||||
|
cipherText, err := base64.StdEncoding.DecodeString(data)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
cbc := cipher.NewCBCDecrypter(ec.block, iv)
|
||||||
|
cbc.CryptBlocks(cipherText, cipherText)
|
||||||
|
payload, err := pkcs5UnPadding(cipherText)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(payload), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func pkcs5Padding(ciphertext []byte) []byte {
|
||||||
|
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
|
||||||
|
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||||
|
return append(ciphertext, padText...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pkcs5UnPadding(src []byte) ([]byte, error) {
|
||||||
|
srcLen := len(src)
|
||||||
|
paddingLen := int(src[srcLen-1])
|
||||||
|
if paddingLen >= srcLen || paddingLen > aes.BlockSize {
|
||||||
|
return nil, fmt.Errorf("padding size error")
|
||||||
|
}
|
||||||
|
return src[:srcLen-paddingLen], nil
|
||||||
|
}
|
||||||
63
management/server/activity/sqlite/crypt_test.go
Normal file
63
management/server/activity/sqlite/crypt_test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package sqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateKey(t *testing.T) {
|
||||||
|
testData := "exampl@netbird.io"
|
||||||
|
key, err := GenerateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to generate key: %s", err)
|
||||||
|
}
|
||||||
|
ee, err := NewFieldEncrypt(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to init email encryption: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted := ee.Encrypt(testData)
|
||||||
|
if encrypted == "" {
|
||||||
|
t.Fatalf("invalid encrypted text")
|
||||||
|
}
|
||||||
|
|
||||||
|
decrypted, err := ee.Decrypt(encrypted)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to decrypt data: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if decrypted != testData {
|
||||||
|
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCorruptKey(t *testing.T) {
|
||||||
|
testData := "exampl@netbird.io"
|
||||||
|
key, err := GenerateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to generate key: %s", err)
|
||||||
|
}
|
||||||
|
ee, err := NewFieldEncrypt(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to init email encryption: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted := ee.Encrypt(testData)
|
||||||
|
if encrypted == "" {
|
||||||
|
t.Fatalf("invalid encrypted text")
|
||||||
|
}
|
||||||
|
|
||||||
|
newKey, err := GenerateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to generate key: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ee, err = NewFieldEncrypt(newKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to init email encryption: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, _ := ee.Decrypt(encrypted)
|
||||||
|
if res == testData {
|
||||||
|
t.Fatalf("incorrect decryption, the result is: %s", res)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,14 +3,14 @@ package sqlite
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
|
||||||
|
|
||||||
// sqlite driver
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -25,69 +25,122 @@ const (
|
|||||||
"meta TEXT," +
|
"meta TEXT," +
|
||||||
" target_id TEXT);"
|
" target_id TEXT);"
|
||||||
|
|
||||||
selectDescQuery = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" +
|
creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`
|
||||||
" FROM events WHERE account_id = ? ORDER BY timestamp DESC LIMIT ? OFFSET ?;"
|
|
||||||
selectAscQuery = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" +
|
selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
|
||||||
" FROM events WHERE account_id = ? ORDER BY timestamp ASC LIMIT ? OFFSET ?;"
|
FROM events
|
||||||
|
LEFT JOIN deleted_users i ON events.initiator_id = i.id
|
||||||
|
LEFT JOIN deleted_users t ON events.target_id = t.id
|
||||||
|
WHERE account_id = ?
|
||||||
|
ORDER BY timestamp DESC LIMIT ? OFFSET ?;`
|
||||||
|
|
||||||
|
selectAscQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
|
||||||
|
FROM events
|
||||||
|
LEFT JOIN deleted_users i ON events.initiator_id = i.id
|
||||||
|
LEFT JOIN deleted_users t ON events.target_id = t.id
|
||||||
|
WHERE account_id = ?
|
||||||
|
ORDER BY timestamp ASC LIMIT ? OFFSET ?;`
|
||||||
|
|
||||||
insertQuery = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " +
|
insertQuery = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " +
|
||||||
"VALUES(?, ?, ?, ?, ?, ?)"
|
"VALUES(?, ?, ?, ?, ?, ?)"
|
||||||
|
|
||||||
|
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`
|
||||||
)
|
)
|
||||||
|
|
||||||
// Store is the implementation of the activity.Store interface backed by SQLite
|
// Store is the implementation of the activity.Store interface backed by SQLite
|
||||||
type Store struct {
|
type Store struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
fieldEncrypt *FieldEncrypt
|
||||||
|
|
||||||
insertStatement *sql.Stmt
|
insertStatement *sql.Stmt
|
||||||
selectAscStatement *sql.Stmt
|
selectAscStatement *sql.Stmt
|
||||||
selectDescStatement *sql.Stmt
|
selectDescStatement *sql.Stmt
|
||||||
|
deleteUserStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSQLiteStore creates a new Store with an event table if not exists.
|
// NewSQLiteStore creates a new Store with an event table if not exists.
|
||||||
func NewSQLiteStore(dataDir string) (*Store, error) {
|
func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
|
||||||
dbFile := filepath.Join(dataDir, eventSinkDB)
|
dbFile := filepath.Join(dataDir, eventSinkDB)
|
||||||
db, err := sql.Open("sqlite3", dbFile)
|
db, err := sql.Open("sqlite3", dbFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
crypt, err := NewFieldEncrypt(encryptionKey)
|
||||||
|
if err != nil {
|
||||||
|
_ = db.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
_, err = db.Exec(createTableQuery)
|
_, err = db.Exec(createTableQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
_ = db.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.Exec(creatTableDeletedUsersQuery)
|
||||||
|
if err != nil {
|
||||||
|
_ = db.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = updateDeletedUsersTable(db)
|
||||||
|
if err != nil {
|
||||||
|
_ = db.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
insertStmt, err := db.Prepare(insertQuery)
|
insertStmt, err := db.Prepare(insertQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
_ = db.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
selectDescStmt, err := db.Prepare(selectDescQuery)
|
selectDescStmt, err := db.Prepare(selectDescQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
_ = db.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
selectAscStmt, err := db.Prepare(selectAscQuery)
|
selectAscStmt, err := db.Prepare(selectAscQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
_ = db.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Store{
|
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
|
||||||
|
if err != nil {
|
||||||
|
_ = db.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Store{
|
||||||
db: db,
|
db: db,
|
||||||
|
fieldEncrypt: crypt,
|
||||||
insertStatement: insertStmt,
|
insertStatement: insertStmt,
|
||||||
selectDescStatement: selectDescStmt,
|
selectDescStatement: selectDescStmt,
|
||||||
selectAscStatement: selectAscStmt,
|
selectAscStatement: selectAscStmt,
|
||||||
}, nil
|
deleteUserStmt: deleteUserStmt,
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func processResult(result *sql.Rows) ([]*activity.Event, error) {
|
func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
|
||||||
events := make([]*activity.Event, 0)
|
events := make([]*activity.Event, 0)
|
||||||
for result.Next() {
|
for result.Next() {
|
||||||
var id int64
|
var id int64
|
||||||
var operation activity.Activity
|
var operation activity.Activity
|
||||||
var timestamp time.Time
|
var timestamp time.Time
|
||||||
var initiator string
|
var initiator string
|
||||||
|
var initiatorName *string
|
||||||
|
var initiatorEmail *string
|
||||||
var target string
|
var target string
|
||||||
|
var targetUserName *string
|
||||||
|
var targetEmail *string
|
||||||
var account string
|
var account string
|
||||||
var jsonMeta string
|
var jsonMeta string
|
||||||
err := result.Scan(&id, &operation, ×tamp, &initiator, &target, &account, &jsonMeta)
|
err := result.Scan(&id, &operation, ×tamp, &initiator, &initiatorName, &initiatorEmail, &target, &targetUserName, &targetEmail, &account, &jsonMeta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -100,7 +153,27 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
events = append(events, &activity.Event{
|
if targetUserName != nil {
|
||||||
|
name, err := store.fieldEncrypt.Decrypt(*targetUserName)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to decrypt username for target id: %s", target)
|
||||||
|
meta["username"] = ""
|
||||||
|
} else {
|
||||||
|
meta["username"] = name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if targetEmail != nil {
|
||||||
|
email, err := store.fieldEncrypt.Decrypt(*targetEmail)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to decrypt email address for target id: %s", target)
|
||||||
|
meta["email"] = ""
|
||||||
|
} else {
|
||||||
|
meta["email"] = email
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
event := &activity.Event{
|
||||||
Timestamp: timestamp,
|
Timestamp: timestamp,
|
||||||
Activity: operation,
|
Activity: operation,
|
||||||
ID: uint64(id),
|
ID: uint64(id),
|
||||||
@@ -108,7 +181,27 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) {
|
|||||||
TargetID: target,
|
TargetID: target,
|
||||||
AccountID: account,
|
AccountID: account,
|
||||||
Meta: meta,
|
Meta: meta,
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if initiatorName != nil {
|
||||||
|
name, err := store.fieldEncrypt.Decrypt(*initiatorName)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to decrypt username of initiator: %s", initiator)
|
||||||
|
} else {
|
||||||
|
event.InitiatorName = name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if initiatorEmail != nil {
|
||||||
|
email, err := store.fieldEncrypt.Decrypt(*initiatorEmail)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to decrypt email address of initiator: %s", initiator)
|
||||||
|
} else {
|
||||||
|
event.InitiatorEmail = email
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
events = append(events, event)
|
||||||
}
|
}
|
||||||
|
|
||||||
return events, nil
|
return events, nil
|
||||||
@@ -127,13 +220,18 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([
|
|||||||
}
|
}
|
||||||
|
|
||||||
defer result.Close() //nolint
|
defer result.Close() //nolint
|
||||||
return processResult(result)
|
return store.processResult(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save an event in the SQLite events table
|
// Save an event in the SQLite events table end encrypt the "email" element in meta map
|
||||||
func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
|
func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
|
||||||
var jsonMeta string
|
var jsonMeta string
|
||||||
if event.Meta != nil {
|
meta, err := store.saveDeletedUserEmailAndNameInEncrypted(event)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if meta != nil {
|
||||||
metaBytes, err := json.Marshal(event.Meta)
|
metaBytes, err := json.Marshal(event.Meta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -156,6 +254,34 @@ func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
|
|||||||
return eventCopy, nil
|
return eventCopy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// saveDeletedUserEmailAndNameInEncrypted if the meta contains email and name then store it in encrypted way and delete
|
||||||
|
// this item from meta map
|
||||||
|
func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event) (map[string]any, error) {
|
||||||
|
email, ok := event.Meta["email"]
|
||||||
|
if !ok {
|
||||||
|
return event.Meta, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
name, ok := event.Meta["name"]
|
||||||
|
if !ok {
|
||||||
|
return event.Meta, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
encryptedEmail := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
|
||||||
|
encryptedName := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
|
||||||
|
_, err := store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(event.Meta) == 2 {
|
||||||
|
return nil, nil // nolint
|
||||||
|
}
|
||||||
|
delete(event.Meta, "email")
|
||||||
|
delete(event.Meta, "name")
|
||||||
|
return event.Meta, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Close the Store
|
// Close the Store
|
||||||
func (store *Store) Close() error {
|
func (store *Store) Close() error {
|
||||||
if store.db != nil {
|
if store.db != nil {
|
||||||
@@ -163,3 +289,44 @@ func (store *Store) Close() error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func updateDeletedUsersTable(db *sql.DB) error {
|
||||||
|
log.Debugf("check deleted_users table version")
|
||||||
|
rows, err := db.Query(`PRAGMA table_info(deleted_users);`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
found := false
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
cid int
|
||||||
|
name string
|
||||||
|
dataType string
|
||||||
|
notNull int
|
||||||
|
dfltVal sql.NullString
|
||||||
|
pk int
|
||||||
|
)
|
||||||
|
err := rows.Scan(&cid, &name, &dataType, ¬Null, &dfltVal, &pk)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if name == "name" {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rows.Err()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if found {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("update delted_users table")
|
||||||
|
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ import (
|
|||||||
|
|
||||||
func TestNewSQLiteStore(t *testing.T) {
|
func TestNewSQLiteStore(t *testing.T) {
|
||||||
dataDir := t.TempDir()
|
dataDir := t.TempDir()
|
||||||
store, err := NewSQLiteStore(dataDir)
|
key, _ := GenerateKey()
|
||||||
|
store, err := NewSQLiteStore(dataDir, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -35,7 +35,8 @@ type Config struct {
|
|||||||
TURNConfig *TURNConfig
|
TURNConfig *TURNConfig
|
||||||
Signal *Host
|
Signal *Host
|
||||||
|
|
||||||
Datadir string
|
Datadir string
|
||||||
|
DataStoreEncryptionKey string
|
||||||
|
|
||||||
HttpConfig *HttpServerConfig
|
HttpConfig *HttpServerConfig
|
||||||
|
|
||||||
|
|||||||
@@ -191,7 +191,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.test", eventStore)
|
return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.test", eventStore, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createDNSStore(t *testing.T) (Store, error) {
|
func createDNSStore(t *testing.T) (Store, error) {
|
||||||
|
|||||||
@@ -33,26 +33,6 @@ type Group struct {
|
|||||||
Peers []string
|
Peers []string
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
// UpdateGroupName indicates a name update operation
|
|
||||||
UpdateGroupName GroupUpdateOperationType = iota
|
|
||||||
// InsertPeersToGroup indicates insert peers to group operation
|
|
||||||
InsertPeersToGroup
|
|
||||||
// RemovePeersFromGroup indicates a remove peers from group operation
|
|
||||||
RemovePeersFromGroup
|
|
||||||
// UpdateGroupPeers indicates a replacement of group peers list
|
|
||||||
UpdateGroupPeers
|
|
||||||
)
|
|
||||||
|
|
||||||
// GroupUpdateOperationType operation type
|
|
||||||
type GroupUpdateOperationType int
|
|
||||||
|
|
||||||
// GroupUpdateOperation operation object with type and values to be applied
|
|
||||||
type GroupUpdateOperation struct {
|
|
||||||
Type GroupUpdateOperationType
|
|
||||||
Values []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// EventMeta returns activity event meta related to the group
|
// EventMeta returns activity event meta related to the group
|
||||||
func (g *Group) EventMeta() map[string]any {
|
func (g *Group) EventMeta() map[string]any {
|
||||||
return map[string]any{"name": g.Name}
|
return map[string]any{"name": g.Name}
|
||||||
@@ -165,57 +145,6 @@ func difference(a, b []string) []string {
|
|||||||
return diff
|
return diff
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateGroup updates a group using a list of operations
|
|
||||||
func (am *DefaultAccountManager) UpdateGroup(accountID string,
|
|
||||||
groupID string, operations []GroupUpdateOperation,
|
|
||||||
) (*Group, error) {
|
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
groupToUpdate, ok := account.Groups[groupID]
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Errorf(status.NotFound, "group with ID %s no longer exists", groupID)
|
|
||||||
}
|
|
||||||
|
|
||||||
group := groupToUpdate.Copy()
|
|
||||||
|
|
||||||
for _, operation := range operations {
|
|
||||||
switch operation.Type {
|
|
||||||
case UpdateGroupName:
|
|
||||||
group.Name = operation.Values[0]
|
|
||||||
case UpdateGroupPeers:
|
|
||||||
group.Peers = operation.Values
|
|
||||||
case InsertPeersToGroup:
|
|
||||||
sourceList := group.Peers
|
|
||||||
resultList := removeFromList(sourceList, operation.Values)
|
|
||||||
group.Peers = append(resultList, operation.Values...)
|
|
||||||
case RemovePeersFromGroup:
|
|
||||||
sourceList := group.Peers
|
|
||||||
resultList := removeFromList(sourceList, operation.Values)
|
|
||||||
group.Peers = resultList
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Groups[groupID] = group
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(account); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = am.updateAccountPeers(account)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return group, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteGroup object of the peers
|
// DeleteGroup object of the peers
|
||||||
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountId)
|
unlock := am.Store.AcquireAccountLock(accountId)
|
||||||
|
|||||||
0
management/server/http/api/generate.sh
Normal file → Executable file
0
management/server/http/api/generate.sh
Normal file → Executable file
@@ -922,6 +922,14 @@ components:
|
|||||||
description: The ID of the initiator of the event. E.g., an ID of a user that triggered the event.
|
description: The ID of the initiator of the event. E.g., an ID of a user that triggered the event.
|
||||||
type: string
|
type: string
|
||||||
example: google-oauth2|123456789012345678901
|
example: google-oauth2|123456789012345678901
|
||||||
|
initiator_name:
|
||||||
|
description: The name of the initiator of the event.
|
||||||
|
type: string
|
||||||
|
example: John Doe
|
||||||
|
initiator_email:
|
||||||
|
description: The e-mail address of the initiator of the event. E.g., an e-mail of a user that triggered the event.
|
||||||
|
type: string
|
||||||
|
example: demo@netbird.io
|
||||||
target_id:
|
target_id:
|
||||||
description: The ID of the target of the event. E.g., an ID of the peer that a user removed.
|
description: The ID of the target of the event. E.g., an ID of the peer that a user removed.
|
||||||
type: string
|
type: string
|
||||||
@@ -938,6 +946,8 @@ components:
|
|||||||
- activity
|
- activity
|
||||||
- activity_code
|
- activity_code
|
||||||
- initiator_id
|
- initiator_id
|
||||||
|
- initiator_name
|
||||||
|
- initiator_email
|
||||||
- target_id
|
- target_id
|
||||||
- meta
|
- meta
|
||||||
responses:
|
responses:
|
||||||
|
|||||||
@@ -164,9 +164,15 @@ type Event struct {
|
|||||||
// Id Event unique identifier
|
// Id Event unique identifier
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
|
|
||||||
|
// InitiatorEmail The e-mail address of the initiator of the event. E.g., an e-mail of a user that triggered the event.
|
||||||
|
InitiatorEmail string `json:"initiator_email"`
|
||||||
|
|
||||||
// InitiatorId The ID of the initiator of the event. E.g., an ID of a user that triggered the event.
|
// InitiatorId The ID of the initiator of the event. E.g., an ID of a user that triggered the event.
|
||||||
InitiatorId string `json:"initiator_id"`
|
InitiatorId string `json:"initiator_id"`
|
||||||
|
|
||||||
|
// InitiatorName The name of the initiator of the event.
|
||||||
|
InitiatorName string `json:"initiator_name"`
|
||||||
|
|
||||||
// Meta The metadata of the event
|
// Meta The metadata of the event
|
||||||
Meta map[string]string `json:"meta"`
|
Meta map[string]string `json:"meta"`
|
||||||
|
|
||||||
|
|||||||
@@ -45,14 +45,66 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteError(err, w)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
events := make([]*api.Event, 0)
|
events := make([]*api.Event, len(accountEvents))
|
||||||
for _, e := range accountEvents {
|
for i, e := range accountEvents {
|
||||||
events = append(events, toEventResponse(e))
|
events[i] = toEventResponse(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.fillEventsWithUserInfo(events, account.Id, user.Id)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(err, w)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(w, events)
|
util.WriteJSONObject(w, events)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, userId string) error {
|
||||||
|
// build email, name maps based on users
|
||||||
|
userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get users from account: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
emails := make(map[string]string)
|
||||||
|
names := make(map[string]string)
|
||||||
|
for _, ui := range userInfos {
|
||||||
|
emails[ui.ID] = ui.Email
|
||||||
|
names[ui.ID] = ui.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
var ok bool
|
||||||
|
for _, event := range events {
|
||||||
|
// fill initiator
|
||||||
|
if event.InitiatorEmail == "" {
|
||||||
|
event.InitiatorEmail, ok = emails[event.InitiatorId]
|
||||||
|
if !ok {
|
||||||
|
log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if event.InitiatorName == "" {
|
||||||
|
// here to allowed to be empty because in the first release we did not store the name
|
||||||
|
event.InitiatorName = names[event.InitiatorId]
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill target meta
|
||||||
|
email, ok := emails[event.TargetId]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
event.Meta["email"] = email
|
||||||
|
|
||||||
|
username, ok := names[event.TargetId]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
event.Meta["username"] = username
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func toEventResponse(event *activity.Event) *api.Event {
|
func toEventResponse(event *activity.Event) *api.Event {
|
||||||
meta := make(map[string]string)
|
meta := make(map[string]string)
|
||||||
if event.Meta != nil {
|
if event.Meta != nil {
|
||||||
@@ -60,13 +112,16 @@ func toEventResponse(event *activity.Event) *api.Event {
|
|||||||
meta[s] = fmt.Sprintf("%v", a)
|
meta[s] = fmt.Sprintf("%v", a)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &api.Event{
|
e := &api.Event{
|
||||||
Id: fmt.Sprint(event.ID),
|
Id: fmt.Sprint(event.ID),
|
||||||
InitiatorId: event.InitiatorID,
|
InitiatorId: event.InitiatorID,
|
||||||
Activity: event.Activity.Message(),
|
InitiatorName: event.InitiatorName,
|
||||||
ActivityCode: api.EventActivityCode(event.Activity.StringCode()),
|
InitiatorEmail: event.InitiatorEmail,
|
||||||
TargetId: event.TargetID,
|
Activity: event.Activity.Message(),
|
||||||
Timestamp: event.Timestamp,
|
ActivityCode: api.EventActivityCode(event.Activity.StringCode()),
|
||||||
Meta: meta,
|
TargetId: event.TargetID,
|
||||||
|
Timestamp: event.Timestamp,
|
||||||
|
Meta: meta,
|
||||||
}
|
}
|
||||||
|
return e
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,6 +37,9 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
|
|||||||
},
|
},
|
||||||
}, user, nil
|
}, user, nil
|
||||||
},
|
},
|
||||||
|
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
|
||||||
|
return make([]*server.UserInfo, 0), nil
|
||||||
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||||
|
|||||||
@@ -53,22 +53,6 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle
|
|||||||
Issued: server.GroupIssuedAPI,
|
Issued: server.GroupIssuedAPI,
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
UpdateGroupFunc: func(_ string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) {
|
|
||||||
var group server.Group
|
|
||||||
group.ID = groupID
|
|
||||||
for _, operation := range operations {
|
|
||||||
switch operation.Type {
|
|
||||||
case server.UpdateGroupName:
|
|
||||||
group.Name = operation.Values[0]
|
|
||||||
case server.UpdateGroupPeers, server.InsertPeersToGroup:
|
|
||||||
group.Peers = operation.Values
|
|
||||||
case server.RemovePeersFromGroup:
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("no operation")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &group, nil
|
|
||||||
},
|
|
||||||
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
|
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
|
||||||
for _, peer := range TestPeers {
|
for _, peer := range TestPeers {
|
||||||
if peer.IP.String() == peerIP {
|
if peer.IP.String() == peerIP {
|
||||||
|
|||||||
@@ -88,31 +88,6 @@ func initNameserversTestData() *NameserversHandler {
|
|||||||
}
|
}
|
||||||
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
|
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
|
||||||
},
|
},
|
||||||
UpdateNameServerGroupFunc: func(accountID, nsGroupID, _ string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
|
|
||||||
nsGroupToUpdate := baseExistingNSGroup.Copy()
|
|
||||||
if nsGroupID != nsGroupToUpdate.ID {
|
|
||||||
return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
|
|
||||||
}
|
|
||||||
for _, operation := range operations {
|
|
||||||
switch operation.Type {
|
|
||||||
case server.UpdateNameServerGroupName:
|
|
||||||
nsGroupToUpdate.Name = operation.Values[0]
|
|
||||||
case server.UpdateNameServerGroupDescription:
|
|
||||||
nsGroupToUpdate.Description = operation.Values[0]
|
|
||||||
case server.UpdateNameServerGroupNameServers:
|
|
||||||
var parsedNSList []nbdns.NameServer
|
|
||||||
for _, nsURL := range operation.Values {
|
|
||||||
parsed, err := nbdns.ParseNameServerURL(nsURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
parsedNSList = append(parsedNSList, parsed)
|
|
||||||
}
|
|
||||||
nsGroupToUpdate.NameServers = parsedNSList
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nsGroupToUpdate, nil
|
|
||||||
},
|
|
||||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
return testingNSAccount, testingAccount.Users["test_user"], nil
|
return testingNSAccount, testingAccount.Users["test_user"], nil
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
@@ -108,38 +107,6 @@ func initRoutesTestData() *RoutesHandler {
|
|||||||
IP: netip.MustParseAddr(existingPeerID).AsSlice(),
|
IP: netip.MustParseAddr(existingPeerID).AsSlice(),
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
UpdateRouteFunc: func(_ string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) {
|
|
||||||
routeToUpdate := baseExistingRoute
|
|
||||||
if routeID != routeToUpdate.ID {
|
|
||||||
return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID)
|
|
||||||
}
|
|
||||||
for _, operation := range operations {
|
|
||||||
switch operation.Type {
|
|
||||||
case server.UpdateRouteNetwork:
|
|
||||||
routeToUpdate.NetworkType, routeToUpdate.Network, _ = route.ParseNetwork(operation.Values[0])
|
|
||||||
case server.UpdateRouteDescription:
|
|
||||||
routeToUpdate.Description = operation.Values[0]
|
|
||||||
case server.UpdateRouteNetworkIdentifier:
|
|
||||||
routeToUpdate.NetID = operation.Values[0]
|
|
||||||
case server.UpdateRoutePeer:
|
|
||||||
routeToUpdate.Peer = operation.Values[0]
|
|
||||||
if routeToUpdate.Peer == notFoundPeerID {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", routeToUpdate.Peer)
|
|
||||||
}
|
|
||||||
case server.UpdateRouteMetric:
|
|
||||||
routeToUpdate.Metric, _ = strconv.Atoi(operation.Values[0])
|
|
||||||
case server.UpdateRouteMasquerade:
|
|
||||||
routeToUpdate.Masquerade, _ = strconv.ParseBool(operation.Values[0])
|
|
||||||
case server.UpdateRouteEnabled:
|
|
||||||
routeToUpdate.Enabled, _ = strconv.ParseBool(operation.Values[0])
|
|
||||||
case server.UpdateRouteGroups:
|
|
||||||
routeToUpdate.Groups = operation.Values
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("no operation")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return routeToUpdate, nil
|
|
||||||
},
|
|
||||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
return testingAccount, testingAccount.Users["test_user"], nil
|
return testingAccount, testingAccount.Users["test_user"], nil
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
|
|||||||
// WriteError converts an error to an JSON error response.
|
// WriteError converts an error to an JSON error response.
|
||||||
// If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise
|
// If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise
|
||||||
func WriteError(err error, w http.ResponseWriter) {
|
func WriteError(err error, w http.ResponseWriter) {
|
||||||
|
log.Errorf("got a handler error: %s", err.Error())
|
||||||
errStatus, ok := status.FromError(err)
|
errStatus, ok := status.FromError(err)
|
||||||
httpStatus := http.StatusInternalServerError
|
httpStatus := http.StatusInternalServerError
|
||||||
msg := "internal server error"
|
msg := "internal server error"
|
||||||
|
|||||||
@@ -513,7 +513,9 @@ func buildUserExportRequest() (string, error) {
|
|||||||
return string(str), nil
|
return string(str), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) {
|
func (am *Auth0Manager) createRequest(
|
||||||
|
method string, endpoint string, body io.Reader,
|
||||||
|
) (*http.Request, error) {
|
||||||
jwtToken, err := am.credentials.Authenticate()
|
jwtToken, err := am.credentials.Authenticate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -521,17 +523,23 @@ func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*
|
|||||||
|
|
||||||
reqURL := am.authIssuer + endpoint
|
reqURL := am.authIssuer + endpoint
|
||||||
|
|
||||||
payload := strings.NewReader(payloadStr)
|
req, err := http.NewRequest(method, reqURL, body)
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", reqURL, payload)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) {
|
||||||
|
req, err := am.createRequest("POST", endpoint, strings.NewReader(payloadStr))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
req.Header.Add("content-type", "application/json")
|
req.Header.Add("content-type", "application/json")
|
||||||
|
|
||||||
return req, nil
|
return req, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||||
@@ -737,6 +745,38 @@ func (am *Auth0Manager) InviteUserByID(userID string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteUser from Auth0
|
||||||
|
func (am *Auth0Manager) DeleteUser(userID string) error {
|
||||||
|
req, err := am.createRequest(http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := am.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("execute delete request: %v", err)
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("close delete request body: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if resp.StatusCode != 204 {
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
|
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
|
||||||
// If the status is "completed", then return the downloadLink
|
// If the status is "completed", then return the downloadLink
|
||||||
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {
|
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {
|
||||||
|
|||||||
@@ -12,9 +12,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"goauthentik.io/api/v3"
|
"goauthentik.io/api/v3"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AuthentikManager authentik manager client instance.
|
// AuthentikManager authentik manager client instance.
|
||||||
@@ -453,6 +454,38 @@ func (am *AuthentikManager) InviteUserByID(_ string) error {
|
|||||||
return fmt.Errorf("method InviteUserByID not implemented")
|
return fmt.Errorf("method InviteUserByID not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteUser from Authentik
|
||||||
|
func (am *AuthentikManager) DeleteUser(userID string) error {
|
||||||
|
ctx, err := am.authenticationContext()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
userPk, err := strconv.ParseInt(userID, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := am.apiClient.CoreApi.CoreUsersDestroy(ctx, int32(userPk)).Execute()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close() // nolint
|
||||||
|
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountDeleteUser()
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNoContent {
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unable to delete user %s, statusCode %d", userID, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (am *AuthentikManager) authenticationContext() (context.Context, error) {
|
func (am *AuthentikManager) authenticationContext() (context.Context, error) {
|
||||||
jwtToken, err := am.credentials.Authenticate()
|
jwtToken, err := am.credentials.Authenticate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -454,6 +454,43 @@ func (am *AzureManager) InviteUserByID(_ string) error {
|
|||||||
return fmt.Errorf("method InviteUserByID not implemented")
|
return fmt.Errorf("method InviteUserByID not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteUser from Azure
|
||||||
|
func (am *AzureManager) DeleteUser(userID string) error {
|
||||||
|
jwtToken, err := am.credentials.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, url.QueryEscape(userID))
|
||||||
|
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||||
|
req.Header.Add("content-type", "application/json")
|
||||||
|
|
||||||
|
log.Debugf("delete idp user %s", userID)
|
||||||
|
|
||||||
|
resp, err := am.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountDeleteUser()
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNoContent {
|
||||||
|
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
|
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
|
||||||
q := url.Values{}
|
q := url.Values{}
|
||||||
q.Add("$select", extensionFields)
|
q.Add("$select", extensionFields)
|
||||||
|
|||||||
@@ -254,6 +254,19 @@ func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error {
|
|||||||
return fmt.Errorf("method InviteUserByID not implemented")
|
return fmt.Errorf("method InviteUserByID not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteUser from GoogleWorkspace.
|
||||||
|
func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error {
|
||||||
|
if err := gm.usersService.Delete(userID).Do(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if gm.appMetrics != nil {
|
||||||
|
gm.appMetrics.IDPMetrics().CountDeleteUser()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey.
|
// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey.
|
||||||
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
|
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
|
||||||
// If that fails, it falls back to using the default Google credentials path.
|
// If that fails, it falls back to using the default Google credentials path.
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ type Manager interface {
|
|||||||
CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error)
|
CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error)
|
||||||
GetUserByEmail(email string) ([]*UserData, error)
|
GetUserByEmail(email string) ([]*UserData, error)
|
||||||
InviteUserByID(userID string) error
|
InviteUserByID(userID string) error
|
||||||
|
DeleteUser(userID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientConfig defines common client configuration for all IdP manager
|
// ClientConfig defines common client configuration for all IdP manager
|
||||||
@@ -37,10 +38,10 @@ type Config struct {
|
|||||||
ManagerType string
|
ManagerType string
|
||||||
ClientConfig *ClientConfig
|
ClientConfig *ClientConfig
|
||||||
ExtraConfig ExtraConfig
|
ExtraConfig ExtraConfig
|
||||||
Auth0ClientCredentials Auth0ClientConfig
|
Auth0ClientCredentials *Auth0ClientConfig
|
||||||
AzureClientCredentials AzureClientConfig
|
AzureClientCredentials *AzureClientConfig
|
||||||
KeycloakClientCredentials KeycloakClientConfig
|
KeycloakClientCredentials *KeycloakClientConfig
|
||||||
ZitadelClientCredentials ZitadelClientConfig
|
ZitadelClientCredentials *ZitadelClientConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// ManagerCredentials interface that authenticates using the credential of each type of idp
|
// ManagerCredentials interface that authenticates using the credential of each type of idp
|
||||||
@@ -96,7 +97,7 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
|
|||||||
case "auth0":
|
case "auth0":
|
||||||
auth0ClientConfig := config.Auth0ClientCredentials
|
auth0ClientConfig := config.Auth0ClientCredentials
|
||||||
if config.ClientConfig != nil {
|
if config.ClientConfig != nil {
|
||||||
auth0ClientConfig = Auth0ClientConfig{
|
auth0ClientConfig = &Auth0ClientConfig{
|
||||||
Audience: config.ExtraConfig["Audience"],
|
Audience: config.ExtraConfig["Audience"],
|
||||||
AuthIssuer: config.ClientConfig.Issuer,
|
AuthIssuer: config.ClientConfig.Issuer,
|
||||||
ClientID: config.ClientConfig.ClientID,
|
ClientID: config.ClientConfig.ClientID,
|
||||||
@@ -105,11 +106,11 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return NewAuth0Manager(auth0ClientConfig, appMetrics)
|
return NewAuth0Manager(*auth0ClientConfig, appMetrics)
|
||||||
case "azure":
|
case "azure":
|
||||||
azureClientConfig := config.AzureClientCredentials
|
azureClientConfig := config.AzureClientCredentials
|
||||||
if config.ClientConfig != nil {
|
if config.ClientConfig != nil {
|
||||||
azureClientConfig = AzureClientConfig{
|
azureClientConfig = &AzureClientConfig{
|
||||||
ClientID: config.ClientConfig.ClientID,
|
ClientID: config.ClientConfig.ClientID,
|
||||||
ClientSecret: config.ClientConfig.ClientSecret,
|
ClientSecret: config.ClientConfig.ClientSecret,
|
||||||
GrantType: config.ClientConfig.GrantType,
|
GrantType: config.ClientConfig.GrantType,
|
||||||
@@ -119,11 +120,11 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return NewAzureManager(azureClientConfig, appMetrics)
|
return NewAzureManager(*azureClientConfig, appMetrics)
|
||||||
case "keycloak":
|
case "keycloak":
|
||||||
keycloakClientConfig := config.KeycloakClientCredentials
|
keycloakClientConfig := config.KeycloakClientCredentials
|
||||||
if config.ClientConfig != nil {
|
if config.ClientConfig != nil {
|
||||||
keycloakClientConfig = KeycloakClientConfig{
|
keycloakClientConfig = &KeycloakClientConfig{
|
||||||
ClientID: config.ClientConfig.ClientID,
|
ClientID: config.ClientConfig.ClientID,
|
||||||
ClientSecret: config.ClientConfig.ClientSecret,
|
ClientSecret: config.ClientConfig.ClientSecret,
|
||||||
GrantType: config.ClientConfig.GrantType,
|
GrantType: config.ClientConfig.GrantType,
|
||||||
@@ -132,11 +133,11 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return NewKeycloakManager(keycloakClientConfig, appMetrics)
|
return NewKeycloakManager(*keycloakClientConfig, appMetrics)
|
||||||
case "zitadel":
|
case "zitadel":
|
||||||
zitadelClientConfig := config.ZitadelClientCredentials
|
zitadelClientConfig := config.ZitadelClientCredentials
|
||||||
if config.ClientConfig != nil {
|
if config.ClientConfig != nil {
|
||||||
zitadelClientConfig = ZitadelClientConfig{
|
zitadelClientConfig = &ZitadelClientConfig{
|
||||||
ClientID: config.ClientConfig.ClientID,
|
ClientID: config.ClientConfig.ClientID,
|
||||||
ClientSecret: config.ClientConfig.ClientSecret,
|
ClientSecret: config.ClientConfig.ClientSecret,
|
||||||
GrantType: config.ClientConfig.GrantType,
|
GrantType: config.ClientConfig.GrantType,
|
||||||
@@ -145,7 +146,7 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return NewZitadelManager(zitadelClientConfig, appMetrics)
|
return NewZitadelManager(*zitadelClientConfig, appMetrics)
|
||||||
case "authentik":
|
case "authentik":
|
||||||
authentikConfig := AuthentikClientConfig{
|
authentikConfig := AuthentikClientConfig{
|
||||||
Issuer: config.ClientConfig.Issuer,
|
Issuer: config.ClientConfig.Issuer,
|
||||||
|
|||||||
@@ -467,6 +467,47 @@ func (km *KeycloakManager) InviteUserByID(_ string) error {
|
|||||||
return fmt.Errorf("method InviteUserByID not implemented")
|
return fmt.Errorf("method InviteUserByID not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteUser from Keycloack
|
||||||
|
func (km *KeycloakManager) DeleteUser(userID string) error {
|
||||||
|
jwtToken, err := km.credentials.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID))
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||||
|
req.Header.Add("content-type", "application/json")
|
||||||
|
|
||||||
|
if km.appMetrics != nil {
|
||||||
|
km.appMetrics.IDPMetrics().CountDeleteUser()
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := km.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if km.appMetrics != nil {
|
||||||
|
km.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close() // nolint
|
||||||
|
|
||||||
|
// In the docs, they specified 200, but in the endpoints, they return 204
|
||||||
|
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
||||||
|
if km.appMetrics != nil {
|
||||||
|
km.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) {
|
func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) {
|
||||||
attrs := keycloakUserAttributes{}
|
attrs := keycloakUserAttributes{}
|
||||||
attrs.Set(wtAccountID, appMetadata.WTAccountID)
|
attrs.Set(wtAccountID, appMetadata.WTAccountID)
|
||||||
|
|||||||
@@ -319,6 +319,28 @@ func (om *OktaManager) InviteUserByID(_ string) error {
|
|||||||
return fmt.Errorf("method InviteUserByID not implemented")
|
return fmt.Errorf("method InviteUserByID not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteUser from Okta
|
||||||
|
func (om *OktaManager) DeleteUser(userID string) error {
|
||||||
|
resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err.Error())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if om.appMetrics != nil {
|
||||||
|
om.appMetrics.IDPMetrics().CountDeleteUser()
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
if om.appMetrics != nil {
|
||||||
|
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// updateUserProfileSchema updates the Okta user schema to include custom fields,
|
// updateUserProfileSchema updates the Okta user schema to include custom fields,
|
||||||
// wt_account_id and wt_pending_invite.
|
// wt_account_id and wt_pending_invite.
|
||||||
func updateUserProfileSchema(client *okta.Client) error {
|
func updateUserProfileSchema(client *okta.Client) error {
|
||||||
|
|||||||
@@ -13,8 +13,9 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ZitadelManager zitadel manager client instance.
|
// ZitadelManager zitadel manager client instance.
|
||||||
@@ -447,6 +448,21 @@ func (zm *ZitadelManager) InviteUserByID(_ string) error {
|
|||||||
return fmt.Errorf("method InviteUserByID not implemented")
|
return fmt.Errorf("method InviteUserByID not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteUser from Zitadel
|
||||||
|
func (zm *ZitadelManager) DeleteUser(userID string) error {
|
||||||
|
resource := fmt.Sprintf("users/%s", userID)
|
||||||
|
if err := zm.delete(resource); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountDeleteUser()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// getUserMetadata requests user metadata from zitadel via ID.
|
// getUserMetadata requests user metadata from zitadel via ID.
|
||||||
func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) {
|
func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) {
|
||||||
resource := fmt.Sprintf("users/%s/metadata/_search", userID)
|
resource := fmt.Sprintf("users/%s/metadata/_search", userID)
|
||||||
@@ -500,6 +516,42 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
|
|||||||
return io.ReadAll(resp.Body)
|
return io.ReadAll(resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// delete perform Delete requests.
|
||||||
|
func (zm *ZitadelManager) delete(resource string) error {
|
||||||
|
jwtToken, err := zm.credentials.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
reqURL := fmt.Sprintf("%s/%s", zm.managementEndpoint, resource)
|
||||||
|
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||||
|
req.Header.Add("content-type", "application/json")
|
||||||
|
|
||||||
|
resp, err := zm.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("unable to delete %s, statusCode %d", reqURL, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// get perform Get requests.
|
// get perform Get requests.
|
||||||
func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
|
func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
|
||||||
jwtToken, err := zm.credentials.Authenticate()
|
jwtToken, err := zm.credentials.Authenticate()
|
||||||
|
|||||||
@@ -412,7 +412,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
|
|||||||
peersUpdateManager := NewPeersUpdateManager()
|
peersUpdateManager := NewPeersUpdateManager()
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "",
|
accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "",
|
||||||
eventStore)
|
eventStore, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -503,7 +503,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
|
|||||||
peersUpdateManager := server.NewPeersUpdateManager()
|
peersUpdateManager := server.NewPeersUpdateManager()
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
|
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
|
||||||
eventStore)
|
eventStore, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed creating a manager: %v", err)
|
log.Fatalf("failed creating a manager: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ type MockAccountManager struct {
|
|||||||
AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, *server.NetworkMap, error)
|
AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, *server.NetworkMap, error)
|
||||||
GetGroupFunc func(accountID, groupID string) (*server.Group, error)
|
GetGroupFunc func(accountID, groupID string) (*server.Group, error)
|
||||||
SaveGroupFunc func(accountID, userID string, group *server.Group) error
|
SaveGroupFunc func(accountID, userID string, group *server.Group) error
|
||||||
UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error)
|
|
||||||
DeleteGroupFunc func(accountID, userId, groupID string) error
|
DeleteGroupFunc func(accountID, userId, groupID string) error
|
||||||
ListGroupsFunc func(accountID string) ([]*server.Group, error)
|
ListGroupsFunc func(accountID string) ([]*server.Group, error)
|
||||||
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
|
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
|
||||||
@@ -54,7 +53,6 @@ type MockAccountManager struct {
|
|||||||
CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
||||||
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
|
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
|
||||||
SaveRouteFunc func(accountID, userID string, route *route.Route) error
|
SaveRouteFunc func(accountID, userID string, route *route.Route) error
|
||||||
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
|
|
||||||
DeleteRouteFunc func(accountID, routeID, userID string) error
|
DeleteRouteFunc func(accountID, routeID, userID string) error
|
||||||
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
|
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
|
||||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||||
@@ -68,7 +66,6 @@ type MockAccountManager struct {
|
|||||||
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
|
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
|
||||||
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||||
UpdateNameServerGroupFunc func(accountID, nsGroupID, userID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
|
|
||||||
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
|
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
|
||||||
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
|
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
|
||||||
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||||
@@ -267,14 +264,6 @@ func (am *MockAccountManager) SaveGroup(accountID, userID string, group *server.
|
|||||||
return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented")
|
return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateGroup mock implementation of UpdateGroup from server.AccountManager interface
|
|
||||||
func (am *MockAccountManager) UpdateGroup(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) {
|
|
||||||
if am.UpdateGroupFunc != nil {
|
|
||||||
return am.UpdateGroupFunc(accountID, groupID, operations)
|
|
||||||
}
|
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method UpdateGroup not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface
|
// DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface
|
||||||
func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
||||||
if am.DeleteGroupFunc != nil {
|
if am.DeleteGroupFunc != nil {
|
||||||
@@ -435,14 +424,6 @@ func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.R
|
|||||||
return status.Errorf(codes.Unimplemented, "method SaveRoute is not implemented")
|
return status.Errorf(codes.Unimplemented, "method SaveRoute is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoute mock implementation of UpdateRoute from server.AccountManager interface
|
|
||||||
func (am *MockAccountManager) UpdateRoute(accountID, ruleID string, operations []server.RouteUpdateOperation) (*route.Route, error) {
|
|
||||||
if am.UpdateRouteFunc != nil {
|
|
||||||
return am.UpdateRouteFunc(accountID, ruleID, operations)
|
|
||||||
}
|
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method UpdateRoute not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface
|
// DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface
|
||||||
func (am *MockAccountManager) DeleteRoute(accountID, routeID, userID string) error {
|
func (am *MockAccountManager) DeleteRoute(accountID, routeID, userID string) error {
|
||||||
if am.DeleteRouteFunc != nil {
|
if am.DeleteRouteFunc != nil {
|
||||||
@@ -533,14 +514,6 @@ func (am *MockAccountManager) SaveNameServerGroup(accountID, userID string, nsGr
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateNameServerGroup mocks UpdateNameServerGroup of the AccountManager interface
|
|
||||||
func (am *MockAccountManager) UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
|
|
||||||
if am.UpdateNameServerGroupFunc != nil {
|
|
||||||
return am.UpdateNameServerGroupFunc(accountID, nsGroupID, userID, operations)
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteNameServerGroup mocks DeleteNameServerGroup of the AccountManager interface
|
// DeleteNameServerGroup mocks DeleteNameServerGroup of the AccountManager interface
|
||||||
func (am *MockAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
|
func (am *MockAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
|
||||||
if am.DeleteNameServerGroupFunc != nil {
|
if am.DeleteNameServerGroupFunc != nil {
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package server
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@@ -15,54 +14,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
|
||||||
// UpdateNameServerGroupName indicates a nameserver group name update operation
|
|
||||||
UpdateNameServerGroupName NameServerGroupUpdateOperationType = iota
|
|
||||||
// UpdateNameServerGroupDescription indicates a nameserver group description update operation
|
|
||||||
UpdateNameServerGroupDescription
|
|
||||||
// UpdateNameServerGroupNameServers indicates a nameserver group nameservers list update operation
|
|
||||||
UpdateNameServerGroupNameServers
|
|
||||||
// UpdateNameServerGroupGroups indicates a nameserver group' groups update operation
|
|
||||||
UpdateNameServerGroupGroups
|
|
||||||
// UpdateNameServerGroupEnabled indicates a nameserver group status update operation
|
|
||||||
UpdateNameServerGroupEnabled
|
|
||||||
// UpdateNameServerGroupPrimary indicates a nameserver group primary status update operation
|
|
||||||
UpdateNameServerGroupPrimary
|
|
||||||
// UpdateNameServerGroupDomains indicates a nameserver group' domains update operation
|
|
||||||
UpdateNameServerGroupDomains
|
|
||||||
|
|
||||||
domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
|
|
||||||
)
|
|
||||||
|
|
||||||
// NameServerGroupUpdateOperationType operation type
|
|
||||||
type NameServerGroupUpdateOperationType int
|
|
||||||
|
|
||||||
func (t NameServerGroupUpdateOperationType) String() string {
|
|
||||||
switch t {
|
|
||||||
case UpdateNameServerGroupDescription:
|
|
||||||
return "UpdateNameServerGroupDescription"
|
|
||||||
case UpdateNameServerGroupName:
|
|
||||||
return "UpdateNameServerGroupName"
|
|
||||||
case UpdateNameServerGroupNameServers:
|
|
||||||
return "UpdateNameServerGroupNameServers"
|
|
||||||
case UpdateNameServerGroupGroups:
|
|
||||||
return "UpdateNameServerGroupGroups"
|
|
||||||
case UpdateNameServerGroupEnabled:
|
|
||||||
return "UpdateNameServerGroupEnabled"
|
|
||||||
case UpdateNameServerGroupPrimary:
|
|
||||||
return "UpdateNameServerGroupPrimary"
|
|
||||||
case UpdateNameServerGroupDomains:
|
|
||||||
return "UpdateNameServerGroupDomains"
|
|
||||||
default:
|
|
||||||
return "InvalidOperation"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NameServerGroupUpdateOperation operation object with type and values to be applied
|
|
||||||
type NameServerGroupUpdateOperation struct {
|
|
||||||
Type NameServerGroupUpdateOperationType
|
|
||||||
Values []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
||||||
func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||||
@@ -172,109 +124,6 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateNameServerGroup updates existing nameserver group with set of operations
|
|
||||||
func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
|
|
||||||
|
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(operations) == 0 {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "operations shouldn't be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID]
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
|
|
||||||
}
|
|
||||||
|
|
||||||
newNSGroup := nsGroupToUpdate.Copy()
|
|
||||||
|
|
||||||
for _, operation := range operations {
|
|
||||||
valuesCount := len(operation.Values)
|
|
||||||
if valuesCount < 1 {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, value := range operation.Values {
|
|
||||||
if value == "" {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
switch operation.Type {
|
|
||||||
case UpdateNameServerGroupDescription:
|
|
||||||
newNSGroup.Description = operation.Values[0]
|
|
||||||
case UpdateNameServerGroupName:
|
|
||||||
if valuesCount > 1 {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount)
|
|
||||||
}
|
|
||||||
err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
newNSGroup.Name = operation.Values[0]
|
|
||||||
case UpdateNameServerGroupNameServers:
|
|
||||||
var nsList []nbdns.NameServer
|
|
||||||
for _, url := range operation.Values {
|
|
||||||
ns, err := nbdns.ParseNameServerURL(url)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
nsList = append(nsList, ns)
|
|
||||||
}
|
|
||||||
err = validateNSList(nsList)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
newNSGroup.NameServers = nsList
|
|
||||||
case UpdateNameServerGroupGroups:
|
|
||||||
err = validateGroups(operation.Values, account.Groups)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
newNSGroup.Groups = operation.Values
|
|
||||||
case UpdateNameServerGroupEnabled:
|
|
||||||
enabled, err := strconv.ParseBool(operation.Values[0])
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
|
|
||||||
}
|
|
||||||
newNSGroup.Enabled = enabled
|
|
||||||
case UpdateNameServerGroupPrimary:
|
|
||||||
primary, err := strconv.ParseBool(operation.Values[0])
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0])
|
|
||||||
}
|
|
||||||
newNSGroup.Primary = primary
|
|
||||||
case UpdateNameServerGroupDomains:
|
|
||||||
err = validateDomainInput(false, operation.Values)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
newNSGroup.Domains = operation.Values
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
account.NameServerGroups[nsGroupID] = newNSGroup
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
err = am.Store.SaveAccount(account)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = am.updateAccountPeers(account)
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after update nameserver %s", newNSGroup.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
return newNSGroup.Copy(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
||||||
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
|
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
|
||||||
|
|
||||||
|
|||||||
@@ -655,323 +655,6 @@ func TestSaveNameServerGroup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateNameServerGroup(t *testing.T) {
|
|
||||||
nsGroupID := "testingNSGroup"
|
|
||||||
|
|
||||||
existingNSGroup := &nbdns.NameServerGroup{
|
|
||||||
ID: nsGroupID,
|
|
||||||
Name: "super",
|
|
||||||
Description: "super",
|
|
||||||
Primary: true,
|
|
||||||
NameServers: []nbdns.NameServer{
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("1.1.1.1"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: nbdns.DefaultDNSPort,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("1.1.2.2"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: nbdns.DefaultDNSPort,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Groups: []string{group1ID},
|
|
||||||
Enabled: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
existingNSGroup *nbdns.NameServerGroup
|
|
||||||
nsGroupID string
|
|
||||||
operations []NameServerGroupUpdateOperation
|
|
||||||
shouldCreate bool
|
|
||||||
errFunc require.ErrorAssertionFunc
|
|
||||||
expectedNSGroup *nbdns.NameServerGroup
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Should Config Single Property",
|
|
||||||
existingNSGroup: existingNSGroup,
|
|
||||||
nsGroupID: existingNSGroup.ID,
|
|
||||||
operations: []NameServerGroupUpdateOperation{
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupName,
|
|
||||||
Values: []string{"superNew"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.NoError,
|
|
||||||
shouldCreate: true,
|
|
||||||
expectedNSGroup: &nbdns.NameServerGroup{
|
|
||||||
ID: nsGroupID,
|
|
||||||
Name: "superNew",
|
|
||||||
Description: "super",
|
|
||||||
Primary: true,
|
|
||||||
NameServers: []nbdns.NameServer{
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("1.1.1.1"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: nbdns.DefaultDNSPort,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("1.1.2.2"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: nbdns.DefaultDNSPort,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Groups: []string{group1ID},
|
|
||||||
Enabled: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Config Multiple Properties",
|
|
||||||
existingNSGroup: existingNSGroup,
|
|
||||||
nsGroupID: existingNSGroup.ID,
|
|
||||||
operations: []NameServerGroupUpdateOperation{
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupName,
|
|
||||||
Values: []string{"superNew"},
|
|
||||||
},
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupDescription,
|
|
||||||
Values: []string{"superDescription"},
|
|
||||||
},
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupNameServers,
|
|
||||||
Values: []string{"udp://127.0.0.1:53", "udp://8.8.8.8:53"},
|
|
||||||
},
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupGroups,
|
|
||||||
Values: []string{group1ID, group2ID},
|
|
||||||
},
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupEnabled,
|
|
||||||
Values: []string{"false"},
|
|
||||||
},
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupPrimary,
|
|
||||||
Values: []string{"false"},
|
|
||||||
},
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupDomains,
|
|
||||||
Values: []string{validDomain},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.NoError,
|
|
||||||
shouldCreate: true,
|
|
||||||
expectedNSGroup: &nbdns.NameServerGroup{
|
|
||||||
ID: nsGroupID,
|
|
||||||
Name: "superNew",
|
|
||||||
Description: "superDescription",
|
|
||||||
Primary: false,
|
|
||||||
Domains: []string{validDomain},
|
|
||||||
NameServers: []nbdns.NameServer{
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("127.0.0.1"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: nbdns.DefaultDNSPort,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.8.8"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: nbdns.DefaultDNSPort,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Groups: []string{group1ID, group2ID},
|
|
||||||
Enabled: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Config On Invalid ID",
|
|
||||||
existingNSGroup: existingNSGroup,
|
|
||||||
nsGroupID: "nonExistingNSGroup",
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Config On Empty Operations",
|
|
||||||
existingNSGroup: existingNSGroup,
|
|
||||||
nsGroupID: existingNSGroup.ID,
|
|
||||||
operations: []NameServerGroupUpdateOperation{},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Config On Empty Values",
|
|
||||||
existingNSGroup: existingNSGroup,
|
|
||||||
nsGroupID: existingNSGroup.ID,
|
|
||||||
operations: []NameServerGroupUpdateOperation{
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupName,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Config On Empty String",
|
|
||||||
existingNSGroup: existingNSGroup,
|
|
||||||
nsGroupID: existingNSGroup.ID,
|
|
||||||
operations: []NameServerGroupUpdateOperation{
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupName,
|
|
||||||
Values: []string{""},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Config On Invalid Name Large String",
|
|
||||||
existingNSGroup: existingNSGroup,
|
|
||||||
nsGroupID: existingNSGroup.ID,
|
|
||||||
operations: []NameServerGroupUpdateOperation{
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupName,
|
|
||||||
Values: []string{"12345678901234567890qwertyuiopqwertyuiop1"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Config On Invalid On Existing Name",
|
|
||||||
existingNSGroup: existingNSGroup,
|
|
||||||
nsGroupID: existingNSGroup.ID,
|
|
||||||
operations: []NameServerGroupUpdateOperation{
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupName,
|
|
||||||
Values: []string{existingNSGroupName},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Config On Invalid On Multiple Name Values",
|
|
||||||
existingNSGroup: existingNSGroup,
|
|
||||||
nsGroupID: existingNSGroup.ID,
|
|
||||||
operations: []NameServerGroupUpdateOperation{
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupName,
|
|
||||||
Values: []string{"nameOne", "nameTwo"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Config On Invalid Boolean",
|
|
||||||
existingNSGroup: existingNSGroup,
|
|
||||||
nsGroupID: existingNSGroup.ID,
|
|
||||||
operations: []NameServerGroupUpdateOperation{
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupEnabled,
|
|
||||||
Values: []string{"yes"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Config On Invalid Nameservers Wrong Schema",
|
|
||||||
existingNSGroup: existingNSGroup,
|
|
||||||
nsGroupID: existingNSGroup.ID,
|
|
||||||
operations: []NameServerGroupUpdateOperation{
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupNameServers,
|
|
||||||
Values: []string{"https://127.0.0.1:53"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Config On Invalid Nameservers Wrong IP",
|
|
||||||
existingNSGroup: existingNSGroup,
|
|
||||||
nsGroupID: existingNSGroup.ID,
|
|
||||||
operations: []NameServerGroupUpdateOperation{
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupNameServers,
|
|
||||||
Values: []string{"udp://8.8.8.300:53"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Config On Large Number Of Nameservers",
|
|
||||||
existingNSGroup: existingNSGroup,
|
|
||||||
nsGroupID: existingNSGroup.ID,
|
|
||||||
operations: []NameServerGroupUpdateOperation{
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupNameServers,
|
|
||||||
Values: []string{"udp://127.0.0.1:53", "udp://8.8.8.8:53", "udp://8.8.4.4:53"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Config On Invalid GroupID",
|
|
||||||
existingNSGroup: existingNSGroup,
|
|
||||||
nsGroupID: existingNSGroup.ID,
|
|
||||||
operations: []NameServerGroupUpdateOperation{
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupGroups,
|
|
||||||
Values: []string{"nonExistingGroupID"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Config On Invalid Domains",
|
|
||||||
existingNSGroup: existingNSGroup,
|
|
||||||
nsGroupID: existingNSGroup.ID,
|
|
||||||
operations: []NameServerGroupUpdateOperation{
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupDomains,
|
|
||||||
Values: []string{invalidDomain},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Config On Invalid Primary Status",
|
|
||||||
existingNSGroup: existingNSGroup,
|
|
||||||
nsGroupID: existingNSGroup.ID,
|
|
||||||
operations: []NameServerGroupUpdateOperation{
|
|
||||||
NameServerGroupUpdateOperation{
|
|
||||||
Type: UpdateNameServerGroupPrimary,
|
|
||||||
Values: []string{"yes"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, testCase := range testCases {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
am, err := createNSManager(t)
|
|
||||||
if err != nil {
|
|
||||||
t.Error("failed to create account manager")
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := initTestNSAccount(t, am)
|
|
||||||
if err != nil {
|
|
||||||
t.Error("failed to init testing account")
|
|
||||||
}
|
|
||||||
|
|
||||||
account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup
|
|
||||||
|
|
||||||
err = am.Store.SaveAccount(account)
|
|
||||||
if err != nil {
|
|
||||||
t.Error("account should be saved")
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedRoute, err := am.UpdateNameServerGroup(account.Id, testCase.nsGroupID, userID, testCase.operations)
|
|
||||||
testCase.errFunc(t, err)
|
|
||||||
|
|
||||||
if !testCase.shouldCreate {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase.expectedNSGroup.ID = updatedRoute.ID
|
|
||||||
|
|
||||||
if !testCase.expectedNSGroup.IsEqual(updatedRoute) {
|
|
||||||
t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", updatedRoute, testCase.expectedNSGroup)
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDeleteNameServerGroup(t *testing.T) {
|
func TestDeleteNameServerGroup(t *testing.T) {
|
||||||
nsGroupID := "testingNSGroup"
|
nsGroupID := "testingNSGroup"
|
||||||
|
|
||||||
@@ -1061,7 +744,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore)
|
return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNSStore(t *testing.T) (Store, error) {
|
func createNSStore(t *testing.T) (Store, error) {
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
@@ -13,57 +12,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
// UpdateRouteDescription indicates a route description update operation
|
|
||||||
UpdateRouteDescription RouteUpdateOperationType = iota
|
|
||||||
// UpdateRouteNetwork indicates a route IP update operation
|
|
||||||
UpdateRouteNetwork
|
|
||||||
// UpdateRoutePeer indicates a route peer update operation
|
|
||||||
UpdateRoutePeer
|
|
||||||
// UpdateRouteMetric indicates a route metric update operation
|
|
||||||
UpdateRouteMetric
|
|
||||||
// UpdateRouteMasquerade indicates a route masquerade update operation
|
|
||||||
UpdateRouteMasquerade
|
|
||||||
// UpdateRouteEnabled indicates a route enabled update operation
|
|
||||||
UpdateRouteEnabled
|
|
||||||
// UpdateRouteNetworkIdentifier indicates a route net ID update operation
|
|
||||||
UpdateRouteNetworkIdentifier
|
|
||||||
// UpdateRouteGroups indicates a group list update operation
|
|
||||||
UpdateRouteGroups
|
|
||||||
)
|
|
||||||
|
|
||||||
// RouteUpdateOperationType operation type
|
|
||||||
type RouteUpdateOperationType int
|
|
||||||
|
|
||||||
func (t RouteUpdateOperationType) String() string {
|
|
||||||
switch t {
|
|
||||||
case UpdateRouteDescription:
|
|
||||||
return "UpdateRouteDescription"
|
|
||||||
case UpdateRouteNetwork:
|
|
||||||
return "UpdateRouteNetwork"
|
|
||||||
case UpdateRoutePeer:
|
|
||||||
return "UpdateRoutePeer"
|
|
||||||
case UpdateRouteMetric:
|
|
||||||
return "UpdateRouteMetric"
|
|
||||||
case UpdateRouteMasquerade:
|
|
||||||
return "UpdateRouteMasquerade"
|
|
||||||
case UpdateRouteEnabled:
|
|
||||||
return "UpdateRouteEnabled"
|
|
||||||
case UpdateRouteNetworkIdentifier:
|
|
||||||
return "UpdateRouteNetworkIdentifier"
|
|
||||||
case UpdateRouteGroups:
|
|
||||||
return "UpdateRouteGroups"
|
|
||||||
default:
|
|
||||||
return "InvalidOperation"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RouteUpdateOperation operation object with type and values to be applied
|
|
||||||
type RouteUpdateOperation struct {
|
|
||||||
Type RouteUpdateOperationType
|
|
||||||
Values []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRoute gets a route object from account and route IDs
|
// GetRoute gets a route object from account and route IDs
|
||||||
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
|
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
@@ -241,109 +189,6 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoute updates existing route with set of operations
|
|
||||||
func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error) {
|
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
routeToUpdate, ok := account.Routes[routeID]
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID)
|
|
||||||
}
|
|
||||||
|
|
||||||
newRoute := routeToUpdate.Copy()
|
|
||||||
|
|
||||||
for _, operation := range operations {
|
|
||||||
|
|
||||||
if len(operation.Values) != 1 {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be 1", operation.Type.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
switch operation.Type {
|
|
||||||
case UpdateRouteDescription:
|
|
||||||
newRoute.Description = operation.Values[0]
|
|
||||||
case UpdateRouteNetworkIdentifier:
|
|
||||||
if utf8.RuneCountInString(operation.Values[0]) > route.MaxNetIDChar || operation.Values[0] == "" {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
|
||||||
}
|
|
||||||
newRoute.NetID = operation.Values[0]
|
|
||||||
case UpdateRouteNetwork:
|
|
||||||
prefixType, prefix, err := route.ParseNetwork(operation.Values[0])
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", operation.Values[0])
|
|
||||||
}
|
|
||||||
err = am.checkPrefixPeerExists(accountID, routeToUpdate.Peer, prefix)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
newRoute.Network = prefix
|
|
||||||
newRoute.NetworkType = prefixType
|
|
||||||
case UpdateRoutePeer:
|
|
||||||
if operation.Values[0] != "" {
|
|
||||||
peer := account.GetPeer(operation.Values[0])
|
|
||||||
if peer == nil {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", operation.Values[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = am.checkPrefixPeerExists(accountID, operation.Values[0], routeToUpdate.Network)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
newRoute.Peer = operation.Values[0]
|
|
||||||
case UpdateRouteMetric:
|
|
||||||
metric, err := strconv.Atoi(operation.Values[0])
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, not int", operation.Values[0])
|
|
||||||
}
|
|
||||||
if metric < route.MinMetric || metric > route.MaxMetric {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, value should be %d > N < %d",
|
|
||||||
operation.Values[0],
|
|
||||||
route.MinMetric,
|
|
||||||
route.MaxMetric,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
newRoute.Metric = metric
|
|
||||||
case UpdateRouteMasquerade:
|
|
||||||
masquerade, err := strconv.ParseBool(operation.Values[0])
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse masquerade %s, not boolean", operation.Values[0])
|
|
||||||
}
|
|
||||||
newRoute.Masquerade = masquerade
|
|
||||||
case UpdateRouteEnabled:
|
|
||||||
enabled, err := strconv.ParseBool(operation.Values[0])
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
|
|
||||||
}
|
|
||||||
newRoute.Enabled = enabled
|
|
||||||
case UpdateRouteGroups:
|
|
||||||
err = validateGroups(operation.Values, account.Groups)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
newRoute.Groups = operation.Values
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Routes[routeID] = newRoute
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(account); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = am.updateAccountPeers(account)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(status.Internal, "failed to update account peers")
|
|
||||||
}
|
|
||||||
return newRoute, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteRoute deletes route with routeID
|
// DeleteRoute deletes route with routeID
|
||||||
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error {
|
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
|
|||||||
@@ -524,265 +524,6 @@ func TestSaveRoute(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateRoute(t *testing.T) {
|
|
||||||
routeID := "testingRouteID"
|
|
||||||
|
|
||||||
existingRoute := &route.Route{
|
|
||||||
ID: routeID,
|
|
||||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
|
||||||
NetID: "superRoute",
|
|
||||||
NetworkType: route.IPv4Network,
|
|
||||||
Peer: peer1ID,
|
|
||||||
Description: "super",
|
|
||||||
Masquerade: false,
|
|
||||||
Metric: 9999,
|
|
||||||
Enabled: true,
|
|
||||||
Groups: []string{routeGroup1},
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
existingRoute *route.Route
|
|
||||||
operations []RouteUpdateOperation
|
|
||||||
shouldCreate bool
|
|
||||||
errFunc require.ErrorAssertionFunc
|
|
||||||
expectedRoute *route.Route
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Happy Path Single OPS",
|
|
||||||
existingRoute: existingRoute,
|
|
||||||
operations: []RouteUpdateOperation{
|
|
||||||
{
|
|
||||||
Type: UpdateRoutePeer,
|
|
||||||
Values: []string{peer2ID},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.NoError,
|
|
||||||
shouldCreate: true,
|
|
||||||
expectedRoute: &route.Route{
|
|
||||||
ID: routeID,
|
|
||||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
|
||||||
NetID: "superRoute",
|
|
||||||
NetworkType: route.IPv4Network,
|
|
||||||
Peer: peer2ID,
|
|
||||||
Description: "super",
|
|
||||||
Masquerade: false,
|
|
||||||
Metric: 9999,
|
|
||||||
Enabled: true,
|
|
||||||
Groups: []string{routeGroup1},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Happy Path Multiple OPS",
|
|
||||||
existingRoute: existingRoute,
|
|
||||||
operations: []RouteUpdateOperation{
|
|
||||||
{
|
|
||||||
Type: UpdateRouteDescription,
|
|
||||||
Values: []string{"great"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Type: UpdateRouteNetwork,
|
|
||||||
Values: []string{"192.168.0.0/24"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Type: UpdateRoutePeer,
|
|
||||||
Values: []string{peer2ID},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Type: UpdateRouteMetric,
|
|
||||||
Values: []string{"3030"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Type: UpdateRouteMasquerade,
|
|
||||||
Values: []string{"true"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Type: UpdateRouteEnabled,
|
|
||||||
Values: []string{"false"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Type: UpdateRouteNetworkIdentifier,
|
|
||||||
Values: []string{"megaRoute"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Type: UpdateRouteGroups,
|
|
||||||
Values: []string{routeGroup2},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.NoError,
|
|
||||||
shouldCreate: true,
|
|
||||||
expectedRoute: &route.Route{
|
|
||||||
ID: routeID,
|
|
||||||
Network: netip.MustParsePrefix("192.168.0.0/24"),
|
|
||||||
NetID: "megaRoute",
|
|
||||||
NetworkType: route.IPv4Network,
|
|
||||||
Peer: peer2ID,
|
|
||||||
Description: "great",
|
|
||||||
Masquerade: true,
|
|
||||||
Metric: 3030,
|
|
||||||
Enabled: false,
|
|
||||||
Groups: []string{routeGroup2},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Empty Values Should Fail",
|
|
||||||
existingRoute: existingRoute,
|
|
||||||
operations: []RouteUpdateOperation{
|
|
||||||
{
|
|
||||||
Type: UpdateRoutePeer,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Multiple Values Should Fail",
|
|
||||||
existingRoute: existingRoute,
|
|
||||||
operations: []RouteUpdateOperation{
|
|
||||||
{
|
|
||||||
Type: UpdateRoutePeer,
|
|
||||||
Values: []string{peer2ID, peer1ID},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Bad Prefix Should Fail",
|
|
||||||
existingRoute: existingRoute,
|
|
||||||
operations: []RouteUpdateOperation{
|
|
||||||
{
|
|
||||||
Type: UpdateRouteNetwork,
|
|
||||||
Values: []string{"192.168.0.0/34"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Bad Peer Should Fail",
|
|
||||||
existingRoute: existingRoute,
|
|
||||||
operations: []RouteUpdateOperation{
|
|
||||||
{
|
|
||||||
Type: UpdateRoutePeer,
|
|
||||||
Values: []string{"non existing Peer"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Empty Peer",
|
|
||||||
existingRoute: existingRoute,
|
|
||||||
operations: []RouteUpdateOperation{
|
|
||||||
{
|
|
||||||
Type: UpdateRoutePeer,
|
|
||||||
Values: []string{""},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.NoError,
|
|
||||||
shouldCreate: true,
|
|
||||||
expectedRoute: &route.Route{
|
|
||||||
ID: routeID,
|
|
||||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
|
||||||
NetID: "superRoute",
|
|
||||||
NetworkType: route.IPv4Network,
|
|
||||||
Peer: "",
|
|
||||||
Description: "super",
|
|
||||||
Masquerade: false,
|
|
||||||
Metric: 9999,
|
|
||||||
Enabled: true,
|
|
||||||
Groups: []string{routeGroup1},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Large Network ID Should Fail",
|
|
||||||
existingRoute: existingRoute,
|
|
||||||
operations: []RouteUpdateOperation{
|
|
||||||
{
|
|
||||||
Type: UpdateRouteNetworkIdentifier,
|
|
||||||
Values: []string{"12345678901234567890qwertyuiopqwertyuiop1"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Empty Network ID Should Fail",
|
|
||||||
existingRoute: existingRoute,
|
|
||||||
operations: []RouteUpdateOperation{
|
|
||||||
{
|
|
||||||
Type: UpdateRouteNetworkIdentifier,
|
|
||||||
Values: []string{""},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid Metric Should Fail",
|
|
||||||
existingRoute: existingRoute,
|
|
||||||
operations: []RouteUpdateOperation{
|
|
||||||
{
|
|
||||||
Type: UpdateRouteMetric,
|
|
||||||
Values: []string{"999999"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid Boolean Should Fail",
|
|
||||||
existingRoute: existingRoute,
|
|
||||||
operations: []RouteUpdateOperation{
|
|
||||||
{
|
|
||||||
Type: UpdateRouteMasquerade,
|
|
||||||
Values: []string{"yes"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid Group Should Fail",
|
|
||||||
existingRoute: existingRoute,
|
|
||||||
operations: []RouteUpdateOperation{
|
|
||||||
{
|
|
||||||
Type: UpdateRouteGroups,
|
|
||||||
Values: []string{routeInvalidGroup1},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
errFunc: require.Error,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, testCase := range testCases {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
am, err := createRouterManager(t)
|
|
||||||
if err != nil {
|
|
||||||
t.Error("failed to create account manager")
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := initTestRouteAccount(t, am)
|
|
||||||
if err != nil {
|
|
||||||
t.Error("failed to init testing account")
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Routes[testCase.existingRoute.ID] = testCase.existingRoute
|
|
||||||
|
|
||||||
err = am.Store.SaveAccount(account)
|
|
||||||
if err != nil {
|
|
||||||
t.Error("account should be saved")
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedRoute, err := am.UpdateRoute(account.Id, testCase.existingRoute.ID, testCase.operations)
|
|
||||||
|
|
||||||
testCase.errFunc(t, err)
|
|
||||||
|
|
||||||
if !testCase.shouldCreate {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase.expectedRoute.ID = updatedRoute.ID
|
|
||||||
|
|
||||||
if !testCase.expectedRoute.IsEqual(updatedRoute) {
|
|
||||||
t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", updatedRoute, testCase.expectedRoute)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDeleteRoute(t *testing.T) {
|
func TestDeleteRoute(t *testing.T) {
|
||||||
testingRoute := &route.Route{
|
testingRoute := &route.Route{
|
||||||
ID: "testingRoute",
|
ID: "testingRoute",
|
||||||
@@ -940,7 +681,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore)
|
return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createRouterStore(t *testing.T) (Store, error) {
|
func createRouterStore(t *testing.T) (Store, error) {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package telemetry
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
"go.opentelemetry.io/otel/metric/instrument"
|
"go.opentelemetry.io/otel/metric/instrument"
|
||||||
"go.opentelemetry.io/otel/metric/instrument/syncint64"
|
"go.opentelemetry.io/otel/metric/instrument/syncint64"
|
||||||
@@ -13,6 +14,7 @@ type IDPMetrics struct {
|
|||||||
getUserByEmailCounter syncint64.Counter
|
getUserByEmailCounter syncint64.Counter
|
||||||
getAllAccountsCounter syncint64.Counter
|
getAllAccountsCounter syncint64.Counter
|
||||||
createUserCounter syncint64.Counter
|
createUserCounter syncint64.Counter
|
||||||
|
deleteUserCounter syncint64.Counter
|
||||||
getAccountCounter syncint64.Counter
|
getAccountCounter syncint64.Counter
|
||||||
getUserByIDCounter syncint64.Counter
|
getUserByIDCounter syncint64.Counter
|
||||||
authenticateRequestCounter syncint64.Counter
|
authenticateRequestCounter syncint64.Counter
|
||||||
@@ -39,6 +41,10 @@ func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
deleteUserCounter, err := meter.SyncInt64().Counter("management.idp.delete.user.counter", instrument.WithUnit("1"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
getAccountCounter, err := meter.SyncInt64().Counter("management.idp.get.account.counter", instrument.WithUnit("1"))
|
getAccountCounter, err := meter.SyncInt64().Counter("management.idp.get.account.counter", instrument.WithUnit("1"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -65,6 +71,7 @@ func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error)
|
|||||||
getUserByEmailCounter: getUserByEmailCounter,
|
getUserByEmailCounter: getUserByEmailCounter,
|
||||||
getAllAccountsCounter: getAllAccountsCounter,
|
getAllAccountsCounter: getAllAccountsCounter,
|
||||||
createUserCounter: createUserCounter,
|
createUserCounter: createUserCounter,
|
||||||
|
deleteUserCounter: deleteUserCounter,
|
||||||
getAccountCounter: getAccountCounter,
|
getAccountCounter: getAccountCounter,
|
||||||
getUserByIDCounter: getUserByIDCounter,
|
getUserByIDCounter: getUserByIDCounter,
|
||||||
authenticateRequestCounter: authenticateRequestCounter,
|
authenticateRequestCounter: authenticateRequestCounter,
|
||||||
@@ -88,6 +95,11 @@ func (idpMetrics *IDPMetrics) CountCreateUser() {
|
|||||||
idpMetrics.createUserCounter.Add(idpMetrics.ctx, 1)
|
idpMetrics.createUserCounter.Add(idpMetrics.ctx, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CountDeleteUser ...
|
||||||
|
func (idpMetrics *IDPMetrics) CountDeleteUser() {
|
||||||
|
idpMetrics.deleteUserCounter.Add(idpMetrics.ctx, 1)
|
||||||
|
}
|
||||||
|
|
||||||
// CountGetAllAccounts ...
|
// CountGetAllAccounts ...
|
||||||
func (idpMetrics *IDPMetrics) CountGetAllAccounts() {
|
func (idpMetrics *IDPMetrics) CountGetAllAccounts() {
|
||||||
idpMetrics.getAllAccountsCounter.Add(idpMetrics.ctx, 1)
|
idpMetrics.getAllAccountsCounter.Add(idpMetrics.ctx, 1)
|
||||||
|
|||||||
@@ -309,6 +309,9 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (
|
|||||||
|
|
||||||
// DeleteUser deletes a user from the given account.
|
// DeleteUser deletes a user from the given account.
|
||||||
func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error {
|
func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error {
|
||||||
|
if initiatorUserID == targetUserID {
|
||||||
|
return status.Errorf(status.InvalidArgument, "self deletion is not allowed")
|
||||||
|
}
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -327,15 +330,43 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t
|
|||||||
return status.Errorf(status.NotFound, "user not found")
|
return status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
if executingUser.Role != UserRoleAdmin {
|
if executingUser.Role != UserRoleAdmin {
|
||||||
return status.Errorf(status.PermissionDenied, "only admins can delete service users")
|
return status.Errorf(status.PermissionDenied, "only admins can delete users")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !targetUser.IsServiceUser {
|
peers, err := account.FindUserPeers(targetUserID)
|
||||||
return status.Errorf(status.PermissionDenied, "regular users can not be deleted")
|
if err != nil {
|
||||||
|
return status.Errorf(status.Internal, "failed to find user peers")
|
||||||
}
|
}
|
||||||
|
|
||||||
meta := map[string]any{"name": targetUser.ServiceUserName}
|
if err := am.expireAndUpdatePeers(account, peers); err != nil {
|
||||||
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
|
log.Errorf("failed update deleted peers expiration: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(account.Id, initiatorUserID, targetUserID)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to resolve email address: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var meta map[string]any
|
||||||
|
var eventAction activity.Activity
|
||||||
|
if targetUser.IsServiceUser {
|
||||||
|
meta = map[string]any{"name": targetUser.ServiceUserName}
|
||||||
|
eventAction = activity.ServiceUserDeleted
|
||||||
|
} else {
|
||||||
|
meta = map[string]any{"name": tuName, "email": tuEmail}
|
||||||
|
eventAction = activity.UserDeleted
|
||||||
|
}
|
||||||
|
am.storeEvent(initiatorUserID, targetUserID, accountID, eventAction, meta)
|
||||||
|
|
||||||
|
if !targetUser.IsServiceUser && !isNil(am.idpManager) {
|
||||||
|
err := am.deleteUserFromIDP(targetUserID, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to delete user from IDP: %s", targetUserID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
delete(account.Users, targetUserID)
|
delete(account.Users, targetUserID)
|
||||||
|
|
||||||
@@ -609,23 +640,10 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var peerIDs []string
|
|
||||||
for _, peer := range blockedPeers {
|
|
||||||
peerIDs = append(peerIDs, peer.ID)
|
|
||||||
peer.MarkLoginExpired(true)
|
|
||||||
account.UpdatePeer(peer)
|
|
||||||
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
am.peersUpdateManager.CloseChannels(peerIDs)
|
|
||||||
err = am.updateAccountPeers(account)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed updating account peers while expiring peers of a blocked user %s", accountID)
|
|
||||||
return nil, err
|
|
||||||
|
|
||||||
|
if err := am.expireAndUpdatePeers(account, blockedPeers); err != nil {
|
||||||
|
log.Errorf("failed update expired peers: %s", err)
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -814,6 +832,67 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
|||||||
return userInfos, nil
|
return userInfos, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// expireAndUpdatePeers expires all peers of the given user and updates them in the account
|
||||||
|
func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers []*Peer) error {
|
||||||
|
var peerIDs []string
|
||||||
|
for _, peer := range peers {
|
||||||
|
peerIDs = append(peerIDs, peer.ID)
|
||||||
|
peer.MarkLoginExpired(true)
|
||||||
|
account.UpdatePeer(peer)
|
||||||
|
if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
am.storeEvent(
|
||||||
|
peer.UserID, peer.ID, account.Id,
|
||||||
|
activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(peerIDs) != 0 {
|
||||||
|
// this will trigger peer disconnect from the management service
|
||||||
|
am.peersUpdateManager.CloseChannels(peerIDs)
|
||||||
|
if err := am.updateAccountPeers(account); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) deleteUserFromIDP(targetUserID, accountID string) error {
|
||||||
|
if am.userDeleteFromIDPEnabled {
|
||||||
|
log.Debugf("user %s deleted from IdP", targetUserID)
|
||||||
|
err := am.idpManager.DeleteUser(targetUserID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete user %s from IdP: %s", targetUserID, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err := am.idpManager.UpdateUserAppMetadata(targetUserID, idp.AppMetadata{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to remove user %s app metadata in IdP: %s", targetUserID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = am.refreshCache(accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("refresh account (%q) cache: %v", accountID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) getEmailAndNameOfTargetUser(accountId, initiatorId, targetId string) (string, string, error) {
|
||||||
|
userInfos, err := am.GetUsersFromAccount(accountId, initiatorId)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
for _, ui := range userInfos {
|
||||||
|
if ui.ID == targetId {
|
||||||
|
return ui.Email, ui.Name, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", "", fmt.Errorf("user info not found for user: %s", targetId)
|
||||||
|
}
|
||||||
|
|
||||||
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
|
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
|
||||||
for _, user := range userData {
|
for _, user := range userData {
|
||||||
if user.ID == userID {
|
if user.ID == userID {
|
||||||
|
|||||||
@@ -424,7 +424,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
|||||||
assert.Nil(t, store.Accounts[mockAccountID].Users[mockServiceUserID])
|
assert.Nil(t, store.Accounts[mockAccountID].Users[mockServiceUserID])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_DeleteUser_regularUser(t *testing.T) {
|
func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||||
store := newStore(t)
|
store := newStore(t)
|
||||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||||
|
|
||||||
@@ -439,8 +439,35 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = am.DeleteUser(mockAccountID, mockUserID, mockUserID)
|
err = am.DeleteUser(mockAccountID, mockUserID, mockUserID)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("failed to prevent self deletion")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
assert.Errorf(t, err, "Regular users can not be deleted (yet)")
|
func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||||
|
targetId := "user2"
|
||||||
|
account.Users[targetId] = &User{
|
||||||
|
Id: targetId,
|
||||||
|
IsServiceUser: true,
|
||||||
|
ServiceUserName: "user2username",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := store.SaveAccount(account)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error when saving account: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
am := DefaultAccountManager{
|
||||||
|
Store: store,
|
||||||
|
eventStore: &activity.InMemoryEventStore{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = am.DeleteUser(mockAccountID, mockUserID, targetId)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultAccountManager_GetUser(t *testing.T) {
|
func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||||
|
|||||||
64
util/file.go
64
util/file.go
@@ -5,6 +5,8 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WriteJson writes JSON config object to a file creating parent directories if required
|
// WriteJson writes JSON config object to a file creating parent directories if required
|
||||||
@@ -54,6 +56,68 @@ func WriteJson(file string, obj interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
|
||||||
|
func DirectWriteJson(file string, obj interface{}) error {
|
||||||
|
|
||||||
|
_, _, err := prepareConfigFileDir(file)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
targetFile, err := openOrCreateFile(file)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err = targetFile.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to close file %s: %v", file, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// make it pretty
|
||||||
|
bs, err := json.MarshalIndent(obj, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = targetFile.Truncate(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = targetFile.Write(bs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func openOrCreateFile(file string) (*os.File, error) {
|
||||||
|
s, err := os.Stat(file)
|
||||||
|
if err == nil {
|
||||||
|
return os.OpenFile(file, os.O_WRONLY, s.Mode())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
targetFile, err := os.Create(file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
//no:lint
|
||||||
|
err = targetFile.Chmod(0640)
|
||||||
|
if err != nil {
|
||||||
|
_ = targetFile.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return targetFile, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ReadJson reads JSON config file and maps to a provided interface
|
// ReadJson reads JSON config file and maps to a provided interface
|
||||||
func ReadJson(file string, res interface{}) (interface{}, error) {
|
func ReadJson(file string, res interface{}) (interface{}, error) {
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user