diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
index ab23f178e..c4bd3140b 100644
--- a/.github/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -2,6 +2,10 @@
## Issue ticket number and link
+## Stack
+
+
+
### Checklist
- [ ] Is it a bug fix
- [ ] Is a typo/documentation fix
diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml
index 5a3c6c22e..174b7d205 100644
--- a/.github/workflows/test-infrastructure-files.yml
+++ b/.github/workflows/test-infrastructure-files.yml
@@ -178,6 +178,7 @@ jobs:
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
+ grep DisablePromptLogin management.json | grep 'true'
- name: Install modules
run: go mod tidy
diff --git a/README.md b/README.md
index e39382acd..4ab9db03b 100644
--- a/README.md
+++ b/README.md
@@ -57,16 +57,16 @@
### Key features
-| Connectivity | Management | Security | Automation | Platforms |
-|------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
-|
| - - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)
| - - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)
| - - \[x] [Public API](https://docs.netbird.io/api)
| |
-| - - \[x] Peer-to-peer connections
| - - \[x] Auto peer discovery and configuration
| - - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)
| - - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)
| |
-| - - \[x] Connection relay fallback
| - - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)
| - - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity)
| - - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)
| |
-| - - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)
| - - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)
| - - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)
| - - \[x] IdP groups sync with JWT
| |
-| - - \[x] NAT traversal with BPF
| - - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)
| - - \[x] Peer-to-peer encryption
| | |
-| | | - - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)
| | |
-| | | - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication) | | - - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)
|
-| | | | | |
+| Connectivity | Management | Security | Automation| Platforms |
+|----|----|----|----|----|
+| | - - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)
| - - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)
| - - \[x] [Public API](https://docs.netbird.io/api)
| |
+| - - \[x] Peer-to-peer connections
| - - \[x] Auto peer discovery and configuration
| - - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)
| - - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)
| - - \[x] Mac
|
+| - - \[x] Connection relay fallback
| - - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)
| - - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity)
| - - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)
| - - \[x] Windows
|
+| - - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)
| - - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)
| - - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)
| - - \[x] IdP groups sync with JWT
| - - \[x] Android
|
+| - - \[x] NAT traversal with BPF
| - - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)
| - - \[x] Peer-to-peer encryption
|| - - \[x] iOS
|
+||| - - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)
|| - - \[x] OpenWRT
|
+||| - - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)
|| - - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)
|
+||||| - - \[x] Docker
|
### Quickstart with NetBird Cloud
diff --git a/client/cmd/debug.go b/client/cmd/debug.go
index c02f60aed..d2e5bdd7e 100644
--- a/client/cmd/debug.go
+++ b/client/cmd/debug.go
@@ -11,9 +11,12 @@ import (
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
+ "github.com/netbirdio/netbird/client/internal/debug"
+ "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
+ mgmProto "github.com/netbirdio/netbird/management/proto"
)
const errCloseConnection = "Failed to close connection: %v"
@@ -326,3 +329,34 @@ func formatDuration(d time.Duration) string {
s := d / time.Second
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
}
+
+func generateDebugBundle(config *internal.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
+ var networkMap *mgmProto.NetworkMap
+ var err error
+
+ if connectClient != nil {
+ networkMap, err = connectClient.GetLatestNetworkMap()
+ if err != nil {
+ log.Warnf("Failed to get latest network map: %v", err)
+ }
+ }
+
+ bundleGenerator := debug.NewBundleGenerator(
+ debug.GeneratorDependencies{
+ InternalConfig: config,
+ StatusRecorder: recorder,
+ NetworkMap: networkMap,
+ LogFile: logFilePath,
+ },
+ debug.BundleConfig{
+ IncludeSystemInfo: true,
+ },
+ )
+
+ path, err := bundleGenerator.Generate()
+ if err != nil {
+ log.Errorf("Failed to generate debug bundle: %v", err)
+ return
+ }
+ log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
+}
diff --git a/client/cmd/debug_unix.go b/client/cmd/debug_unix.go
new file mode 100644
index 000000000..45ace7e13
--- /dev/null
+++ b/client/cmd/debug_unix.go
@@ -0,0 +1,39 @@
+//go:build unix
+
+package cmd
+
+import (
+ "context"
+ "os"
+ "os/signal"
+ "syscall"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/internal"
+ "github.com/netbirdio/netbird/client/internal/peer"
+)
+
+func SetupDebugHandler(
+ ctx context.Context,
+ config *internal.Config,
+ recorder *peer.Status,
+ connectClient *internal.ConnectClient,
+ logFilePath string,
+) {
+ usr1Ch := make(chan os.Signal, 1)
+
+ signal.Notify(usr1Ch, syscall.SIGUSR1)
+
+ go func() {
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-usr1Ch:
+ log.Info("Received SIGUSR1. Triggering debug bundle generation.")
+ go generateDebugBundle(config, recorder, connectClient, logFilePath)
+ }
+ }
+ }()
+}
diff --git a/client/cmd/debug_windows.go b/client/cmd/debug_windows.go
new file mode 100644
index 000000000..f57955fd4
--- /dev/null
+++ b/client/cmd/debug_windows.go
@@ -0,0 +1,126 @@
+package cmd
+
+import (
+ "context"
+ "errors"
+ "os"
+ "strconv"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/sys/windows"
+
+ "github.com/netbirdio/netbird/client/internal"
+ "github.com/netbirdio/netbird/client/internal/peer"
+)
+
+const (
+ envListenEvent = "NB_LISTEN_DEBUG_EVENT"
+ debugTriggerEventName = `Global\NetbirdDebugTriggerEvent`
+
+ waitTimeout = 5 * time.Second
+)
+
+// SetupDebugHandler sets up a Windows event to listen for a signal to generate a debug bundle.
+// Example usage with PowerShell:
+// $evt = [System.Threading.EventWaitHandle]::OpenExisting("Global\NetbirdDebugTriggerEvent")
+// $evt.Set()
+// $evt.Close()
+func SetupDebugHandler(
+ ctx context.Context,
+ config *internal.Config,
+ recorder *peer.Status,
+ connectClient *internal.ConnectClient,
+ logFilePath string,
+) {
+ env := os.Getenv(envListenEvent)
+ if env == "" {
+ return
+ }
+
+ listenEvent, err := strconv.ParseBool(env)
+ if err != nil {
+ log.Errorf("Failed to parse %s: %v", envListenEvent, err)
+ return
+ }
+ if !listenEvent {
+ return
+ }
+
+ eventNamePtr, err := windows.UTF16PtrFromString(debugTriggerEventName)
+ if err != nil {
+ log.Errorf("Failed to convert event name '%s' to UTF16: %v", debugTriggerEventName, err)
+ return
+ }
+
+ // TODO: restrict access by ACL
+ eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
+ if err != nil {
+ if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
+ log.Warnf("Debug trigger event '%s' already exists. Attempting to open.", debugTriggerEventName)
+ // SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
+ eventHandle, err = windows.OpenEvent(windows.SYNCHRONIZE|windows.EVENT_MODIFY_STATE, false, eventNamePtr)
+ if err != nil {
+ log.Errorf("Failed to open existing debug trigger event '%s': %v", debugTriggerEventName, err)
+ return
+ }
+ log.Infof("Successfully opened existing debug trigger event '%s'.", debugTriggerEventName)
+ } else {
+ log.Errorf("Failed to create debug trigger event '%s': %v", debugTriggerEventName, err)
+ return
+ }
+ }
+
+ if eventHandle == windows.InvalidHandle {
+ log.Errorf("Obtained an invalid handle for debug trigger event '%s'", debugTriggerEventName)
+ return
+ }
+
+ log.Infof("Debug handler waiting for signal on event: %s", debugTriggerEventName)
+
+ go waitForEvent(ctx, config, recorder, connectClient, logFilePath, eventHandle)
+}
+
+func waitForEvent(
+ ctx context.Context,
+ config *internal.Config,
+ recorder *peer.Status,
+ connectClient *internal.ConnectClient,
+ logFilePath string,
+ eventHandle windows.Handle,
+) {
+ defer func() {
+ if err := windows.CloseHandle(eventHandle); err != nil {
+ log.Errorf("Failed to close debug event handle '%s': %v", debugTriggerEventName, err)
+ }
+ }()
+
+ for {
+ if ctx.Err() != nil {
+ return
+ }
+
+ status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
+
+ switch status {
+ case windows.WAIT_OBJECT_0:
+ log.Info("Received signal on debug event. Triggering debug bundle generation.")
+
+ // reset the event so it can be triggered again later (manual reset == 1)
+ if err := windows.ResetEvent(eventHandle); err != nil {
+ log.Errorf("Failed to reset debug event '%s': %v", debugTriggerEventName, err)
+ }
+
+ go generateDebugBundle(config, recorder, connectClient, logFilePath)
+ case uint32(windows.WAIT_TIMEOUT):
+
+ default:
+ log.Errorf("Unexpected status %d from WaitForSingleObject for debug event '%s': %v", status, debugTriggerEventName, err)
+ select {
+ case <-time.After(5 * time.Second):
+ case <-ctx.Done():
+ return
+ }
+ }
+ }
+}
diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go
index 761c86628..0ddf6c4c8 100644
--- a/client/cmd/service_controller.go
+++ b/client/cmd/service_controller.go
@@ -115,6 +115,7 @@ var runCmd = &cobra.Command{
ctx, cancel := context.WithCancel(cmd.Context())
SetupCloseHandler(ctx, cancel)
+ SetupDebugHandler(ctx, nil, nil, nil, logFile)
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
if err != nil {
diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go
index 31bff26cb..70abe4abe 100644
--- a/client/cmd/testutil_test.go
+++ b/client/cmd/testutil_test.go
@@ -92,11 +92,11 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
- permissionsManagerMock := permissions.NewManagerMock()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
+ permissionsManagerMock := permissions.NewMockManager(ctrl)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
if err != nil {
diff --git a/client/cmd/up.go b/client/cmd/up.go
index 8b716a96d..bfe41628e 100644
--- a/client/cmd/up.go
+++ b/client/cmd/up.go
@@ -219,6 +219,8 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
+ SetupDebugHandler(ctx, config, r, connectClient, "")
+
return connectClient.Run(nil)
}
diff --git a/client/iface/bind/udp_mux.go b/client/iface/bind/udp_mux.go
index 5a471bf24..0e58499aa 100644
--- a/client/iface/bind/udp_mux.go
+++ b/client/iface/bind/udp_mux.go
@@ -458,6 +458,6 @@ func newBufferHolder(size int) *bufferHolder {
func getLogger() logging.LeveledLogger {
fac := logging.NewDefaultLoggerFactory()
- fac.Writer = log.StandardLogger().Writer()
+ //fac.Writer = log.StandardLogger().Writer()
return fac.NewLogger("ice")
}
diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go
index 6c2323412..c5bd84cd5 100644
--- a/client/internal/auth/pkce_flow.go
+++ b/client/internal/auth/pkce_flow.go
@@ -94,13 +94,17 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
p.codeVerifier = codeVerifier
codeChallenge := createCodeChallenge(codeVerifier)
- authURL := p.oAuthConfig.AuthCodeURL(
- state,
+
+ params := []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
- oauth2.SetAuthURLParam("prompt", "login"),
- )
+ }
+ if !p.providerConfig.DisablePromptLogin {
+ params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
+ }
+
+ authURL := p.oAuthConfig.AuthCodeURL(state, params...)
return AuthFlowInfo{
VerificationURIComplete: authURL,
diff --git a/client/internal/auth/pkce_flow_test.go b/client/internal/auth/pkce_flow_test.go
new file mode 100644
index 000000000..4510ed338
--- /dev/null
+++ b/client/internal/auth/pkce_flow_test.go
@@ -0,0 +1,49 @@
+package auth
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "github.com/netbirdio/netbird/client/internal"
+)
+
+func TestPromptLogin(t *testing.T) {
+ tt := []struct {
+ name string
+ prompt bool
+ }{
+ {"PromptLogin", true},
+ {"NoPromptLogin", false},
+ }
+
+ for _, tc := range tt {
+ t.Run(tc.name, func(t *testing.T) {
+ config := internal.PKCEAuthProviderConfig{
+ ClientID: "test-client-id",
+ Audience: "test-audience",
+ TokenEndpoint: "https://test-token-endpoint.com/token",
+ Scope: "openid email profile",
+ AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
+ RedirectURLs: []string{"http://127.0.0.1:33992/"},
+ UseIDToken: true,
+ DisablePromptLogin: !tc.prompt,
+ }
+ pkce, err := NewPKCEAuthorizationFlow(config)
+ if err != nil {
+ t.Fatalf("Failed to create PKCEAuthorizationFlow: %v", err)
+ }
+ authInfo, err := pkce.RequestAuthInfo(context.Background())
+ if err != nil {
+ t.Fatalf("Failed to request auth info: %v", err)
+ }
+ pattern := "prompt=login"
+ if tc.prompt {
+ require.Contains(t, authInfo.VerificationURIComplete, pattern)
+ } else {
+ require.NotContains(t, authInfo.VerificationURIComplete, pattern)
+ }
+ })
+ }
+}
diff --git a/client/internal/connect.go b/client/internal/connect.go
index 504c88c6f..832d58dcd 100644
--- a/client/internal/connect.go
+++ b/client/internal/connect.go
@@ -349,6 +349,25 @@ func (c *ConnectClient) Engine() *Engine {
return e
}
+// GetLatestNetworkMap returns the latest network map from the engine.
+func (c *ConnectClient) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
+ engine := c.Engine()
+ if engine == nil {
+ return nil, errors.New("engine is not initialized")
+ }
+
+ networkMap, err := engine.GetLatestNetworkMap()
+ if err != nil {
+ return nil, fmt.Errorf("get latest network map: %w", err)
+ }
+
+ if networkMap == nil {
+ return nil, errors.New("network map is not available")
+ }
+
+ return networkMap, nil
+}
+
// Status returns the current client status
func (c *ConnectClient) Status() StatusType {
if c == nil {
diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go
new file mode 100644
index 000000000..e07f981fe
--- /dev/null
+++ b/client/internal/debug/debug.go
@@ -0,0 +1,1022 @@
+package debug
+
+import (
+ "archive/zip"
+ "bufio"
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "io/fs"
+ "net"
+ "net/netip"
+ "os"
+ "path/filepath"
+ "runtime"
+ "runtime/pprof"
+ "sort"
+ "strings"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "google.golang.org/protobuf/encoding/protojson"
+
+ "github.com/netbirdio/netbird/client/anonymize"
+ "github.com/netbirdio/netbird/client/internal"
+ "github.com/netbirdio/netbird/client/internal/peer"
+ "github.com/netbirdio/netbird/client/internal/statemanager"
+ mgmProto "github.com/netbirdio/netbird/management/proto"
+)
+
+const readmeContent = `Netbird debug bundle
+This debug bundle contains the following files.
+If the --anonymize flag is set, the files are anonymized to protect sensitive information.
+
+status.txt: Anonymized status information of the NetBird client.
+client.log: Most recent, anonymized client log file of the NetBird client.
+netbird.err: Most recent, anonymized stderr log file of the NetBird client.
+netbird.out: Most recent, anonymized stdout log file of the NetBird client.
+routes.txt: Anonymized system routes, if --system-info flag was provided.
+interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
+iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
+nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
+config.txt: Anonymized configuration information of the NetBird client.
+network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules.
+state.json: Anonymized client state dump containing netbird states.
+mutex.prof: Mutex profiling information.
+goroutine.prof: Goroutine profiling information.
+block.prof: Block profiling information.
+heap.prof: Heap profiling information (snapshot of memory allocations).
+allocs.prof: Allocations profiling information.
+threadcreate.prof: Thread creation profiling information.
+
+
+Anonymization Process
+The files in this bundle have been anonymized to protect sensitive information. Here's how the anonymization was applied:
+
+IP Addresses
+
+IPv4 addresses are replaced with addresses starting from 198.51.100.0
+IPv6 addresses are replaced with addresses starting from 100::
+
+IP addresses from non public ranges and well known addresses are not anonymized (e.g. 8.8.8.8, 100.64.0.0/10, addresses starting with 192.168., 172.16., 10., etc.).
+Reoccuring IP addresses are replaced with the same anonymized address.
+
+Note: The anonymized IP addresses in the status file do not match those in the log and routes files. However, the anonymized IP addresses are consistent within the status file and across the routes and log files.
+
+Domains
+All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle.
+Reoccuring domain names are replaced with the same anonymized domain.
+
+Network Map
+The network_map.json file contains the following anonymized information:
+- Peer configurations (addresses, FQDNs, DNS settings)
+- Remote and offline peer information (allowed IPs, FQDNs)
+- Routes (network ranges, associated domains)
+- DNS configuration (nameservers, domains, custom zones)
+- Firewall rules (peer IPs, source/destination ranges)
+
+SSH keys in the network map are replaced with a placeholder value. All IP addresses and domains in the network map follow the same anonymization rules as described above.
+
+State File
+The state.json file contains anonymized internal state information of the NetBird client, including:
+- DNS settings and configuration
+- Firewall rules
+- Exclusion routes
+- Route selection
+- Other internal states that may be present
+
+The state file follows the same anonymization rules as other files:
+- IP addresses (both individual and CIDR ranges) are anonymized while preserving their structure
+- Domain names are consistently anonymized
+- Technical identifiers and non-sensitive data remain unchanged
+
+Mutex, Goroutines, Block, and Heap Profiling Files
+The goroutine, block, mutex, and heap profiling files contain process information that might help the NetBird team diagnose performance or memory issues. The information in these files doesn't contain personal data.
+You can check each using the following go command:
+
+go tool pprof -http=:8088 .prof
+
+For example, to view the heap profile:
+go tool pprof -http=:8088 heap.prof
+
+This will open a web browser tab with the profiling information.
+
+Routes
+For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct.
+
+Network Interfaces
+The interfaces.txt file contains information about network interfaces, including:
+- Interface name
+- Interface index
+- MTU (Maximum Transmission Unit)
+- Flags
+- IP addresses associated with each interface
+
+The IP addresses in the interfaces file are anonymized using the same process as described above. Interface names, indexes, MTUs, and flags are not anonymized.
+
+Configuration
+The config.txt file contains anonymized configuration information of the NetBird client. Sensitive information such as private keys and SSH keys are excluded. The following fields are anonymized:
+- ManagementURL
+- AdminURL
+- NATExternalIPs
+- CustomDNSAddress
+
+Other non-sensitive configuration options are included without anonymization.
+
+Firewall Rules (Linux only)
+The bundle includes two separate firewall rule files:
+
+iptables.txt:
+- Complete iptables ruleset with packet counters using 'iptables -v -n -L'
+- Includes all tables (filter, nat, mangle, raw, security)
+- Shows packet and byte counters for each rule
+- All IP addresses are anonymized
+- Chain names, table names, and other non-sensitive information remain unchanged
+
+nftables.txt:
+- Complete nftables ruleset obtained via 'nft -a list ruleset'
+- Includes rule handle numbers and packet counters
+- All tables, chains, and rules are included
+- Shows packet and byte counters for each rule
+- All IP addresses are anonymized
+- Chain names, table names, and other non-sensitive information remain unchanged
+`
+
+const (
+ clientLogFile = "client.log"
+ errorLogFile = "netbird.err"
+ stdoutLogFile = "netbird.out"
+
+ darwinErrorLogPath = "/var/log/netbird.out.log"
+ darwinStdoutLogPath = "/var/log/netbird.err.log"
+)
+
+type BundleGenerator struct {
+ anonymizer *anonymize.Anonymizer
+
+ // deps
+ internalConfig *internal.Config
+ statusRecorder *peer.Status
+ networkMap *mgmProto.NetworkMap
+ logFile string
+
+ // config
+ anonymize bool
+ clientStatus string
+ includeSystemInfo bool
+
+ archive *zip.Writer
+}
+
+type BundleConfig struct {
+ Anonymize bool
+ ClientStatus string
+ IncludeSystemInfo bool
+}
+
+type GeneratorDependencies struct {
+ InternalConfig *internal.Config
+ StatusRecorder *peer.Status
+ NetworkMap *mgmProto.NetworkMap
+ LogFile string
+}
+
+func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
+ return &BundleGenerator{
+ anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
+
+ internalConfig: deps.InternalConfig,
+ statusRecorder: deps.StatusRecorder,
+ networkMap: deps.NetworkMap,
+ logFile: deps.LogFile,
+
+ anonymize: cfg.Anonymize,
+ clientStatus: cfg.ClientStatus,
+ includeSystemInfo: cfg.IncludeSystemInfo,
+ }
+}
+
+// Generate creates a debug bundle and returns the location.
+func (g *BundleGenerator) Generate() (resp string, err error) {
+ bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
+ if err != nil {
+ return "", fmt.Errorf("create zip file: %w", err)
+ }
+ defer func() {
+ if closeErr := bundlePath.Close(); closeErr != nil && err == nil {
+ err = fmt.Errorf("close zip file: %w", closeErr)
+ }
+
+ if err != nil {
+ if removeErr := os.Remove(bundlePath.Name()); removeErr != nil {
+ log.Errorf("Failed to remove zip file: %v", removeErr)
+ }
+ }
+ }()
+
+ g.archive = zip.NewWriter(bundlePath)
+
+ if err := g.createArchive(); err != nil {
+ return "", err
+ }
+
+ if err := g.archive.Close(); err != nil {
+ return "", fmt.Errorf("close archive writer: %w", err)
+ }
+
+ return bundlePath.Name(), nil
+}
+
+func (g *BundleGenerator) createArchive() error {
+ if err := g.addReadme(); err != nil {
+ return fmt.Errorf("add readme: %w", err)
+ }
+
+ if err := g.addStatus(); err != nil {
+ return fmt.Errorf("add status: %w", err)
+ }
+
+ if g.statusRecorder != nil {
+ status := g.statusRecorder.GetFullStatus()
+ seedFromStatus(g.anonymizer, &status)
+ } else {
+ log.Debugf("no status recorder available for seeding")
+ }
+
+ if err := g.addConfig(); err != nil {
+ log.Errorf("Failed to add config to debug bundle: %v", err)
+ }
+
+ if g.includeSystemInfo {
+ g.addSystemInfo()
+ }
+
+ if err := g.addProf(); err != nil {
+ log.Errorf("Failed to add profiles to debug bundle: %v", err)
+ }
+
+ if err := g.addNetworkMap(); err != nil {
+ return fmt.Errorf("add network map: %w", err)
+ }
+
+ if err := g.addStateFile(); err != nil {
+ log.Errorf("Failed to add state file to debug bundle: %v", err)
+ }
+
+ if err := g.addCorruptedStateFiles(); err != nil {
+ log.Errorf("Failed to add corrupted state files to debug bundle: %v", err)
+ }
+
+ if g.logFile != "console" {
+ if err := g.addLogfile(); err != nil {
+ return fmt.Errorf("add log file: %w", err)
+ }
+ }
+ return nil
+}
+
+func (g *BundleGenerator) addSystemInfo() {
+ if err := g.addRoutes(); err != nil {
+ log.Errorf("Failed to add routes to debug bundle: %v", err)
+ }
+
+ if err := g.addInterfaces(); err != nil {
+ log.Errorf("Failed to add interfaces to debug bundle: %v", err)
+ }
+
+ if err := g.addFirewallRules(); err != nil {
+ log.Errorf("Failed to add firewall rules to debug bundle: %v", err)
+ }
+}
+
+func (g *BundleGenerator) addReadme() error {
+ readmeReader := strings.NewReader(readmeContent)
+ if err := g.addFileToZip(readmeReader, "README.txt"); err != nil {
+ return fmt.Errorf("add README file to zip: %w", err)
+ }
+ return nil
+}
+
+func (g *BundleGenerator) addStatus() error {
+ if status := g.clientStatus; status != "" {
+ statusReader := strings.NewReader(status)
+ if err := g.addFileToZip(statusReader, "status.txt"); err != nil {
+ return fmt.Errorf("add status file to zip: %w", err)
+ }
+ }
+ return nil
+}
+
+func (g *BundleGenerator) addConfig() error {
+ if g.internalConfig == nil {
+ log.Debug("skipping empty config in debug bundle")
+ return nil
+ }
+
+ var configContent strings.Builder
+ g.addCommonConfigFields(&configContent)
+
+ if g.anonymize {
+ if g.internalConfig.ManagementURL != nil {
+ configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String())))
+ }
+ if g.internalConfig.AdminURL != nil {
+ configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String())))
+ }
+ configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", anonymizeNATExternalIPs(g.internalConfig.NATExternalIPs, g.anonymizer)))
+ if g.internalConfig.CustomDNSAddress != "" {
+ configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress)))
+ }
+ } else {
+ if g.internalConfig.ManagementURL != nil {
+ configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", g.internalConfig.ManagementURL.String()))
+ }
+ if g.internalConfig.AdminURL != nil {
+ configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", g.internalConfig.AdminURL.String()))
+ }
+ configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", g.internalConfig.NATExternalIPs))
+ if g.internalConfig.CustomDNSAddress != "" {
+ configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", g.internalConfig.CustomDNSAddress))
+ }
+ }
+
+ // Add config content to zip file
+ configReader := strings.NewReader(configContent.String())
+ if err := g.addFileToZip(configReader, "config.txt"); err != nil {
+ return fmt.Errorf("add config file to zip: %w", err)
+ }
+
+ return nil
+}
+
+func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
+ configContent.WriteString("NetBird Client Configuration:\n\n")
+
+ // Add non-sensitive fields
+ configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface))
+ configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort))
+ if g.internalConfig.NetworkMonitor != nil {
+ configContent.WriteString(fmt.Sprintf("NetworkMonitor: %v\n", *g.internalConfig.NetworkMonitor))
+ }
+ configContent.WriteString(fmt.Sprintf("IFaceBlackList: %v\n", g.internalConfig.IFaceBlackList))
+ configContent.WriteString(fmt.Sprintf("DisableIPv6Discovery: %v\n", g.internalConfig.DisableIPv6Discovery))
+ configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", g.internalConfig.RosenpassEnabled))
+ configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", g.internalConfig.RosenpassPermissive))
+ if g.internalConfig.ServerSSHAllowed != nil {
+ configContent.WriteString(fmt.Sprintf("BundleGeneratorSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
+ }
+ configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", g.internalConfig.DisableAutoConnect))
+ configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", g.internalConfig.DNSRouteInterval))
+
+ configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
+ configContent.WriteString(fmt.Sprintf("DisableBundleGeneratorRoutes: %v\n", g.internalConfig.DisableServerRoutes))
+ configContent.WriteString(fmt.Sprintf("DisableDNS: %v\n", g.internalConfig.DisableDNS))
+ configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
+
+ configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
+}
+
+func (g *BundleGenerator) addProf() (err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ err = fmt.Errorf("panic while profiling: %v", r)
+ }
+ }()
+
+ runtime.SetBlockProfileRate(1)
+ _ = runtime.SetMutexProfileFraction(1)
+ defer runtime.SetBlockProfileRate(0)
+ defer runtime.SetMutexProfileFraction(0)
+
+ time.Sleep(5 * time.Second)
+
+ for _, profile := range []string{"goroutine", "block", "mutex", "heap", "allocs", "threadcreate"} {
+ var buff []byte
+ myBuff := bytes.NewBuffer(buff)
+ err := pprof.Lookup(profile).WriteTo(myBuff, 0)
+ if err != nil {
+ return fmt.Errorf("write %s profile: %w", profile, err)
+ }
+
+ if err := g.addFileToZip(myBuff, profile+".prof"); err != nil {
+ return fmt.Errorf("add %s file to zip: %w", profile, err)
+ }
+ }
+ return nil
+}
+
+func (g *BundleGenerator) addInterfaces() error {
+ interfaces, err := net.Interfaces()
+ if err != nil {
+ return fmt.Errorf("get interfaces: %w", err)
+ }
+
+ interfacesContent := formatInterfaces(interfaces, g.anonymize, g.anonymizer)
+ interfacesReader := strings.NewReader(interfacesContent)
+ if err := g.addFileToZip(interfacesReader, "interfaces.txt"); err != nil {
+ return fmt.Errorf("add interfaces file to zip: %w", err)
+ }
+
+ return nil
+}
+
+func (g *BundleGenerator) addNetworkMap() error {
+ if g.networkMap == nil {
+ log.Debugf("skipping empty network map in debug bundle")
+ return nil
+ }
+
+ if g.anonymize {
+ if err := anonymizeNetworkMap(g.networkMap, g.anonymizer); err != nil {
+ return fmt.Errorf("anonymize network map: %w", err)
+ }
+ }
+
+ options := protojson.MarshalOptions{
+ EmitUnpopulated: true,
+ UseProtoNames: true,
+ Indent: " ",
+ AllowPartial: true,
+ }
+
+ jsonBytes, err := options.Marshal(g.networkMap)
+ if err != nil {
+ return fmt.Errorf("generate json: %w", err)
+ }
+
+ if err := g.addFileToZip(bytes.NewReader(jsonBytes), "network_map.json"); err != nil {
+ return fmt.Errorf("add network map to zip: %w", err)
+ }
+
+ return nil
+}
+
+func (g *BundleGenerator) addStateFile() error {
+ path := statemanager.GetDefaultStatePath()
+ if path == "" {
+ return nil
+ }
+
+ data, err := os.ReadFile(path)
+ if err != nil {
+ if errors.Is(err, fs.ErrNotExist) {
+ return nil
+ }
+ return fmt.Errorf("read state file: %w", err)
+ }
+
+ if g.anonymize {
+ var rawStates map[string]json.RawMessage
+ if err := json.Unmarshal(data, &rawStates); err != nil {
+ return fmt.Errorf("unmarshal states: %w", err)
+ }
+
+ if err := anonymizeStateFile(&rawStates, g.anonymizer); err != nil {
+ return fmt.Errorf("anonymize state file: %w", err)
+ }
+
+ bs, err := json.MarshalIndent(rawStates, "", " ")
+ if err != nil {
+ return fmt.Errorf("marshal states: %w", err)
+ }
+ data = bs
+ }
+
+ if err := g.addFileToZip(bytes.NewReader(data), "state.json"); err != nil {
+ return fmt.Errorf("add state file to zip: %w", err)
+ }
+
+ return nil
+}
+
+func (g *BundleGenerator) addCorruptedStateFiles() error {
+ pattern := statemanager.GetDefaultStatePath()
+ if pattern == "" {
+ return nil
+ }
+ pattern += "*.corrupted.*"
+ matches, err := filepath.Glob(pattern)
+ if err != nil {
+ return fmt.Errorf("find corrupted state files: %w", err)
+ }
+
+ for _, match := range matches {
+ data, err := os.ReadFile(match)
+ if err != nil {
+ log.Warnf("Failed to read corrupted state file %s: %v", match, err)
+ continue
+ }
+
+ fileName := filepath.Base(match)
+ if err := g.addFileToZip(bytes.NewReader(data), "corrupted_states/"+fileName); err != nil {
+ log.Warnf("Failed to add corrupted state file %s to zip: %v", fileName, err)
+ continue
+ }
+
+ log.Debugf("Added corrupted state file to debug bundle: %s", fileName)
+ }
+
+ return nil
+}
+
+func (g *BundleGenerator) addLogfile() error {
+ if g.logFile == "" {
+ log.Debugf("skipping empty log file in debug bundle")
+ return nil
+ }
+
+ logDir := filepath.Dir(g.logFile)
+
+ if err := g.addSingleLogfile(g.logFile, clientLogFile); err != nil {
+ return fmt.Errorf("add client log file to zip: %w", err)
+ }
+
+ stdErrLogPath := filepath.Join(logDir, errorLogFile)
+ stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
+ if runtime.GOOS == "darwin" {
+ stdErrLogPath = darwinErrorLogPath
+ stdoutLogPath = darwinStdoutLogPath
+ }
+
+ if err := g.addSingleLogfile(stdErrLogPath, errorLogFile); err != nil {
+ log.Warnf("Failed to add %s to zip: %v", errorLogFile, err)
+ }
+
+ if err := g.addSingleLogfile(stdoutLogPath, stdoutLogFile); err != nil {
+ log.Warnf("Failed to add %s to zip: %v", stdoutLogFile, err)
+ }
+
+ return nil
+}
+
+// addSingleLogfile adds a single log file to the archive
+func (g *BundleGenerator) addSingleLogfile(logPath, targetName string) error {
+ logFile, err := os.Open(logPath)
+ if err != nil {
+ return fmt.Errorf("open log file %s: %w", targetName, err)
+ }
+ defer func() {
+ if err := logFile.Close(); err != nil {
+ log.Errorf("Failed to close log file %s: %v", targetName, err)
+ }
+ }()
+
+ var logReader io.Reader
+ if g.anonymize {
+ var writer *io.PipeWriter
+ logReader, writer = io.Pipe()
+
+ go anonymizeLog(logFile, writer, g.anonymizer)
+ } else {
+ logReader = logFile
+ }
+
+ if err := g.addFileToZip(logReader, targetName); err != nil {
+ return fmt.Errorf("add %s to zip: %w", targetName, err)
+ }
+
+ return nil
+}
+
+func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error {
+ header := &zip.FileHeader{
+ Name: filename,
+ Method: zip.Deflate,
+ Modified: time.Now(),
+
+ CreatorVersion: 20, // Version 2.0
+ ReaderVersion: 20, // Version 2.0
+ Flags: 0x800, // UTF-8 filename
+ }
+
+ // If the reader is a file, we can get more accurate information
+ if f, ok := reader.(*os.File); ok {
+ if stat, err := f.Stat(); err != nil {
+ log.Tracef("Failed to get file stat for %s: %v", filename, err)
+ } else {
+ header.Modified = stat.ModTime()
+ }
+ }
+
+ writer, err := g.archive.CreateHeader(header)
+ if err != nil {
+ return fmt.Errorf("create zip file header: %w", err)
+ }
+
+ if _, err := io.Copy(writer, reader); err != nil {
+ return fmt.Errorf("write file to zip: %w", err)
+ }
+
+ return nil
+}
+
+func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
+ status.ManagementState.URL = a.AnonymizeURI(status.ManagementState.URL)
+ status.SignalState.URL = a.AnonymizeURI(status.SignalState.URL)
+
+ status.LocalPeerState.FQDN = a.AnonymizeDomain(status.LocalPeerState.FQDN)
+
+ for _, p := range status.Peers {
+ a.AnonymizeDomain(p.FQDN)
+ for route := range p.GetRoutes() {
+ a.AnonymizeRoute(route)
+ }
+ }
+
+ for route := range status.LocalPeerState.Routes {
+ a.AnonymizeRoute(route)
+ }
+
+ for _, nsGroup := range status.NSGroupStates {
+ for _, domain := range nsGroup.Domains {
+ a.AnonymizeDomain(domain)
+ }
+ }
+
+ for _, relay := range status.Relays {
+ if relay.URI != "" {
+ a.AnonymizeURI(relay.URI)
+ }
+ }
+}
+
+func formatRoutes(routes []netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) string {
+ var ipv4Routes, ipv6Routes []netip.Prefix
+
+ // Separate IPv4 and IPv6 routes
+ for _, route := range routes {
+ if route.Addr().Is4() {
+ ipv4Routes = append(ipv4Routes, route)
+ } else {
+ ipv6Routes = append(ipv6Routes, route)
+ }
+ }
+
+ // Sort IPv4 and IPv6 routes separately
+ sort.Slice(ipv4Routes, func(i, j int) bool {
+ return ipv4Routes[i].Bits() > ipv4Routes[j].Bits()
+ })
+ sort.Slice(ipv6Routes, func(i, j int) bool {
+ return ipv6Routes[i].Bits() > ipv6Routes[j].Bits()
+ })
+
+ var builder strings.Builder
+
+ // Format IPv4 routes
+ builder.WriteString("IPv4 Routes:\n")
+ for _, route := range ipv4Routes {
+ formatRoute(&builder, route, anonymize, anonymizer)
+ }
+
+ // Format IPv6 routes
+ builder.WriteString("\nIPv6 Routes:\n")
+ for _, route := range ipv6Routes {
+ formatRoute(&builder, route, anonymize, anonymizer)
+ }
+
+ return builder.String()
+}
+
+func formatRoute(builder *strings.Builder, route netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) {
+ if anonymize {
+ anonymizedIP := anonymizer.AnonymizeIP(route.Addr())
+ builder.WriteString(fmt.Sprintf("%s/%d\n", anonymizedIP, route.Bits()))
+ } else {
+ builder.WriteString(fmt.Sprintf("%s\n", route))
+ }
+}
+
+func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *anonymize.Anonymizer) string {
+ sort.Slice(interfaces, func(i, j int) bool {
+ return interfaces[i].Name < interfaces[j].Name
+ })
+
+ var builder strings.Builder
+ builder.WriteString("Network Interfaces:\n")
+
+ for _, iface := range interfaces {
+ builder.WriteString(fmt.Sprintf("\nInterface: %s\n", iface.Name))
+ builder.WriteString(fmt.Sprintf(" Index: %d\n", iface.Index))
+ builder.WriteString(fmt.Sprintf(" MTU: %d\n", iface.MTU))
+ builder.WriteString(fmt.Sprintf(" Flags: %v\n", iface.Flags))
+
+ addrs, err := iface.Addrs()
+ if err != nil {
+ builder.WriteString(fmt.Sprintf(" Addresses: Error retrieving addresses: %v\n", err))
+ } else {
+ builder.WriteString(" Addresses:\n")
+ for _, addr := range addrs {
+ prefix, err := netip.ParsePrefix(addr.String())
+ if err != nil {
+ builder.WriteString(fmt.Sprintf(" Error parsing address: %v\n", err))
+ continue
+ }
+ ip := prefix.Addr()
+ if anonymize {
+ ip = anonymizer.AnonymizeIP(ip)
+ }
+ builder.WriteString(fmt.Sprintf(" %s/%d\n", ip, prefix.Bits()))
+ }
+ }
+ }
+
+ return builder.String()
+}
+
+func anonymizeLog(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) {
+ defer func() {
+ // always nil
+ _ = writer.Close()
+ }()
+
+ scanner := bufio.NewScanner(reader)
+ for scanner.Scan() {
+ line := anonymizer.AnonymizeString(scanner.Text())
+ if _, err := writer.Write([]byte(line + "\n")); err != nil {
+ if err := writer.CloseWithError(fmt.Errorf("anonymize write: %w", err)); err != nil {
+ log.Errorf("Failed to close writer: %v", err)
+ }
+ return
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ if err := writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err)); err != nil {
+ log.Errorf("Failed to close writer: %v", err)
+ }
+ return
+ }
+}
+
+func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string {
+ anonymizedIPs := make([]string, len(ips))
+ for i, ip := range ips {
+ parts := strings.SplitN(ip, "/", 2)
+
+ ip1, err := netip.ParseAddr(parts[0])
+ if err != nil {
+ anonymizedIPs[i] = ip
+ continue
+ }
+ ip1anon := anonymizer.AnonymizeIP(ip1)
+
+ if len(parts) == 2 {
+ ip2, err := netip.ParseAddr(parts[1])
+ if err != nil {
+ anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, parts[1])
+ } else {
+ ip2anon := anonymizer.AnonymizeIP(ip2)
+ anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, ip2anon)
+ }
+ } else {
+ anonymizedIPs[i] = ip1anon.String()
+ }
+ }
+ return anonymizedIPs
+}
+
+func anonymizeNetworkMap(networkMap *mgmProto.NetworkMap, anonymizer *anonymize.Anonymizer) error {
+ if networkMap.PeerConfig != nil {
+ anonymizePeerConfig(networkMap.PeerConfig, anonymizer)
+ }
+
+ for _, p := range networkMap.RemotePeers {
+ anonymizeRemotePeer(p, anonymizer)
+ }
+
+ for _, p := range networkMap.OfflinePeers {
+ anonymizeRemotePeer(p, anonymizer)
+ }
+
+ for _, r := range networkMap.Routes {
+ anonymizeRoute(r, anonymizer)
+ }
+
+ if networkMap.DNSConfig != nil {
+ anonymizeDNSConfig(networkMap.DNSConfig, anonymizer)
+ }
+
+ for _, rule := range networkMap.FirewallRules {
+ anonymizeFirewallRule(rule, anonymizer)
+ }
+
+ for _, rule := range networkMap.RoutesFirewallRules {
+ anonymizeRouteFirewallRule(rule, anonymizer)
+ }
+
+ return nil
+}
+
+func anonymizePeerConfig(config *mgmProto.PeerConfig, anonymizer *anonymize.Anonymizer) {
+ if config == nil {
+ return
+ }
+
+ if addr, err := netip.ParseAddr(config.Address); err == nil {
+ config.Address = anonymizer.AnonymizeIP(addr).String()
+ }
+
+ if config.SshConfig != nil && len(config.SshConfig.SshPubKey) > 0 {
+ config.SshConfig.SshPubKey = []byte("ssh-placeholder-key")
+ }
+
+ config.Dns = anonymizer.AnonymizeString(config.Dns)
+ config.Fqdn = anonymizer.AnonymizeDomain(config.Fqdn)
+}
+
+func anonymizeRemotePeer(peer *mgmProto.RemotePeerConfig, anonymizer *anonymize.Anonymizer) {
+ if peer == nil {
+ return
+ }
+
+ for i, ip := range peer.AllowedIps {
+ // Try to parse as prefix first (CIDR)
+ if prefix, err := netip.ParsePrefix(ip); err == nil {
+ anonIP := anonymizer.AnonymizeIP(prefix.Addr())
+ peer.AllowedIps[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
+ } else if addr, err := netip.ParseAddr(ip); err == nil {
+ peer.AllowedIps[i] = anonymizer.AnonymizeIP(addr).String()
+ }
+ }
+
+ peer.Fqdn = anonymizer.AnonymizeDomain(peer.Fqdn)
+
+ if peer.SshConfig != nil && len(peer.SshConfig.SshPubKey) > 0 {
+ peer.SshConfig.SshPubKey = []byte("ssh-placeholder-key")
+ }
+}
+
+func anonymizeRoute(route *mgmProto.Route, anonymizer *anonymize.Anonymizer) {
+ if route == nil {
+ return
+ }
+
+ if prefix, err := netip.ParsePrefix(route.Network); err == nil {
+ anonIP := anonymizer.AnonymizeIP(prefix.Addr())
+ route.Network = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
+ }
+
+ for i, domain := range route.Domains {
+ route.Domains[i] = anonymizer.AnonymizeDomain(domain)
+ }
+
+ route.NetID = anonymizer.AnonymizeString(route.NetID)
+}
+
+func anonymizeDNSConfig(config *mgmProto.DNSConfig, anonymizer *anonymize.Anonymizer) {
+ if config == nil {
+ return
+ }
+
+ anonymizeNameBundleGeneratorGroups(config.NameServerGroups, anonymizer)
+ anonymizeCustomZones(config.CustomZones, anonymizer)
+}
+
+func anonymizeNameBundleGeneratorGroups(groups []*mgmProto.NameServerGroup, anonymizer *anonymize.Anonymizer) {
+ for _, group := range groups {
+ anonymizeBundleGenerators(group.NameServers, anonymizer)
+ anonymizeDomains(group.Domains, anonymizer)
+ }
+}
+
+func anonymizeBundleGenerators(servers []*mgmProto.NameServer, anonymizer *anonymize.Anonymizer) {
+ for _, server := range servers {
+ if addr, err := netip.ParseAddr(server.IP); err == nil {
+ server.IP = anonymizer.AnonymizeIP(addr).String()
+ }
+ }
+}
+
+func anonymizeDomains(domains []string, anonymizer *anonymize.Anonymizer) {
+ for i, domain := range domains {
+ domains[i] = anonymizer.AnonymizeDomain(domain)
+ }
+}
+
+func anonymizeCustomZones(zones []*mgmProto.CustomZone, anonymizer *anonymize.Anonymizer) {
+ for _, zone := range zones {
+ zone.Domain = anonymizer.AnonymizeDomain(zone.Domain)
+ anonymizeRecords(zone.Records, anonymizer)
+ }
+}
+
+func anonymizeRecords(records []*mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) {
+ for _, record := range records {
+ record.Name = anonymizer.AnonymizeDomain(record.Name)
+ anonymizeRData(record, anonymizer)
+ }
+}
+
+func anonymizeRData(record *mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) {
+ switch record.Type {
+ case 1, 28: // A or AAAA record
+ if addr, err := netip.ParseAddr(record.RData); err == nil {
+ record.RData = anonymizer.AnonymizeIP(addr).String()
+ }
+ default:
+ record.RData = anonymizer.AnonymizeString(record.RData)
+ }
+}
+
+func anonymizeFirewallRule(rule *mgmProto.FirewallRule, anonymizer *anonymize.Anonymizer) {
+ if rule == nil {
+ return
+ }
+
+ if addr, err := netip.ParseAddr(rule.PeerIP); err == nil {
+ rule.PeerIP = anonymizer.AnonymizeIP(addr).String()
+ }
+}
+
+func anonymizeRouteFirewallRule(rule *mgmProto.RouteFirewallRule, anonymizer *anonymize.Anonymizer) {
+ if rule == nil {
+ return
+ }
+
+ for i, sourceRange := range rule.SourceRanges {
+ if prefix, err := netip.ParsePrefix(sourceRange); err == nil {
+ anonIP := anonymizer.AnonymizeIP(prefix.Addr())
+ rule.SourceRanges[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
+ }
+ }
+
+ if prefix, err := netip.ParsePrefix(rule.Destination); err == nil {
+ anonIP := anonymizer.AnonymizeIP(prefix.Addr())
+ rule.Destination = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
+ }
+}
+
+func anonymizeStateFile(rawStates *map[string]json.RawMessage, anonymizer *anonymize.Anonymizer) error {
+ for name, rawState := range *rawStates {
+ if string(rawState) == "null" {
+ continue
+ }
+
+ var state map[string]any
+ if err := json.Unmarshal(rawState, &state); err != nil {
+ return fmt.Errorf("unmarshal state %s: %w", name, err)
+ }
+
+ state = anonymizeValue(state, anonymizer).(map[string]any)
+
+ bs, err := json.Marshal(state)
+ if err != nil {
+ return fmt.Errorf("marshal state %s: %w", name, err)
+ }
+
+ (*rawStates)[name] = bs
+ }
+
+ return nil
+}
+
+func anonymizeValue(value any, anonymizer *anonymize.Anonymizer) any {
+ switch v := value.(type) {
+ case string:
+ return anonymizeString(v, anonymizer)
+ case map[string]any:
+ return anonymizeMap(v, anonymizer)
+ case []any:
+ return anonymizeSlice(v, anonymizer)
+ }
+ return value
+}
+
+func anonymizeString(v string, anonymizer *anonymize.Anonymizer) string {
+ if prefix, err := netip.ParsePrefix(v); err == nil {
+ anonIP := anonymizer.AnonymizeIP(prefix.Addr())
+ return fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
+ }
+ if ip, err := netip.ParseAddr(v); err == nil {
+ return anonymizer.AnonymizeIP(ip).String()
+ }
+ return anonymizer.AnonymizeString(v)
+}
+
+func anonymizeMap(v map[string]any, anonymizer *anonymize.Anonymizer) map[string]any {
+ result := make(map[string]any, len(v))
+ for key, val := range v {
+ newKey := anonymizeMapKey(key, anonymizer)
+ result[newKey] = anonymizeValue(val, anonymizer)
+ }
+ return result
+}
+
+func anonymizeMapKey(key string, anonymizer *anonymize.Anonymizer) string {
+ if prefix, err := netip.ParsePrefix(key); err == nil {
+ anonIP := anonymizer.AnonymizeIP(prefix.Addr())
+ return fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
+ }
+ if ip, err := netip.ParseAddr(key); err == nil {
+ return anonymizer.AnonymizeIP(ip).String()
+ }
+ return key
+}
+
+func anonymizeSlice(v []any, anonymizer *anonymize.Anonymizer) []any {
+ for i, val := range v {
+ v[i] = anonymizeValue(val, anonymizer)
+ }
+ return v
+}
diff --git a/client/server/debug_linux.go b/client/internal/debug/debug_linux.go
similarity index 95%
rename from client/server/debug_linux.go
rename to client/internal/debug/debug_linux.go
index 60bc40561..291531fea 100644
--- a/client/server/debug_linux.go
+++ b/client/internal/debug/debug_linux.go
@@ -1,9 +1,8 @@
//go:build linux && !android
-package server
+package debug
import (
- "archive/zip"
"bytes"
"encoding/binary"
"fmt"
@@ -14,36 +13,31 @@ import (
"github.com/google/nftables"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
-
- "github.com/netbirdio/netbird/client/anonymize"
- "github.com/netbirdio/netbird/client/proto"
)
// addFirewallRules collects and adds firewall rules to the archive
-func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
+func (g *BundleGenerator) addFirewallRules() error {
log.Info("Collecting firewall rules")
- // Collect and add iptables rules
iptablesRules, err := collectIPTablesRules()
if err != nil {
log.Warnf("Failed to collect iptables rules: %v", err)
} else {
- if req.GetAnonymize() {
- iptablesRules = anonymizer.AnonymizeString(iptablesRules)
+ if g.anonymize {
+ iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
}
- if err := addFileToZip(archive, strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
+ if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
log.Warnf("Failed to add iptables rules to bundle: %v", err)
}
}
- // Collect and add nftables rules
nftablesRules, err := collectNFTablesRules()
if err != nil {
log.Warnf("Failed to collect nftables rules: %v", err)
} else {
- if req.GetAnonymize() {
- nftablesRules = anonymizer.AnonymizeString(nftablesRules)
+ if g.anonymize {
+ nftablesRules = g.anonymizer.AnonymizeString(nftablesRules)
}
- if err := addFileToZip(archive, strings.NewReader(nftablesRules), "nftables.txt"); err != nil {
+ if err := g.addFileToZip(strings.NewReader(nftablesRules), "nftables.txt"); err != nil {
log.Warnf("Failed to add nftables rules to bundle: %v", err)
}
}
@@ -65,16 +59,13 @@ func collectIPTablesRules() (string, error) {
builder.WriteString("\n")
}
- // Then get verbose statistics for each table
builder.WriteString("=== iptables -v -n -L output ===\n")
- // Get list of tables
tables := []string{"filter", "nat", "mangle", "raw", "security"}
for _, table := range tables {
builder.WriteString(fmt.Sprintf("*%s\n", table))
- // Get verbose statistics for the entire table
stats, err := getTableStatistics(table)
if err != nil {
log.Warnf("Failed to get statistics for table %s: %v", table, err)
@@ -182,12 +173,10 @@ func formatTables(conn *nftables.Conn, tables []*nftables.Table) string {
continue
}
- // Format chains
for _, chain := range chains {
formatChain(conn, table, chain, &builder)
}
- // Format sets
if sets, err := conn.GetSets(table); err != nil {
log.Warnf("Failed to get sets for table %s: %v", table.Name, err)
} else if len(sets) > 0 {
diff --git a/client/internal/debug/debug_mobile.go b/client/internal/debug/debug_mobile.go
new file mode 100644
index 000000000..c00c65132
--- /dev/null
+++ b/client/internal/debug/debug_mobile.go
@@ -0,0 +1,7 @@
+//go:build ios || android
+
+package debug
+
+func (g *BundleGenerator) addRoutes() error {
+ return nil
+}
diff --git a/client/internal/debug/debug_nonlinux.go b/client/internal/debug/debug_nonlinux.go
new file mode 100644
index 000000000..ef93620a0
--- /dev/null
+++ b/client/internal/debug/debug_nonlinux.go
@@ -0,0 +1,8 @@
+//go:build !linux || android
+
+package debug
+
+// collectFirewallRules returns nothing on non-linux systems
+func (g *BundleGenerator) addFirewallRules() error {
+ return nil
+}
diff --git a/client/internal/debug/debug_nonmobile.go b/client/internal/debug/debug_nonmobile.go
new file mode 100644
index 000000000..3b487f07f
--- /dev/null
+++ b/client/internal/debug/debug_nonmobile.go
@@ -0,0 +1,25 @@
+//go:build !ios && !android
+
+package debug
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
+)
+
+func (g *BundleGenerator) addRoutes() error {
+ routes, err := systemops.GetRoutesFromTable()
+ if err != nil {
+ return fmt.Errorf("get routes: %w", err)
+ }
+
+ // TODO: get routes including nexthop
+ routesContent := formatRoutes(routes, g.anonymize, g.anonymizer)
+ routesReader := strings.NewReader(routesContent)
+ if err := g.addFileToZip(routesReader, "routes.txt"); err != nil {
+ return fmt.Errorf("add routes file to zip: %w", err)
+ }
+ return nil
+}
diff --git a/client/server/debug_test.go b/client/internal/debug/debug_test.go
similarity index 99%
rename from client/server/debug_test.go
rename to client/internal/debug/debug_test.go
index ebd0bffbc..eb91fed66 100644
--- a/client/server/debug_test.go
+++ b/client/internal/debug/debug_test.go
@@ -1,4 +1,4 @@
-package server
+package debug
import (
"encoding/json"
diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go
index 3a25a23b6..76e18e3ce 100644
--- a/client/internal/dns/local.go
+++ b/client/internal/dns/local.go
@@ -71,6 +71,12 @@ func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
value, found := d.records.Load(key)
if !found {
+ // alternatively check if we have a cname
+ if question.Qtype != dns.TypeCNAME {
+ r.Question[0].Qtype = dns.TypeCNAME
+ return d.lookupRecords(r)
+ }
+
return nil
}
diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go
index 74ab6717f..65b90e5f0 100644
--- a/client/internal/dns/server.go
+++ b/client/internal/dns/server.go
@@ -467,6 +467,11 @@ func (s *DefaultServer) applyHostConfig() {
return
}
+ // prevent reapplying config if we're shutting down
+ if s.ctx.Err() != nil {
+ return
+ }
+
config := s.currentConfig
existingDomains := make(map[string]struct{})
diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go
index 097daa9e2..2d69ce858 100644
--- a/client/internal/dnsfwd/forwarder.go
+++ b/client/internal/dnsfwd/forwarder.go
@@ -3,6 +3,7 @@ package dnsfwd
import (
"context"
"errors"
+ "math"
"net"
"net/netip"
"strings"
@@ -62,7 +63,6 @@ func (f *DNSForwarder) UpdateDomains(domains []string, resIds map[string]string)
for _, d := range f.domains {
f.mux.HandleRemove(d)
- f.statusRecorder.RemoveResolvedIPLookupEntry(d)
}
f.resId.Clear()
@@ -122,8 +122,8 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
return
}
- resId, ok := f.resId.Load(strings.TrimSuffix(domain, "."))
- if ok {
+ resId := f.getResIdForDomain(strings.TrimSuffix(domain, "."))
+ if resId != "" {
for _, ip := range ips {
var ipWithSuffix string
if ip.Is4() {
@@ -133,7 +133,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
ipWithSuffix = ip.String() + "/128"
log.Tracef("resolved domain=%s to IPv6=%s", domain, ipWithSuffix)
}
- f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId.(string))
+ f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId)
}
}
@@ -204,6 +204,36 @@ func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []neti
}
}
+func (f *DNSForwarder) getResIdForDomain(domain string) string {
+ var selectedResId string
+ var bestScore int
+
+ f.resId.Range(func(key, value interface{}) bool {
+ var score int
+ pattern := key.(string)
+
+ switch {
+ case strings.HasPrefix(pattern, "*."):
+ baseDomain := strings.TrimPrefix(pattern, "*.")
+ if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
+ score = len(baseDomain)
+ }
+ case domain == pattern:
+ score = math.MaxInt
+ default:
+ return true
+ }
+
+ if score > bestScore {
+ bestScore = score
+ selectedResId = value.(string)
+ }
+ return true
+ })
+
+ return selectedResId
+}
+
// filterDomains returns a list of normalized domains
func filterDomains(domains []string) []string {
newDomains := make([]string, 0, len(domains))
diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go
new file mode 100644
index 000000000..88ffc2af3
--- /dev/null
+++ b/client/internal/dnsfwd/forwarder_test.go
@@ -0,0 +1,95 @@
+package dnsfwd
+
+import (
+ "sync"
+ "testing"
+)
+
+func TestGetResIdForDomain(t *testing.T) {
+ testCases := []struct {
+ name string
+ storedMappings map[string]string // key: domain pattern, value: resId
+ queryDomain string
+ expectedResId string
+ }{
+ {
+ name: "Empty map returns empty string",
+ storedMappings: map[string]string{},
+ queryDomain: "example.com",
+ expectedResId: "",
+ },
+ {
+ name: "Exact match returns stored resId",
+ storedMappings: map[string]string{"example.com": "res1"},
+ queryDomain: "example.com",
+ expectedResId: "res1",
+ },
+ {
+ name: "Wildcard pattern matches base domain",
+ storedMappings: map[string]string{"*.example.com": "res2"},
+ queryDomain: "example.com",
+ expectedResId: "res2",
+ },
+ {
+ name: "Wildcard pattern matches subdomain",
+ storedMappings: map[string]string{"*.example.com": "res3"},
+ queryDomain: "foo.example.com",
+ expectedResId: "res3",
+ },
+ {
+ name: "Wildcard pattern does not match different domain",
+ storedMappings: map[string]string{"*.example.com": "res4"},
+ queryDomain: "foo.notexample.com",
+ expectedResId: "",
+ },
+ {
+ name: "Non-wildcard pattern does not match subdomain",
+ storedMappings: map[string]string{"example.com": "res5"},
+ queryDomain: "foo.example.com",
+ expectedResId: "",
+ },
+ {
+ name: "Exact match over overlapping wildcard",
+ storedMappings: map[string]string{
+ "*.example.com": "resWildcard",
+ "foo.example.com": "resExact",
+ },
+ queryDomain: "foo.example.com",
+ expectedResId: "resExact",
+ },
+ {
+ name: "Overlapping wildcards: Select more specific wildcard",
+ storedMappings: map[string]string{
+ "*.example.com": "resA",
+ "*.sub.example.com": "resB",
+ },
+ queryDomain: "bar.sub.example.com",
+ expectedResId: "resB",
+ },
+ {
+ name: "Wildcard multi-level subdomain match",
+ storedMappings: map[string]string{
+ "*.example.com": "resMulti",
+ },
+ queryDomain: "a.b.example.com",
+ expectedResId: "resMulti",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ fwd := &DNSForwarder{
+ resId: sync.Map{},
+ }
+
+ for domainPattern, resId := range tc.storedMappings {
+ fwd.resId.Store(domainPattern, resId)
+ }
+
+ got := fwd.getResIdForDomain(tc.queryDomain)
+ if got != tc.expectedResId {
+ t.Errorf("For query domain %q, expected resId %q, but got %q", tc.queryDomain, tc.expectedResId, got)
+ }
+ })
+ }
+}
diff --git a/client/internal/engine.go b/client/internal/engine.go
index 74a07927c..c377c12e1 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -1231,36 +1231,19 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
PreSharedKey: e.config.PreSharedKey,
}
- if e.config.RosenpassEnabled && !e.config.RosenpassPermissive {
- lk := []byte(e.config.WgPrivateKey.PublicKey().String())
- rk := []byte(wgConfig.RemoteKey)
- var keyInput []byte
- if string(lk) > string(rk) {
- //nolint:gocritic
- keyInput = append(lk[:16], rk[:16]...)
- } else {
- //nolint:gocritic
- keyInput = append(rk[:16], lk[:16]...)
- }
-
- key, err := wgtypes.NewKey(keyInput)
- if err != nil {
- return nil, err
- }
-
- wgConfig.PreSharedKey = &key
- }
-
// randomize connection timeout
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
config := peer.ConnConfig{
- Key: pubKey,
- LocalKey: e.config.WgPrivateKey.PublicKey().String(),
- Timeout: timeout,
- WgConfig: wgConfig,
- LocalWgPort: e.config.WgPort,
- RosenpassPubKey: e.getRosenpassPubKey(),
- RosenpassAddr: e.getRosenpassAddr(),
+ Key: pubKey,
+ LocalKey: e.config.WgPrivateKey.PublicKey().String(),
+ Timeout: timeout,
+ WgConfig: wgConfig,
+ LocalWgPort: e.config.WgPort,
+ RosenpassConfig: peer.RosenpassConfig{
+ PubKey: e.getRosenpassPubKey(),
+ Addr: e.getRosenpassAddr(),
+ PermissiveMode: e.config.RosenpassPermissive,
+ },
ICEConfig: icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go
index 352abd62b..7afe0fcd6 100644
--- a/client/internal/engine_test.go
+++ b/client/internal/engine_test.go
@@ -1439,8 +1439,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
- permissionsManagerMock := permissions.NewManagerMock()
-
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
@@ -1449,7 +1447,9 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
Return(&types.Settings{}, nil).
AnyTimes()
- accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
+ permissionsManager := permissions.NewManager(store)
+
+ accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
if err != nil {
return nil, "", err
}
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index 85f94b53f..44e8997bc 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -60,6 +60,15 @@ type WgConfig struct {
PreSharedKey *wgtypes.Key
}
+type RosenpassConfig struct {
+ // RosenpassPubKey is this peer's Rosenpass public key
+ PubKey []byte
+ // RosenpassPubKey is this peer's RosenpassAddr server address (IP:port)
+ Addr string
+
+ PermissiveMode bool
+}
+
// ConnConfig is a peer Connection configuration
type ConnConfig struct {
// Key is a public key of a remote peer
@@ -73,10 +82,7 @@ type ConnConfig struct {
LocalWgPort int
- // RosenpassPubKey is this peer's Rosenpass public key
- RosenpassPubKey []byte
- // RosenpassPubKey is this peer's RosenpassAddr server address (IP:port)
- RosenpassAddr string
+ RosenpassConfig RosenpassConfig
// ICEConfig ICE protocol configuration
ICEConfig icemaker.Config
@@ -109,6 +115,8 @@ type Conn struct {
connIDICE nbnet.ConnectionID
beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc
+ // used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
+ rosenpassRemoteKey []byte
wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy
@@ -375,7 +383,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
wgProxy.Work()
}
- if err = conn.configureWGEndpoint(ep); err != nil {
+ if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil {
conn.handleConfigurationFailure(err, wgProxy)
return
}
@@ -408,7 +416,7 @@ func (conn *Conn) onICEStateDisconnected() {
conn.dumpState.SwitchToRelay()
conn.wgProxyRelay.Work()
- if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
+ if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err)
}
@@ -478,7 +486,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
}
wgProxy.Work()
- if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
+ if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close relay connection: %v", err)
}
@@ -493,6 +501,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
}()
wgConfigWorkaround()
+ conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = connPriorityRelay
conn.statusRelay.Set(StatusConnected)
conn.setRelayedProxy(wgProxy)
@@ -556,13 +565,14 @@ func (conn *Conn) listenGuardEvent(ctx context.Context) {
}
}
-func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error {
+func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error {
+ presharedKey := conn.presharedKey(remoteRPKey)
return conn.config.WgConfig.WgInterface.UpdatePeer(
conn.config.WgConfig.RemoteKey,
conn.config.WgConfig.AllowedIps,
defaultWgKeepAlive,
addr,
- conn.config.WgConfig.PreSharedKey,
+ presharedKey,
)
}
@@ -783,6 +793,44 @@ func (conn *Conn) AllowedIP() netip.Addr {
return conn.config.WgConfig.AllowedIps[0].Addr()
}
+func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
+ if conn.config.RosenpassConfig.PubKey == nil {
+ return conn.config.WgConfig.PreSharedKey
+ }
+
+ if remoteRosenpassKey == nil && conn.config.RosenpassConfig.PermissiveMode {
+ return conn.config.WgConfig.PreSharedKey
+ }
+
+ determKey, err := conn.rosenpassDetermKey()
+ if err != nil {
+ conn.log.Errorf("failed to generate Rosenpass initial key: %v", err)
+ return conn.config.WgConfig.PreSharedKey
+ }
+
+ return determKey
+}
+
+// todo: move this logic into Rosenpass package
+func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
+ lk := []byte(conn.config.LocalKey)
+ rk := []byte(conn.config.Key) // remote key
+ var keyInput []byte
+ if string(lk) > string(rk) {
+ //nolint:gocritic
+ keyInput = append(lk[:16], rk[:16]...)
+ } else {
+ //nolint:gocritic
+ keyInput = append(rk[:16], lk[:16]...)
+ }
+
+ key, err := wgtypes.NewKey(keyInput)
+ if err != nil {
+ return nil, err
+ }
+ return &key, nil
+}
+
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}
diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go
index 505bedb7f..6d55cfff4 100644
--- a/client/internal/peer/conn_test.go
+++ b/client/internal/peer/conn_test.go
@@ -2,6 +2,7 @@ package peer
import (
"context"
+ "fmt"
"os"
"sync"
"testing"
@@ -161,3 +162,145 @@ func TestConn_Status(t *testing.T) {
})
}
}
+
+func TestConn_presharedKey(t *testing.T) {
+ conn1 := Conn{
+ config: ConnConfig{
+ Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
+ LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
+ RosenpassConfig: RosenpassConfig{},
+ },
+ }
+ conn2 := Conn{
+ config: ConnConfig{
+ Key: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
+ LocalKey: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
+ RosenpassConfig: RosenpassConfig{},
+ },
+ }
+
+ tests := []struct {
+ conn1Permissive bool
+ conn1RosenpassEnabled bool
+ conn2Permissive bool
+ conn2RosenpassEnabled bool
+ conn1ExpectedInitialKey bool
+ conn2ExpectedInitialKey bool
+ }{
+ {
+ conn1Permissive: false,
+ conn1RosenpassEnabled: false,
+ conn2Permissive: false,
+ conn2RosenpassEnabled: false,
+ conn1ExpectedInitialKey: false,
+ conn2ExpectedInitialKey: false,
+ },
+ {
+ conn1Permissive: false,
+ conn1RosenpassEnabled: true,
+ conn2Permissive: false,
+ conn2RosenpassEnabled: true,
+ conn1ExpectedInitialKey: true,
+ conn2ExpectedInitialKey: true,
+ },
+ {
+ conn1Permissive: false,
+ conn1RosenpassEnabled: true,
+ conn2Permissive: false,
+ conn2RosenpassEnabled: false,
+ conn1ExpectedInitialKey: true,
+ conn2ExpectedInitialKey: false,
+ },
+ {
+ conn1Permissive: false,
+ conn1RosenpassEnabled: false,
+ conn2Permissive: false,
+ conn2RosenpassEnabled: true,
+ conn1ExpectedInitialKey: false,
+ conn2ExpectedInitialKey: true,
+ },
+ {
+ conn1Permissive: true,
+ conn1RosenpassEnabled: true,
+ conn2Permissive: false,
+ conn2RosenpassEnabled: false,
+ conn1ExpectedInitialKey: false,
+ conn2ExpectedInitialKey: false,
+ },
+ {
+ conn1Permissive: false,
+ conn1RosenpassEnabled: false,
+ conn2Permissive: true,
+ conn2RosenpassEnabled: true,
+ conn1ExpectedInitialKey: false,
+ conn2ExpectedInitialKey: false,
+ },
+ {
+ conn1Permissive: true,
+ conn1RosenpassEnabled: true,
+ conn2Permissive: true,
+ conn2RosenpassEnabled: true,
+ conn1ExpectedInitialKey: true,
+ conn2ExpectedInitialKey: true,
+ },
+ {
+ conn1Permissive: false,
+ conn1RosenpassEnabled: false,
+ conn2Permissive: false,
+ conn2RosenpassEnabled: true,
+ conn1ExpectedInitialKey: false,
+ conn2ExpectedInitialKey: true,
+ },
+ {
+ conn1Permissive: false,
+ conn1RosenpassEnabled: true,
+ conn2Permissive: true,
+ conn2RosenpassEnabled: true,
+ conn1ExpectedInitialKey: true,
+ conn2ExpectedInitialKey: true,
+ },
+ }
+
+ conn1.config.RosenpassConfig.PermissiveMode = true
+ for i, test := range tests {
+ tcase := i + 1
+ t.Run(fmt.Sprintf("Rosenpass test case %d", tcase), func(t *testing.T) {
+ conn1.config.RosenpassConfig = RosenpassConfig{}
+ conn2.config.RosenpassConfig = RosenpassConfig{}
+
+ if test.conn1RosenpassEnabled {
+ conn1.config.RosenpassConfig.PubKey = []byte("dummykey")
+ }
+ conn1.config.RosenpassConfig.PermissiveMode = test.conn1Permissive
+
+ if test.conn2RosenpassEnabled {
+ conn2.config.RosenpassConfig.PubKey = []byte("dummykey")
+ }
+ conn2.config.RosenpassConfig.PermissiveMode = test.conn2Permissive
+
+ conn1PresharedKey := conn1.presharedKey(conn2.config.RosenpassConfig.PubKey)
+ conn2PresharedKey := conn2.presharedKey(conn1.config.RosenpassConfig.PubKey)
+
+ if test.conn1ExpectedInitialKey {
+ if conn1PresharedKey == nil {
+ t.Errorf("Case %d: Expected conn1 to have a non-nil key, but got nil", tcase)
+ }
+ } else {
+ if conn1PresharedKey != nil {
+ t.Errorf("Case %d: Expected conn1 to have a nil key, but got %v", tcase, conn1PresharedKey)
+ }
+ }
+
+ // Assert conn2's key expectation
+ if test.conn2ExpectedInitialKey {
+ if conn2PresharedKey == nil {
+ t.Errorf("Case %d: Expected conn2 to have a non-nil key, but got nil", tcase)
+ }
+ } else {
+ if conn2PresharedKey != nil {
+ t.Errorf("Case %d: Expected conn2 to have a nil key, but got %v", tcase, conn2PresharedKey)
+ }
+ }
+ })
+ }
+}
diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go
index d23727e96..224ea0262 100644
--- a/client/internal/peer/handshaker.go
+++ b/client/internal/peer/handshaker.go
@@ -154,8 +154,8 @@ func (h *Handshaker) sendOffer() error {
IceCredentials: IceCredentials{iceUFrag, icePwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
- RosenpassPubKey: h.config.RosenpassPubKey,
- RosenpassAddr: h.config.RosenpassAddr,
+ RosenpassPubKey: h.config.RosenpassConfig.PubKey,
+ RosenpassAddr: h.config.RosenpassConfig.Addr,
}
addr, err := h.relay.RelayInstanceAddress()
@@ -174,8 +174,8 @@ func (h *Handshaker) sendAnswer() error {
IceCredentials: IceCredentials{uFrag, pwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
- RosenpassPubKey: h.config.RosenpassPubKey,
- RosenpassAddr: h.config.RosenpassAddr,
+ RosenpassPubKey: h.config.RosenpassConfig.PubKey,
+ RosenpassAddr: h.config.RosenpassConfig.Addr,
}
addr, err := h.relay.RelayInstanceAddress()
if err == nil {
diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go
index 2b66610e9..9b63cebf0 100644
--- a/client/internal/peer/ice/agent.go
+++ b/client/internal/peer/ice/agent.go
@@ -37,7 +37,8 @@ func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candida
}
fac := logging.NewDefaultLoggerFactory()
- fac.Writer = log.StandardLogger().Writer()
+
+ //fac.Writer = log.StandardLogger().Writer()
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
diff --git a/client/internal/pkce_auth.go b/client/internal/pkce_auth.go
index ac6734b0c..34eb2df1c 100644
--- a/client/internal/pkce_auth.go
+++ b/client/internal/pkce_auth.go
@@ -39,6 +39,8 @@ type PKCEAuthProviderConfig struct {
UseIDToken bool
// ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
+ // DisablePromptLogin makes the PKCE flow to not prompt the user for login
+ DisablePromptLogin bool
}
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
@@ -97,6 +99,7 @@ func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
ClientCertPair: clientCert,
+ DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
},
}
diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go
index 2874604fd..72c4758f4 100644
--- a/client/internal/routeselector/routeselector.go
+++ b/client/internal/routeselector/routeselector.go
@@ -10,20 +10,27 @@ import (
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/errors"
- route "github.com/netbirdio/netbird/route"
+ "github.com/netbirdio/netbird/route"
)
type RouteSelector struct {
mu sync.RWMutex
selectedRoutes map[route.NetID]struct{}
selectAll bool
+
+ // Indicates if new routes should be automatically selected
+ includeNewRoutes bool
+
+ // All known routes at the time of deselection
+ knownRoutes []route.NetID
}
func NewRouteSelector() *RouteSelector {
return &RouteSelector{
- selectedRoutes: map[route.NetID]struct{}{},
- // default selects all routes
- selectAll: true,
+ selectedRoutes: map[route.NetID]struct{}{},
+ selectAll: true,
+ includeNewRoutes: false,
+ knownRoutes: []route.NetID{},
}
}
@@ -46,6 +53,7 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
rs.selectedRoutes[route] = struct{}{}
}
rs.selectAll = false
+ rs.includeNewRoutes = false
return errors.FormatErrorOrNil(err)
}
@@ -57,16 +65,22 @@ func (rs *RouteSelector) SelectAllRoutes() {
rs.selectAll = true
rs.selectedRoutes = map[route.NetID]struct{}{}
+ rs.includeNewRoutes = false
}
// DeselectRoutes removes specific routes from the selection.
-// If the selector is in "select all" mode, it will transition to "select specific" mode.
+// If the selector is in "select all" mode, it will transition to "select specific" mode
+// but will keep new routes selected.
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.selectAll {
rs.selectAll = false
+ rs.includeNewRoutes = true
+ rs.knownRoutes = make([]route.NetID, len(allRoutes))
+ copy(rs.knownRoutes, allRoutes)
+
rs.selectedRoutes = map[route.NetID]struct{}{}
for _, route := range allRoutes {
rs.selectedRoutes[route] = struct{}{}
@@ -92,6 +106,7 @@ func (rs *RouteSelector) DeselectAllRoutes() {
defer rs.mu.Unlock()
rs.selectAll = false
+ rs.includeNewRoutes = false
rs.selectedRoutes = map[route.NetID]struct{}{}
}
@@ -103,8 +118,20 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
if rs.selectAll {
return true
}
+
+ // Check if the route exists in selectedRoutes
_, selected := rs.selectedRoutes[routeID]
- return selected
+ if selected {
+ return true
+ }
+
+ // If includeNewRoutes is true and this is a new route (not in knownRoutes),
+ // then it should be selected
+ if rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, routeID) {
+ return true
+ }
+
+ return false
}
// FilterSelected removes unselected routes from the provided map.
@@ -118,7 +145,11 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
filtered := route.HAMap{}
for id, rt := range routes {
- if rs.IsSelected(id.NetID()) {
+ netID := id.NetID()
+ _, selected := rs.selectedRoutes[netID]
+
+ // Include if directly selected or if it's a new route and includeNewRoutes is true
+ if selected || (rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, netID)) {
filtered[id] = rt
}
}
@@ -131,11 +162,15 @@ func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
defer rs.mu.RUnlock()
return json.Marshal(struct {
- SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
- SelectAll bool `json:"select_all"`
+ SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
+ SelectAll bool `json:"select_all"`
+ IncludeNewRoutes bool `json:"include_new_routes"`
+ KnownRoutes []route.NetID `json:"known_routes"`
}{
- SelectAll: rs.selectAll,
- SelectedRoutes: rs.selectedRoutes,
+ SelectAll: rs.selectAll,
+ SelectedRoutes: rs.selectedRoutes,
+ IncludeNewRoutes: rs.includeNewRoutes,
+ KnownRoutes: rs.knownRoutes,
})
}
@@ -149,12 +184,16 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
if len(data) == 0 || string(data) == "null" {
rs.selectedRoutes = map[route.NetID]struct{}{}
rs.selectAll = true
+ rs.includeNewRoutes = false
+ rs.knownRoutes = []route.NetID{}
return nil
}
var temp struct {
- SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
- SelectAll bool `json:"select_all"`
+ SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
+ SelectAll bool `json:"select_all"`
+ IncludeNewRoutes bool `json:"include_new_routes"`
+ KnownRoutes []route.NetID `json:"known_routes"`
}
if err := json.Unmarshal(data, &temp); err != nil {
@@ -163,10 +202,15 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
rs.selectedRoutes = temp.SelectedRoutes
rs.selectAll = temp.SelectAll
+ rs.includeNewRoutes = temp.IncludeNewRoutes
+ rs.knownRoutes = temp.KnownRoutes
if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
+ if rs.knownRoutes == nil {
+ rs.knownRoutes = []route.NetID{}
+ }
return nil
}
diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go
index b1671f254..a1461dff6 100644
--- a/client/internal/routeselector/routeselector_test.go
+++ b/client/internal/routeselector/routeselector_test.go
@@ -316,7 +316,7 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes)
},
// After deselecting specific routes, new routes should remain unselected
- wantNewSelected: []route.NetID{"route2", "route3"},
+ wantNewSelected: []route.NetID{"route2", "route3", "route4", "route5"},
},
{
name: "New routes after selecting with append",
@@ -358,3 +358,73 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
})
}
}
+
+func TestRouteSelector_MixedSelectionDeselection(t *testing.T) {
+ allRoutes := []route.NetID{"route1", "route2", "route3"}
+
+ tests := []struct {
+ name string
+ routesToSelect []route.NetID
+ selectAppend bool
+ routesToDeselect []route.NetID
+ selectFirst bool
+ wantSelectedFinal []route.NetID
+ }{
+ {
+ name: "1. Select A, then Deselect B",
+ routesToSelect: []route.NetID{"route1"},
+ selectAppend: false,
+ routesToDeselect: []route.NetID{"route2"},
+ selectFirst: true,
+ wantSelectedFinal: []route.NetID{"route1"},
+ },
+ {
+ name: "2. Select A, then Deselect A",
+ routesToSelect: []route.NetID{"route1"},
+ selectAppend: false,
+ routesToDeselect: []route.NetID{"route1"},
+ selectFirst: true,
+ wantSelectedFinal: []route.NetID{},
+ },
+ {
+ name: "3. Deselect A (from all), then Select B",
+ routesToSelect: []route.NetID{"route2"},
+ selectAppend: false,
+ routesToDeselect: []route.NetID{"route1"},
+ selectFirst: false,
+ wantSelectedFinal: []route.NetID{"route2"},
+ },
+ {
+ name: "4. Deselect A (from all), then Select A",
+ routesToSelect: []route.NetID{"route1"},
+ selectAppend: false,
+ routesToDeselect: []route.NetID{"route1"},
+ selectFirst: false,
+ wantSelectedFinal: []route.NetID{"route1"},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ rs := routeselector.NewRouteSelector()
+
+ var err1, err2 error
+
+ if tt.selectFirst {
+ err1 = rs.SelectRoutes(tt.routesToSelect, tt.selectAppend, allRoutes)
+ require.NoError(t, err1)
+ err2 = rs.DeselectRoutes(tt.routesToDeselect, allRoutes)
+ require.NoError(t, err2)
+ } else {
+ err1 = rs.DeselectRoutes(tt.routesToDeselect, allRoutes)
+ require.NoError(t, err1)
+ err2 = rs.SelectRoutes(tt.routesToSelect, tt.selectAppend, allRoutes)
+ require.NoError(t, err2)
+ }
+
+ for _, r := range allRoutes {
+ assert.Equal(t, slices.Contains(tt.wantSelectedFinal, r), rs.IsSelected(r), "Route %s final state mismatch", r)
+ }
+ })
+ }
+}
diff --git a/client/server/debug.go b/client/server/debug.go
index bdb1f7543..9ccfb13fb 100644
--- a/client/server/debug.go
+++ b/client/server/debug.go
@@ -3,558 +3,46 @@
package server
import (
- "archive/zip"
- "bufio"
- "bytes"
"context"
- "encoding/json"
"errors"
"fmt"
- "io"
- "io/fs"
- "net"
- "net/netip"
- "os"
- "path/filepath"
- "runtime"
- "runtime/pprof"
- "sort"
- "strings"
- "time"
log "github.com/sirupsen/logrus"
- "google.golang.org/protobuf/encoding/protojson"
- "github.com/netbirdio/netbird/client/anonymize"
- "github.com/netbirdio/netbird/client/internal/peer"
- "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
- "github.com/netbirdio/netbird/client/internal/statemanager"
+ "github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/proto"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
-const readmeContent = `Netbird debug bundle
-This debug bundle contains the following files:
-
-status.txt: Anonymized status information of the NetBird client.
-client.log: Most recent, anonymized client log file of the NetBird client.
-netbird.err: Most recent, anonymized stderr log file of the NetBird client.
-netbird.out: Most recent, anonymized stdout log file of the NetBird client.
-routes.txt: Anonymized system routes, if --system-info flag was provided.
-interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
-iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
-nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
-config.txt: Anonymized configuration information of the NetBird client.
-network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules.
-state.json: Anonymized client state dump containing netbird states.
-mutex.prof: Mutex profiling information.
-goroutine.prof: Goroutine profiling information.
-block.prof: Block profiling information.
-
-
-Anonymization Process
-The files in this bundle have been anonymized to protect sensitive information. Here's how the anonymization was applied:
-
-IP Addresses
-
-IPv4 addresses are replaced with addresses starting from 198.51.100.0
-IPv6 addresses are replaced with addresses starting from 100::
-
-IP addresses from non public ranges and well known addresses are not anonymized (e.g. 8.8.8.8, 100.64.0.0/10, addresses starting with 192.168., 172.16., 10., etc.).
-Reoccuring IP addresses are replaced with the same anonymized address.
-
-Note: The anonymized IP addresses in the status file do not match those in the log and routes files. However, the anonymized IP addresses are consistent within the status file and across the routes and log files.
-
-Domains
-All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle.
-Reoccuring domain names are replaced with the same anonymized domain.
-
-Network Map
-The network_map.json file contains the following anonymized information:
-- Peer configurations (addresses, FQDNs, DNS settings)
-- Remote and offline peer information (allowed IPs, FQDNs)
-- Routes (network ranges, associated domains)
-- DNS configuration (nameservers, domains, custom zones)
-- Firewall rules (peer IPs, source/destination ranges)
-
-SSH keys in the network map are replaced with a placeholder value. All IP addresses and domains in the network map follow the same anonymization rules as described above.
-
-State File
-The state.json file contains anonymized internal state information of the NetBird client, including:
-- DNS settings and configuration
-- Firewall rules
-- Exclusion routes
-- Route selection
-- Other internal states that may be present
-
-The state file follows the same anonymization rules as other files:
-- IP addresses (both individual and CIDR ranges) are anonymized while preserving their structure
-- Domain names are consistently anonymized
-- Technical identifiers and non-sensitive data remain unchanged
-
-Mutex, Goroutines, and Block Profiling Files
-The goroutine, block, and mutex profiling files contains process information that might help the NetBird team diagnose performance issues. The information in these files don't contain personal data.
-You can check each using the following go command:
-
-go tool pprof -http=:8088 mutex.prof
-
-This will open a web browser tab with the profiling information.
-
-Routes
-For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct.
-
-Network Interfaces
-The interfaces.txt file contains information about network interfaces, including:
-- Interface name
-- Interface index
-- MTU (Maximum Transmission Unit)
-- Flags
-- IP addresses associated with each interface
-
-The IP addresses in the interfaces file are anonymized using the same process as described above. Interface names, indexes, MTUs, and flags are not anonymized.
-
-Configuration
-The config.txt file contains anonymized configuration information of the NetBird client. Sensitive information such as private keys and SSH keys are excluded. The following fields are anonymized:
-- ManagementURL
-- AdminURL
-- NATExternalIPs
-- CustomDNSAddress
-
-Other non-sensitive configuration options are included without anonymization.
-
-Firewall Rules (Linux only)
-The bundle includes two separate firewall rule files:
-
-iptables.txt:
-- Complete iptables ruleset with packet counters using 'iptables -v -n -L'
-- Includes all tables (filter, nat, mangle, raw, security)
-- Shows packet and byte counters for each rule
-- All IP addresses are anonymized
-- Chain names, table names, and other non-sensitive information remain unchanged
-
-nftables.txt:
-- Complete nftables ruleset obtained via 'nft -a list ruleset'
-- Includes rule handle numbers and packet counters
-- All tables, chains, and rules are included
-- Shows packet and byte counters for each rule
-- All IP addresses are anonymized
-- Chain names, table names, and other non-sensitive information remain unchanged
-`
-
-const (
- clientLogFile = "client.log"
- errorLogFile = "netbird.err"
- stdoutLogFile = "netbird.out"
-
- darwinErrorLogPath = "/var/log/netbird.out.log"
- darwinStdoutLogPath = "/var/log/netbird.err.log"
-)
-
// DebugBundle creates a debug bundle and returns the location.
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
s.mutex.Lock()
defer s.mutex.Unlock()
- bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
- if err != nil {
- return nil, fmt.Errorf("create zip file: %w", err)
- }
- defer func() {
- if closeErr := bundlePath.Close(); closeErr != nil && err == nil {
- err = fmt.Errorf("close zip file: %w", closeErr)
- }
-
- if err != nil {
- if removeErr := os.Remove(bundlePath.Name()); removeErr != nil {
- log.Errorf("Failed to remove zip file: %v", removeErr)
- }
- }
- }()
-
- if err := s.createArchive(bundlePath, req); err != nil {
- return nil, err
- }
-
- return &proto.DebugBundleResponse{Path: bundlePath.Name()}, nil
-}
-
-func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleRequest) error {
- archive := zip.NewWriter(bundlePath)
- if err := s.addReadme(req, archive); err != nil {
- return fmt.Errorf("add readme: %w", err)
- }
-
- if err := s.addStatus(req, archive); err != nil {
- return fmt.Errorf("add status: %w", err)
- }
-
- anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
- status := s.statusRecorder.GetFullStatus()
- seedFromStatus(anonymizer, &status)
-
- if err := s.addConfig(req, anonymizer, archive); err != nil {
- log.Errorf("Failed to add config to debug bundle: %v", err)
- }
-
- if req.GetSystemInfo() {
- s.addSystemInfo(req, anonymizer, archive)
- }
-
- if err := s.addProf(req, anonymizer, archive); err != nil {
- log.Errorf("Failed to add goroutines rules to debug bundle: %v", err)
- }
-
- if err := s.addNetworkMap(req, anonymizer, archive); err != nil {
- return fmt.Errorf("add network map: %w", err)
- }
-
- if err := s.addStateFile(req, anonymizer, archive); err != nil {
- log.Errorf("Failed to add state file to debug bundle: %v", err)
- }
-
- if err := s.addCorruptedStateFiles(archive); err != nil {
- log.Errorf("Failed to add corrupted state files to debug bundle: %v", err)
- }
-
- if s.logFile != "console" {
- if err := s.addLogfile(req, anonymizer, archive); err != nil {
- return fmt.Errorf("add log file: %w", err)
- }
- }
-
- if err := archive.Close(); err != nil {
- return fmt.Errorf("close archive writer: %w", err)
- }
- return nil
-}
-
-func (s *Server) addSystemInfo(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) {
- if err := s.addRoutes(req, anonymizer, archive); err != nil {
- log.Errorf("Failed to add routes to debug bundle: %v", err)
- }
-
- if err := s.addInterfaces(req, anonymizer, archive); err != nil {
- log.Errorf("Failed to add interfaces to debug bundle: %v", err)
- }
-
- if err := s.addFirewallRules(req, anonymizer, archive); err != nil {
- log.Errorf("Failed to add firewall rules to debug bundle: %v", err)
- }
-}
-
-func (s *Server) addReadme(req *proto.DebugBundleRequest, archive *zip.Writer) error {
- if req.GetAnonymize() {
- readmeReader := strings.NewReader(readmeContent)
- if err := addFileToZip(archive, readmeReader, "README.txt"); err != nil {
- return fmt.Errorf("add README file to zip: %w", err)
- }
- }
- return nil
-}
-
-func (s *Server) addStatus(req *proto.DebugBundleRequest, archive *zip.Writer) error {
- if status := req.GetStatus(); status != "" {
- statusReader := strings.NewReader(status)
- if err := addFileToZip(archive, statusReader, "status.txt"); err != nil {
- return fmt.Errorf("add status file to zip: %w", err)
- }
- }
- return nil
-}
-
-func (s *Server) addConfig(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
- var configContent strings.Builder
- s.addCommonConfigFields(&configContent)
-
- if req.GetAnonymize() {
- if s.config.ManagementURL != nil {
- configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", anonymizer.AnonymizeURI(s.config.ManagementURL.String())))
- }
- if s.config.AdminURL != nil {
- configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", anonymizer.AnonymizeURI(s.config.AdminURL.String())))
- }
- configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", anonymizeNATExternalIPs(s.config.NATExternalIPs, anonymizer)))
- if s.config.CustomDNSAddress != "" {
- configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", anonymizer.AnonymizeString(s.config.CustomDNSAddress)))
- }
- } else {
- if s.config.ManagementURL != nil {
- configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", s.config.ManagementURL.String()))
- }
- if s.config.AdminURL != nil {
- configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", s.config.AdminURL.String()))
- }
- configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", s.config.NATExternalIPs))
- if s.config.CustomDNSAddress != "" {
- configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", s.config.CustomDNSAddress))
- }
- }
-
- // Add config content to zip file
- configReader := strings.NewReader(configContent.String())
- if err := addFileToZip(archive, configReader, "config.txt"); err != nil {
- return fmt.Errorf("add config file to zip: %w", err)
- }
-
- return nil
-}
-
-func (s *Server) addCommonConfigFields(configContent *strings.Builder) {
- configContent.WriteString("NetBird Client Configuration:\n\n")
-
- // Add non-sensitive fields
- configContent.WriteString(fmt.Sprintf("WgIface: %s\n", s.config.WgIface))
- configContent.WriteString(fmt.Sprintf("WgPort: %d\n", s.config.WgPort))
- if s.config.NetworkMonitor != nil {
- configContent.WriteString(fmt.Sprintf("NetworkMonitor: %v\n", *s.config.NetworkMonitor))
- }
- configContent.WriteString(fmt.Sprintf("IFaceBlackList: %v\n", s.config.IFaceBlackList))
- configContent.WriteString(fmt.Sprintf("DisableIPv6Discovery: %v\n", s.config.DisableIPv6Discovery))
- configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", s.config.RosenpassEnabled))
- configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", s.config.RosenpassPermissive))
- if s.config.ServerSSHAllowed != nil {
- configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *s.config.ServerSSHAllowed))
- }
- configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", s.config.DisableAutoConnect))
- configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", s.config.DNSRouteInterval))
-
- configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", s.config.DisableClientRoutes))
- configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", s.config.DisableServerRoutes))
- configContent.WriteString(fmt.Sprintf("DisableDNS: %v\n", s.config.DisableDNS))
- configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", s.config.DisableFirewall))
-
- configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", s.config.BlockLANAccess))
-}
-
-func (s *Server) addProf(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
- runtime.SetBlockProfileRate(1)
- _ = runtime.SetMutexProfileFraction(1)
- defer runtime.SetBlockProfileRate(0)
- defer runtime.SetMutexProfileFraction(0)
-
- time.Sleep(5 * time.Second)
-
- for _, profile := range []string{"goroutine", "block", "mutex"} {
- var buff []byte
- myBuff := bytes.NewBuffer(buff)
- err := pprof.Lookup(profile).WriteTo(myBuff, 0)
- if err != nil {
- return fmt.Errorf("write %s profile: %w", profile, err)
- }
-
- if err := addFileToZip(archive, myBuff, profile+".prof"); err != nil {
- return fmt.Errorf("add %s file to zip: %w", profile, err)
- }
- }
- return nil
-}
-
-func (s *Server) addRoutes(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
- routes, err := systemops.GetRoutesFromTable()
- if err != nil {
- return fmt.Errorf("get routes: %w", err)
- }
-
- // TODO: get routes including nexthop
- routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer)
- routesReader := strings.NewReader(routesContent)
- if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil {
- return fmt.Errorf("add routes file to zip: %w", err)
- }
- return nil
-}
-
-func (s *Server) addInterfaces(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
- interfaces, err := net.Interfaces()
- if err != nil {
- return fmt.Errorf("get interfaces: %w", err)
- }
-
- interfacesContent := formatInterfaces(interfaces, req.GetAnonymize(), anonymizer)
- interfacesReader := strings.NewReader(interfacesContent)
- if err := addFileToZip(archive, interfacesReader, "interfaces.txt"); err != nil {
- return fmt.Errorf("add interfaces file to zip: %w", err)
- }
-
- return nil
-}
-
-func (s *Server) addNetworkMap(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
networkMap, err := s.getLatestNetworkMap()
if err != nil {
- // Skip if network map is not available, but log it
- log.Debugf("skipping empty network map in debug bundle: %v", err)
- return nil
+ log.Warnf("failed to get latest network map: %v", err)
}
+ bundleGenerator := debug.NewBundleGenerator(
+ debug.GeneratorDependencies{
+ InternalConfig: s.config,
+ StatusRecorder: s.statusRecorder,
+ NetworkMap: networkMap,
+ LogFile: s.logFile,
+ },
+ debug.BundleConfig{
+ Anonymize: req.GetAnonymize(),
+ ClientStatus: req.GetStatus(),
+ IncludeSystemInfo: req.GetSystemInfo(),
+ },
+ )
- if req.GetAnonymize() {
- if err := anonymizeNetworkMap(networkMap, anonymizer); err != nil {
- return fmt.Errorf("anonymize network map: %w", err)
- }
- }
-
- options := protojson.MarshalOptions{
- EmitUnpopulated: true,
- UseProtoNames: true,
- Indent: " ",
- AllowPartial: true,
- }
-
- jsonBytes, err := options.Marshal(networkMap)
+ path, err := bundleGenerator.Generate()
if err != nil {
- return fmt.Errorf("generate json: %w", err)
+ return nil, fmt.Errorf("generate debug bundle: %w", err)
}
- if err := addFileToZip(archive, bytes.NewReader(jsonBytes), "network_map.json"); err != nil {
- return fmt.Errorf("add network map to zip: %w", err)
- }
-
- return nil
-}
-
-func (s *Server) addStateFile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
- path := statemanager.GetDefaultStatePath()
- if path == "" {
- return nil
- }
-
- data, err := os.ReadFile(path)
- if err != nil {
- if errors.Is(err, fs.ErrNotExist) {
- return nil
- }
- return fmt.Errorf("read state file: %w", err)
- }
-
- if req.GetAnonymize() {
- var rawStates map[string]json.RawMessage
- if err := json.Unmarshal(data, &rawStates); err != nil {
- return fmt.Errorf("unmarshal states: %w", err)
- }
-
- if err := anonymizeStateFile(&rawStates, anonymizer); err != nil {
- return fmt.Errorf("anonymize state file: %w", err)
- }
-
- bs, err := json.MarshalIndent(rawStates, "", " ")
- if err != nil {
- return fmt.Errorf("marshal states: %w", err)
- }
- data = bs
- }
-
- if err := addFileToZip(archive, bytes.NewReader(data), "state.json"); err != nil {
- return fmt.Errorf("add state file to zip: %w", err)
- }
-
- return nil
-}
-
-func (s *Server) addCorruptedStateFiles(archive *zip.Writer) error {
- pattern := statemanager.GetDefaultStatePath()
- if pattern == "" {
- return nil
- }
- pattern += "*.corrupted.*"
- matches, err := filepath.Glob(pattern)
- if err != nil {
- return fmt.Errorf("find corrupted state files: %w", err)
- }
-
- for _, match := range matches {
- data, err := os.ReadFile(match)
- if err != nil {
- log.Warnf("Failed to read corrupted state file %s: %v", match, err)
- continue
- }
-
- fileName := filepath.Base(match)
- if err := addFileToZip(archive, bytes.NewReader(data), "corrupted_states/"+fileName); err != nil {
- log.Warnf("Failed to add corrupted state file %s to zip: %v", fileName, err)
- continue
- }
-
- log.Debugf("Added corrupted state file to debug bundle: %s", fileName)
- }
-
- return nil
-}
-
-func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
- logDir := filepath.Dir(s.logFile)
-
- if err := s.addSingleLogfile(s.logFile, clientLogFile, req, anonymizer, archive); err != nil {
- return fmt.Errorf("add client log file to zip: %w", err)
- }
-
- stdErrLogPath := filepath.Join(logDir, errorLogFile)
- stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
- if runtime.GOOS == "darwin" {
- stdErrLogPath = darwinErrorLogPath
- stdoutLogPath = darwinStdoutLogPath
- }
-
- if err := s.addSingleLogfile(stdErrLogPath, errorLogFile, req, anonymizer, archive); err != nil {
- log.Warnf("Failed to add %s to zip: %v", errorLogFile, err)
- }
-
- if err := s.addSingleLogfile(stdoutLogPath, stdoutLogFile, req, anonymizer, archive); err != nil {
- log.Warnf("Failed to add %s to zip: %v", stdoutLogFile, err)
- }
-
- return nil
-}
-
-// addSingleLogfile adds a single log file to the archive
-func (s *Server) addSingleLogfile(logPath, targetName string, req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
- logFile, err := os.Open(logPath)
- if err != nil {
- return fmt.Errorf("open log file %s: %w", targetName, err)
- }
- defer func() {
- if err := logFile.Close(); err != nil {
- log.Errorf("Failed to close log file %s: %v", targetName, err)
- }
- }()
-
- var logReader io.Reader
- if req.GetAnonymize() {
- var writer *io.PipeWriter
- logReader, writer = io.Pipe()
-
- go anonymizeLog(logFile, writer, anonymizer)
- } else {
- logReader = logFile
- }
-
- if err := addFileToZip(archive, logReader, targetName); err != nil {
- return fmt.Errorf("add %s to zip: %w", targetName, err)
- }
-
- return nil
-}
-
-// getLatestNetworkMap returns the latest network map from the engine if network map persistence is enabled
-func (s *Server) getLatestNetworkMap() (*mgmProto.NetworkMap, error) {
- if s.connectClient == nil {
- return nil, errors.New("connect client is not initialized")
- }
-
- engine := s.connectClient.Engine()
- if engine == nil {
- return nil, errors.New("engine is not initialized")
- }
-
- networkMap, err := engine.GetLatestNetworkMap()
- if err != nil {
- return nil, fmt.Errorf("get latest network map: %w", err)
- }
-
- if networkMap == nil {
- return nil, errors.New("network map is not available")
- }
-
- return networkMap, nil
+ return &proto.DebugBundleResponse{Path: path}, nil
}
// GetLogLevel gets the current logging level for the server.
@@ -612,439 +100,12 @@ func (s *Server) SetNetworkMapPersistence(_ context.Context, req *proto.SetNetwo
return &proto.SetNetworkMapPersistenceResponse{}, nil
}
-func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error {
- header := &zip.FileHeader{
- Name: filename,
- Method: zip.Deflate,
- Modified: time.Now(),
-
- CreatorVersion: 20, // Version 2.0
- ReaderVersion: 20, // Version 2.0
- Flags: 0x800, // UTF-8 filename
+// getLatestNetworkMap returns the latest network map from the engine if network map persistence is enabled
+func (s *Server) getLatestNetworkMap() (*mgmProto.NetworkMap, error) {
+ cClient := s.connectClient
+ if cClient == nil {
+ return nil, errors.New("connect client is not initialized")
}
- // If the reader is a file, we can get more accurate information
- if f, ok := reader.(*os.File); ok {
- if stat, err := f.Stat(); err != nil {
- log.Tracef("Failed to get file stat for %s: %v", filename, err)
- } else {
- header.Modified = stat.ModTime()
- }
- }
-
- writer, err := archive.CreateHeader(header)
- if err != nil {
- return fmt.Errorf("create zip file header: %w", err)
- }
-
- if _, err := io.Copy(writer, reader); err != nil {
- return fmt.Errorf("write file to zip: %w", err)
- }
-
- return nil
-}
-
-func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
- status.ManagementState.URL = a.AnonymizeURI(status.ManagementState.URL)
- status.SignalState.URL = a.AnonymizeURI(status.SignalState.URL)
-
- status.LocalPeerState.FQDN = a.AnonymizeDomain(status.LocalPeerState.FQDN)
-
- for _, peer := range status.Peers {
- a.AnonymizeDomain(peer.FQDN)
- for route := range peer.GetRoutes() {
- a.AnonymizeRoute(route)
- }
- }
-
- for route := range status.LocalPeerState.Routes {
- a.AnonymizeRoute(route)
- }
-
- for _, nsGroup := range status.NSGroupStates {
- for _, domain := range nsGroup.Domains {
- a.AnonymizeDomain(domain)
- }
- }
-
- for _, relay := range status.Relays {
- if relay.URI != "" {
- a.AnonymizeURI(relay.URI)
- }
- }
-}
-
-func formatRoutes(routes []netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) string {
- var ipv4Routes, ipv6Routes []netip.Prefix
-
- // Separate IPv4 and IPv6 routes
- for _, route := range routes {
- if route.Addr().Is4() {
- ipv4Routes = append(ipv4Routes, route)
- } else {
- ipv6Routes = append(ipv6Routes, route)
- }
- }
-
- // Sort IPv4 and IPv6 routes separately
- sort.Slice(ipv4Routes, func(i, j int) bool {
- return ipv4Routes[i].Bits() > ipv4Routes[j].Bits()
- })
- sort.Slice(ipv6Routes, func(i, j int) bool {
- return ipv6Routes[i].Bits() > ipv6Routes[j].Bits()
- })
-
- var builder strings.Builder
-
- // Format IPv4 routes
- builder.WriteString("IPv4 Routes:\n")
- for _, route := range ipv4Routes {
- formatRoute(&builder, route, anonymize, anonymizer)
- }
-
- // Format IPv6 routes
- builder.WriteString("\nIPv6 Routes:\n")
- for _, route := range ipv6Routes {
- formatRoute(&builder, route, anonymize, anonymizer)
- }
-
- return builder.String()
-}
-
-func formatRoute(builder *strings.Builder, route netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) {
- if anonymize {
- anonymizedIP := anonymizer.AnonymizeIP(route.Addr())
- builder.WriteString(fmt.Sprintf("%s/%d\n", anonymizedIP, route.Bits()))
- } else {
- builder.WriteString(fmt.Sprintf("%s\n", route))
- }
-}
-
-func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *anonymize.Anonymizer) string {
- sort.Slice(interfaces, func(i, j int) bool {
- return interfaces[i].Name < interfaces[j].Name
- })
-
- var builder strings.Builder
- builder.WriteString("Network Interfaces:\n")
-
- for _, iface := range interfaces {
- builder.WriteString(fmt.Sprintf("\nInterface: %s\n", iface.Name))
- builder.WriteString(fmt.Sprintf(" Index: %d\n", iface.Index))
- builder.WriteString(fmt.Sprintf(" MTU: %d\n", iface.MTU))
- builder.WriteString(fmt.Sprintf(" Flags: %v\n", iface.Flags))
-
- addrs, err := iface.Addrs()
- if err != nil {
- builder.WriteString(fmt.Sprintf(" Addresses: Error retrieving addresses: %v\n", err))
- } else {
- builder.WriteString(" Addresses:\n")
- for _, addr := range addrs {
- prefix, err := netip.ParsePrefix(addr.String())
- if err != nil {
- builder.WriteString(fmt.Sprintf(" Error parsing address: %v\n", err))
- continue
- }
- ip := prefix.Addr()
- if anonymize {
- ip = anonymizer.AnonymizeIP(ip)
- }
- builder.WriteString(fmt.Sprintf(" %s/%d\n", ip, prefix.Bits()))
- }
- }
- }
-
- return builder.String()
-}
-
-func anonymizeLog(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) {
- defer func() {
- // always nil
- _ = writer.Close()
- }()
-
- scanner := bufio.NewScanner(reader)
- for scanner.Scan() {
- line := anonymizer.AnonymizeString(scanner.Text())
- if _, err := writer.Write([]byte(line + "\n")); err != nil {
- writer.CloseWithError(fmt.Errorf("anonymize write: %w", err))
- return
- }
- }
- if err := scanner.Err(); err != nil {
- writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err))
- return
- }
-}
-
-func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string {
- anonymizedIPs := make([]string, len(ips))
- for i, ip := range ips {
- parts := strings.SplitN(ip, "/", 2)
-
- ip1, err := netip.ParseAddr(parts[0])
- if err != nil {
- anonymizedIPs[i] = ip
- continue
- }
- ip1anon := anonymizer.AnonymizeIP(ip1)
-
- if len(parts) == 2 {
- ip2, err := netip.ParseAddr(parts[1])
- if err != nil {
- anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, parts[1])
- } else {
- ip2anon := anonymizer.AnonymizeIP(ip2)
- anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, ip2anon)
- }
- } else {
- anonymizedIPs[i] = ip1anon.String()
- }
- }
- return anonymizedIPs
-}
-
-func anonymizeNetworkMap(networkMap *mgmProto.NetworkMap, anonymizer *anonymize.Anonymizer) error {
- if networkMap.PeerConfig != nil {
- anonymizePeerConfig(networkMap.PeerConfig, anonymizer)
- }
-
- for _, peer := range networkMap.RemotePeers {
- anonymizeRemotePeer(peer, anonymizer)
- }
-
- for _, peer := range networkMap.OfflinePeers {
- anonymizeRemotePeer(peer, anonymizer)
- }
-
- for _, r := range networkMap.Routes {
- anonymizeRoute(r, anonymizer)
- }
-
- if networkMap.DNSConfig != nil {
- anonymizeDNSConfig(networkMap.DNSConfig, anonymizer)
- }
-
- for _, rule := range networkMap.FirewallRules {
- anonymizeFirewallRule(rule, anonymizer)
- }
-
- for _, rule := range networkMap.RoutesFirewallRules {
- anonymizeRouteFirewallRule(rule, anonymizer)
- }
-
- return nil
-}
-
-func anonymizePeerConfig(config *mgmProto.PeerConfig, anonymizer *anonymize.Anonymizer) {
- if config == nil {
- return
- }
-
- if addr, err := netip.ParseAddr(config.Address); err == nil {
- config.Address = anonymizer.AnonymizeIP(addr).String()
- }
-
- if config.SshConfig != nil && len(config.SshConfig.SshPubKey) > 0 {
- config.SshConfig.SshPubKey = []byte("ssh-placeholder-key")
- }
-
- config.Dns = anonymizer.AnonymizeString(config.Dns)
- config.Fqdn = anonymizer.AnonymizeDomain(config.Fqdn)
-}
-
-func anonymizeRemotePeer(peer *mgmProto.RemotePeerConfig, anonymizer *anonymize.Anonymizer) {
- if peer == nil {
- return
- }
-
- for i, ip := range peer.AllowedIps {
- // Try to parse as prefix first (CIDR)
- if prefix, err := netip.ParsePrefix(ip); err == nil {
- anonIP := anonymizer.AnonymizeIP(prefix.Addr())
- peer.AllowedIps[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
- } else if addr, err := netip.ParseAddr(ip); err == nil {
- peer.AllowedIps[i] = anonymizer.AnonymizeIP(addr).String()
- }
- }
-
- peer.Fqdn = anonymizer.AnonymizeDomain(peer.Fqdn)
-
- if peer.SshConfig != nil && len(peer.SshConfig.SshPubKey) > 0 {
- peer.SshConfig.SshPubKey = []byte("ssh-placeholder-key")
- }
-}
-
-func anonymizeRoute(route *mgmProto.Route, anonymizer *anonymize.Anonymizer) {
- if route == nil {
- return
- }
-
- if prefix, err := netip.ParsePrefix(route.Network); err == nil {
- anonIP := anonymizer.AnonymizeIP(prefix.Addr())
- route.Network = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
- }
-
- for i, domain := range route.Domains {
- route.Domains[i] = anonymizer.AnonymizeDomain(domain)
- }
-
- route.NetID = anonymizer.AnonymizeString(route.NetID)
-}
-
-func anonymizeDNSConfig(config *mgmProto.DNSConfig, anonymizer *anonymize.Anonymizer) {
- if config == nil {
- return
- }
-
- anonymizeNameServerGroups(config.NameServerGroups, anonymizer)
- anonymizeCustomZones(config.CustomZones, anonymizer)
-}
-
-func anonymizeNameServerGroups(groups []*mgmProto.NameServerGroup, anonymizer *anonymize.Anonymizer) {
- for _, group := range groups {
- anonymizeServers(group.NameServers, anonymizer)
- anonymizeDomains(group.Domains, anonymizer)
- }
-}
-
-func anonymizeServers(servers []*mgmProto.NameServer, anonymizer *anonymize.Anonymizer) {
- for _, server := range servers {
- if addr, err := netip.ParseAddr(server.IP); err == nil {
- server.IP = anonymizer.AnonymizeIP(addr).String()
- }
- }
-}
-
-func anonymizeDomains(domains []string, anonymizer *anonymize.Anonymizer) {
- for i, domain := range domains {
- domains[i] = anonymizer.AnonymizeDomain(domain)
- }
-}
-
-func anonymizeCustomZones(zones []*mgmProto.CustomZone, anonymizer *anonymize.Anonymizer) {
- for _, zone := range zones {
- zone.Domain = anonymizer.AnonymizeDomain(zone.Domain)
- anonymizeRecords(zone.Records, anonymizer)
- }
-}
-
-func anonymizeRecords(records []*mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) {
- for _, record := range records {
- record.Name = anonymizer.AnonymizeDomain(record.Name)
- anonymizeRData(record, anonymizer)
- }
-}
-
-func anonymizeRData(record *mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) {
- switch record.Type {
- case 1, 28: // A or AAAA record
- if addr, err := netip.ParseAddr(record.RData); err == nil {
- record.RData = anonymizer.AnonymizeIP(addr).String()
- }
- default:
- record.RData = anonymizer.AnonymizeString(record.RData)
- }
-}
-
-func anonymizeFirewallRule(rule *mgmProto.FirewallRule, anonymizer *anonymize.Anonymizer) {
- if rule == nil {
- return
- }
-
- if addr, err := netip.ParseAddr(rule.PeerIP); err == nil {
- rule.PeerIP = anonymizer.AnonymizeIP(addr).String()
- }
-}
-
-func anonymizeRouteFirewallRule(rule *mgmProto.RouteFirewallRule, anonymizer *anonymize.Anonymizer) {
- if rule == nil {
- return
- }
-
- for i, sourceRange := range rule.SourceRanges {
- if prefix, err := netip.ParsePrefix(sourceRange); err == nil {
- anonIP := anonymizer.AnonymizeIP(prefix.Addr())
- rule.SourceRanges[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
- }
- }
-
- if prefix, err := netip.ParsePrefix(rule.Destination); err == nil {
- anonIP := anonymizer.AnonymizeIP(prefix.Addr())
- rule.Destination = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
- }
-}
-
-func anonymizeStateFile(rawStates *map[string]json.RawMessage, anonymizer *anonymize.Anonymizer) error {
- for name, rawState := range *rawStates {
- if string(rawState) == "null" {
- continue
- }
-
- var state map[string]any
- if err := json.Unmarshal(rawState, &state); err != nil {
- return fmt.Errorf("unmarshal state %s: %w", name, err)
- }
-
- state = anonymizeValue(state, anonymizer).(map[string]any)
-
- bs, err := json.Marshal(state)
- if err != nil {
- return fmt.Errorf("marshal state %s: %w", name, err)
- }
-
- (*rawStates)[name] = bs
- }
-
- return nil
-}
-
-func anonymizeValue(value any, anonymizer *anonymize.Anonymizer) any {
- switch v := value.(type) {
- case string:
- return anonymizeString(v, anonymizer)
- case map[string]any:
- return anonymizeMap(v, anonymizer)
- case []any:
- return anonymizeSlice(v, anonymizer)
- }
- return value
-}
-
-func anonymizeString(v string, anonymizer *anonymize.Anonymizer) string {
- if prefix, err := netip.ParsePrefix(v); err == nil {
- anonIP := anonymizer.AnonymizeIP(prefix.Addr())
- return fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
- }
- if ip, err := netip.ParseAddr(v); err == nil {
- return anonymizer.AnonymizeIP(ip).String()
- }
- return anonymizer.AnonymizeString(v)
-}
-
-func anonymizeMap(v map[string]any, anonymizer *anonymize.Anonymizer) map[string]any {
- result := make(map[string]any, len(v))
- for key, val := range v {
- newKey := anonymizeMapKey(key, anonymizer)
- result[newKey] = anonymizeValue(val, anonymizer)
- }
- return result
-}
-
-func anonymizeMapKey(key string, anonymizer *anonymize.Anonymizer) string {
- if prefix, err := netip.ParsePrefix(key); err == nil {
- anonIP := anonymizer.AnonymizeIP(prefix.Addr())
- return fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
- }
- if ip, err := netip.ParseAddr(key); err == nil {
- return anonymizer.AnonymizeIP(ip).String()
- }
- return key
-}
-
-func anonymizeSlice(v []any, anonymizer *anonymize.Anonymizer) []any {
- for i, val := range v {
- v[i] = anonymizeValue(val, anonymizer)
- }
- return v
+ return cClient.GetLatestNetworkMap()
}
diff --git a/client/server/debug_nonlinux.go b/client/server/debug_nonlinux.go
deleted file mode 100644
index c54ac9b6e..000000000
--- a/client/server/debug_nonlinux.go
+++ /dev/null
@@ -1,15 +0,0 @@
-//go:build !linux || android
-
-package server
-
-import (
- "archive/zip"
-
- "github.com/netbirdio/netbird/client/anonymize"
- "github.com/netbirdio/netbird/client/proto"
-)
-
-// collectFirewallRules returns nothing on non-linux systems
-func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
- return nil
-}
diff --git a/client/server/server_test.go b/client/server/server_test.go
index 8ee8294cf..f2dff76fd 100644
--- a/client/server/server_test.go
+++ b/client/server/server_test.go
@@ -8,10 +8,11 @@ import (
"time"
"github.com/golang/mock/gomock"
- "github.com/netbirdio/management-integrations/integrations"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
+ "github.com/netbirdio/management-integrations/integrations"
+
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
@@ -200,10 +201,10 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
- permissionsManagerMock := permissions.NewManagerMock()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
+ permissionsManagerMock := permissions.NewMockManager(ctrl)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
if err != nil {
diff --git a/flow/proto/flow.pb.go b/flow/proto/flow.pb.go
index 9d82a6e5a..a1b668cdc 100644
--- a/flow/proto/flow.pb.go
+++ b/flow/proto/flow.pb.go
@@ -569,7 +569,7 @@ var file_flow_flow_proto_rawDesc = []byte{
0x0a, 0x0f, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x12, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61,
- 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xb2, 0x01, 0x0a, 0x09, 0x46, 0x6c, 0x6f,
+ 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xd4, 0x01, 0x0a, 0x09, 0x46, 0x6c, 0x6f,
0x77, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x5f,
0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x49,
0x64, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x02,
@@ -580,67 +580,71 @@ var file_flow_flow_proto_rawDesc = []byte{
0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x31, 0x0a, 0x0b, 0x66, 0x6c,
0x6f, 0x77, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32,
0x10, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64,
- 0x73, 0x52, 0x0a, 0x66, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x22, 0x29, 0x0a,
- 0x0c, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x12, 0x19, 0x0a,
- 0x08, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52,
- 0x07, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x22, 0x9c, 0x04, 0x0a, 0x0a, 0x46, 0x6c, 0x6f,
- 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x12, 0x17, 0x0a, 0x07, 0x66, 0x6c, 0x6f, 0x77, 0x5f,
- 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x66, 0x6c, 0x6f, 0x77, 0x49, 0x64,
- 0x12, 0x1e, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0a,
- 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65,
- 0x12, 0x17, 0x0a, 0x07, 0x72, 0x75, 0x6c, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28,
- 0x0c, 0x52, 0x06, 0x72, 0x75, 0x6c, 0x65, 0x49, 0x64, 0x12, 0x2d, 0x0a, 0x09, 0x64, 0x69, 0x72,
- 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0f, 0x2e, 0x66,
- 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x64,
- 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74,
- 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74,
- 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69,
- 0x70, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49,
- 0x70, 0x12, 0x17, 0x0a, 0x07, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x70, 0x18, 0x07, 0x20, 0x01,
- 0x28, 0x0c, 0x52, 0x06, 0x64, 0x65, 0x73, 0x74, 0x49, 0x70, 0x12, 0x2d, 0x0a, 0x09, 0x70, 0x6f,
- 0x72, 0x74, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e,
- 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x48, 0x00, 0x52,
- 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x2d, 0x0a, 0x09, 0x69, 0x63, 0x6d,
- 0x70, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x66,
- 0x6c, 0x6f, 0x77, 0x2e, 0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f, 0x48, 0x00, 0x52, 0x08,
- 0x69, 0x63, 0x6d, 0x70, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x78, 0x5f, 0x70,
- 0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x04, 0x52, 0x09, 0x72, 0x78,
- 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x78, 0x5f, 0x70, 0x61,
- 0x63, 0x6b, 0x65, 0x74, 0x73, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x04, 0x52, 0x09, 0x74, 0x78, 0x50,
- 0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x12, 0x19, 0x0a, 0x08, 0x72, 0x78, 0x5f, 0x62, 0x79, 0x74,
- 0x65, 0x73, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x04, 0x52, 0x07, 0x72, 0x78, 0x42, 0x79, 0x74, 0x65,
- 0x73, 0x12, 0x19, 0x0a, 0x08, 0x74, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0d, 0x20,
- 0x01, 0x28, 0x04, 0x52, 0x07, 0x74, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x2c, 0x0a, 0x12,
- 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f,
- 0x69, 0x64, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x10, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65,
- 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x64, 0x12, 0x28, 0x0a, 0x10, 0x64, 0x65,
- 0x73, 0x74, 0x5f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x0f,
- 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0e, 0x64, 0x65, 0x73, 0x74, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72,
- 0x63, 0x65, 0x49, 0x64, 0x42, 0x11, 0x0a, 0x0f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69,
- 0x6f, 0x6e, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x22, 0x48, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49,
- 0x6e, 0x66, 0x6f, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x70, 0x6f,
- 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65,
- 0x50, 0x6f, 0x72, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x6f, 0x72,
- 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x64, 0x65, 0x73, 0x74, 0x50, 0x6f, 0x72,
- 0x74, 0x22, 0x44, 0x0a, 0x08, 0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1b, 0x0a,
- 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d,
- 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x63,
- 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x69,
- 0x63, 0x6d, 0x70, 0x43, 0x6f, 0x64, 0x65, 0x2a, 0x45, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12,
- 0x10, 0x0a, 0x0c, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10,
- 0x00, 0x12, 0x0e, 0x0a, 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x54, 0x41, 0x52, 0x54, 0x10,
- 0x01, 0x12, 0x0c, 0x0a, 0x08, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x45, 0x4e, 0x44, 0x10, 0x02, 0x12,
- 0x0d, 0x0a, 0x09, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x03, 0x2a, 0x3b,
- 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x15, 0x0a, 0x11, 0x44,
- 0x49, 0x52, 0x45, 0x43, 0x54, 0x49, 0x4f, 0x4e, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e,
- 0x10, 0x00, 0x12, 0x0b, 0x0a, 0x07, 0x49, 0x4e, 0x47, 0x52, 0x45, 0x53, 0x53, 0x10, 0x01, 0x12,
- 0x0a, 0x0a, 0x06, 0x45, 0x47, 0x52, 0x45, 0x53, 0x53, 0x10, 0x02, 0x32, 0x42, 0x0a, 0x0b, 0x46,
- 0x6c, 0x6f, 0x77, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x06, 0x45, 0x76,
- 0x65, 0x6e, 0x74, 0x73, 0x12, 0x0f, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77,
- 0x45, 0x76, 0x65, 0x6e, 0x74, 0x1a, 0x12, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x6c, 0x6f,
- 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42,
- 0x0c, 0x5a, 0x0a, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70,
- 0x72, 0x6f, 0x74, 0x6f, 0x33,
+ 0x73, 0x52, 0x0a, 0x66, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x12, 0x20, 0x0a,
+ 0x0b, 0x69, 0x73, 0x49, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x18, 0x05, 0x20, 0x01,
+ 0x28, 0x08, 0x52, 0x0b, 0x69, 0x73, 0x49, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x22,
+ 0x4b, 0x0a, 0x0c, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x12,
+ 0x19, 0x0a, 0x08, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28,
+ 0x0c, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x20, 0x0a, 0x0b, 0x69, 0x73,
+ 0x49, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52,
+ 0x0b, 0x69, 0x73, 0x49, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x22, 0x9c, 0x04, 0x0a,
+ 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x12, 0x17, 0x0a, 0x07, 0x66,
+ 0x6c, 0x6f, 0x77, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x66, 0x6c,
+ 0x6f, 0x77, 0x49, 0x64, 0x12, 0x1e, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01,
+ 0x28, 0x0e, 0x32, 0x0a, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04,
+ 0x74, 0x79, 0x70, 0x65, 0x12, 0x17, 0x0a, 0x07, 0x72, 0x75, 0x6c, 0x65, 0x5f, 0x69, 0x64, 0x18,
+ 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x72, 0x75, 0x6c, 0x65, 0x49, 0x64, 0x12, 0x2d, 0x0a,
+ 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e,
+ 0x32, 0x0f, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f,
+ 0x6e, 0x52, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1a, 0x0a, 0x08,
+ 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08,
+ 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72,
+ 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x73, 0x6f, 0x75,
+ 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, 0x17, 0x0a, 0x07, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x70,
+ 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x64, 0x65, 0x73, 0x74, 0x49, 0x70, 0x12, 0x2d,
+ 0x0a, 0x09, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x08, 0x20, 0x01, 0x28,
+ 0x0b, 0x32, 0x0e, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66,
+ 0x6f, 0x48, 0x00, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x2d, 0x0a,
+ 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0b,
+ 0x32, 0x0e, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f,
+ 0x48, 0x00, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1d, 0x0a, 0x0a,
+ 0x72, 0x78, 0x5f, 0x70, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x04,
+ 0x52, 0x09, 0x72, 0x78, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x74,
+ 0x78, 0x5f, 0x70, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x04, 0x52,
+ 0x09, 0x74, 0x78, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x12, 0x19, 0x0a, 0x08, 0x72, 0x78,
+ 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x04, 0x52, 0x07, 0x72, 0x78,
+ 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x19, 0x0a, 0x08, 0x74, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65,
+ 0x73, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x04, 0x52, 0x07, 0x74, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73,
+ 0x12, 0x2c, 0x0a, 0x12, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x72, 0x65, 0x73, 0x6f, 0x75,
+ 0x72, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x10, 0x73, 0x6f,
+ 0x75, 0x72, 0x63, 0x65, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x64, 0x12, 0x28,
+ 0x0a, 0x10, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f,
+ 0x69, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0e, 0x64, 0x65, 0x73, 0x74, 0x52, 0x65,
+ 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x64, 0x42, 0x11, 0x0a, 0x0f, 0x63, 0x6f, 0x6e, 0x6e,
+ 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x22, 0x48, 0x0a, 0x08, 0x50,
+ 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63,
+ 0x65, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f,
+ 0x75, 0x72, 0x63, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x65, 0x73, 0x74,
+ 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x64, 0x65, 0x73,
+ 0x74, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x44, 0x0a, 0x08, 0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66,
+ 0x6f, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01,
+ 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b,
+ 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28,
+ 0x0d, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x43, 0x6f, 0x64, 0x65, 0x2a, 0x45, 0x0a, 0x04, 0x54,
+ 0x79, 0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x4e, 0x4b, 0x4e,
+ 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0e, 0x0a, 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x54,
+ 0x41, 0x52, 0x54, 0x10, 0x01, 0x12, 0x0c, 0x0a, 0x08, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x45, 0x4e,
+ 0x44, 0x10, 0x02, 0x12, 0x0d, 0x0a, 0x09, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x44, 0x52, 0x4f, 0x50,
+ 0x10, 0x03, 0x2a, 0x3b, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12,
+ 0x15, 0x0a, 0x11, 0x44, 0x49, 0x52, 0x45, 0x43, 0x54, 0x49, 0x4f, 0x4e, 0x5f, 0x55, 0x4e, 0x4b,
+ 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0b, 0x0a, 0x07, 0x49, 0x4e, 0x47, 0x52, 0x45, 0x53,
+ 0x53, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x45, 0x47, 0x52, 0x45, 0x53, 0x53, 0x10, 0x02, 0x32,
+ 0x42, 0x0a, 0x0b, 0x46, 0x6c, 0x6f, 0x77, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33,
+ 0x0a, 0x06, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x0f, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e,
+ 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x1a, 0x12, 0x2e, 0x66, 0x6c, 0x6f, 0x77,
+ 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x22, 0x00, 0x28,
+ 0x01, 0x30, 0x01, 0x42, 0x0c, 0x5a, 0x0a, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x70, 0x72, 0x6f, 0x74,
+ 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
diff --git a/infrastructure_files/base.setup.env b/infrastructure_files/base.setup.env
index 45dce8d88..4b1376921 100644
--- a/infrastructure_files/base.setup.env
+++ b/infrastructure_files/base.setup.env
@@ -58,6 +58,7 @@ NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken}
# PKCE authorization flow
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"}
NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false}
+NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=${NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN:-false}
NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
# Dashboard
@@ -120,6 +121,7 @@ export NETBIRD_AUTH_DEVICE_AUTH_SCOPE
export NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
export NETBIRD_AUTH_PKCE_USE_ID_TOKEN
+export NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN
export NETBIRD_AUTH_PKCE_AUDIENCE
export NETBIRD_DASH_AUTH_USE_AUDIENCE
export NETBIRD_DASH_AUTH_AUDIENCE
diff --git a/infrastructure_files/management.json.tmpl b/infrastructure_files/management.json.tmpl
index 5cbf2b4d3..aa1739c61 100644
--- a/infrastructure_files/management.json.tmpl
+++ b/infrastructure_files/management.json.tmpl
@@ -94,7 +94,8 @@
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
"Scope": "$NETBIRD_AUTH_SUPPORTED_SCOPES",
"RedirectURLs": [$NETBIRD_AUTH_PKCE_REDIRECT_URLS],
- "UseIDToken": $NETBIRD_AUTH_PKCE_USE_ID_TOKEN
+ "UseIDToken": $NETBIRD_AUTH_PKCE_USE_ID_TOKEN,
+ "DisablePromptLogin": $NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN
}
}
}
diff --git a/infrastructure_files/tests/setup.env b/infrastructure_files/tests/setup.env
index 5d774fbd1..2945e1c43 100644
--- a/infrastructure_files/tests/setup.env
+++ b/infrastructure_files/tests/setup.env
@@ -27,3 +27,4 @@ NETBIRD_STORE_CONFIG_ENGINE=$CI_NETBIRD_STORE_CONFIG_ENGINE
NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=$CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH
NETBIRD_TURN_EXTERNAL_IP=1.2.3.4
NETBIRD_RELAY_PORT=33445
+NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=true
diff --git a/management/client/client_test.go b/management/client/client_test.go
index 6c30ff371..b22a79930 100644
--- a/management/client/client_test.go
+++ b/management/client/client_test.go
@@ -75,7 +75,6 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
- permissionsManagerMock := permissions.NewManagerMock()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
@@ -88,6 +87,18 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
).
Return(&types.Settings{}, nil).
AnyTimes()
+ permissionsManagerMock := permissions.NewMockManager(ctrl)
+ permissionsManagerMock.
+ EXPECT().
+ ValidateUserPermissions(
+ gomock.Any(),
+ gomock.Any(),
+ gomock.Any(),
+ gomock.Any(),
+ gomock.Any(),
+ ).
+ Return(true, nil).
+ AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
if err != nil {
diff --git a/management/client/rest/users.go b/management/client/rest/users.go
index 372bcee45..31ffad051 100644
--- a/management/client/rest/users.go
+++ b/management/client/rest/users.go
@@ -80,3 +80,16 @@ func (a *UsersAPI) ResendInvitation(ctx context.Context, userID string) error {
return nil
}
+
+// Current gets the current user info
+// See more: https://docs.netbird.io/api/resources/users#retrieve-current-user
+func (a *UsersAPI) Current(ctx context.Context) (*api.User, error) {
+ resp, err := a.c.newRequest(ctx, "GET", "/api/users/current", nil)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ ret, err := parseResponse[api.User](resp)
+ return &ret, err
+}
diff --git a/management/client/rest/users_test.go b/management/client/rest/users_test.go
index 2ff8a0327..f68c5f083 100644
--- a/management/client/rest/users_test.go
+++ b/management/client/rest/users_test.go
@@ -196,8 +196,42 @@ func TestUsers_ResendInvitation_Err(t *testing.T) {
})
}
+func TestUsers_Current_200(t *testing.T) {
+ withMockClient(func(c *rest.Client, mux *http.ServeMux) {
+ mux.HandleFunc("/api/users/current", func(w http.ResponseWriter, r *http.Request) {
+ retBytes, _ := json.Marshal(testUser)
+ _, err := w.Write(retBytes)
+ require.NoError(t, err)
+ })
+ ret, err := c.Users.Current(context.Background())
+ require.NoError(t, err)
+ assert.Equal(t, testUser, *ret)
+ })
+}
+
+func TestUsers_Current_Err(t *testing.T) {
+ withMockClient(func(c *rest.Client, mux *http.ServeMux) {
+ mux.HandleFunc("/api/users/current", func(w http.ResponseWriter, r *http.Request) {
+ retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
+ w.WriteHeader(400)
+ _, err := w.Write(retBytes)
+ require.NoError(t, err)
+ })
+ ret, err := c.Users.Current(context.Background())
+ assert.Error(t, err)
+ assert.Equal(t, "No", err.Error())
+ assert.Empty(t, ret)
+ })
+}
+
func TestUsers_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
+ // rest client PAT is owner's
+ current, err := c.Users.Current(context.Background())
+ require.NoError(t, err)
+ assert.Equal(t, "a23efe53-63fb-11ec-90d6-0242ac120003", current.Id)
+ assert.Equal(t, "owner", current.Role)
+
user, err := c.Users.Create(context.Background(), api.UserCreateRequest{
AutoGroups: []string{},
Email: ptr("test@example.com"),
diff --git a/management/proto/management.pb.go b/management/proto/management.pb.go
index aaba56e82..ced8da7fb 100644
--- a/management/proto/management.pb.go
+++ b/management/proto/management.pb.go
@@ -2140,6 +2140,8 @@ type ProviderConfig struct {
AuthorizationEndpoint string `protobuf:"bytes,9,opt,name=AuthorizationEndpoint,proto3" json:"AuthorizationEndpoint,omitempty"`
// RedirectURLs handles authorization code from IDP manager
RedirectURLs []string `protobuf:"bytes,10,rep,name=RedirectURLs,proto3" json:"RedirectURLs,omitempty"`
+ // DisablePromptLogin makes the PKCE flow to not prompt the user for login
+ DisablePromptLogin bool `protobuf:"varint,11,opt,name=DisablePromptLogin,proto3" json:"DisablePromptLogin,omitempty"`
}
func (x *ProviderConfig) Reset() {
@@ -2242,6 +2244,13 @@ func (x *ProviderConfig) GetRedirectURLs() []string {
return nil
}
+func (x *ProviderConfig) GetDisablePromptLogin() bool {
+ if x != nil {
+ return x.DisablePromptLogin
+ }
+ return false
+}
+
// Route represents a route.Route object
type Route struct {
state protoimpl.MessageState
@@ -3499,7 +3508,7 @@ var file_management_management_proto_rawDesc = []byte{
0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65,
0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e,
0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e,
- 0x66, 0x69, 0x67, 0x22, 0xea, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72,
+ 0x66, 0x69, 0x67, 0x22, 0x9a, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72,
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74,
0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74,
0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72,
@@ -3522,6 +3531,9 @@ var file_management_management_proto_rawDesc = []byte{
0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c,
0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03,
0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73,
+ 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70,
+ 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x44, 0x69,
+ 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e,
0x22, 0xed, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44,
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65,
0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74,
diff --git a/management/server/account.go b/management/server/account.go
index 0b52df2f0..d7f108dfe 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -6,11 +6,14 @@ import (
"fmt"
"math/rand"
"net"
+ "os"
"reflect"
"regexp"
"slices"
+ "strconv"
"strings"
"sync"
+ "sync/atomic"
"time"
cacheStore "github.com/eko/gocache/lib/v4/store"
@@ -30,6 +33,8 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/status"
@@ -92,6 +97,9 @@ type DefaultAccountManager struct {
metrics telemetry.AppMetrics
permissionsManager permissions.Manager
+
+ accountUpdateLocks sync.Map
+ updateAccountPeersBufferInterval atomic.Int64
}
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
@@ -186,6 +194,23 @@ func BuildManager(
settingsManager: settingsManager,
permissionsManager: permissionsManager,
}
+
+ var initialInterval int64
+ intervalStr := os.Getenv("PEER_UPDATE_INTERVAL_MS")
+ interval, err := strconv.Atoi(intervalStr)
+ if err != nil {
+ initialInterval = 1
+ } else {
+ initialInterval = int64(interval) * 10
+ go func() {
+ time.Sleep(30 * time.Second)
+ am.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond))
+ log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval)
+ }()
+ }
+ am.updateAccountPeersBufferInterval.Store(initialInterval)
+ log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval)
+
accountsCounter, err := store.GetAccountsCounter(ctx)
if err != nil {
log.WithContext(ctx).Error(err)
@@ -258,7 +283,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, err
}
- allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Settings, permissions.Write)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
@@ -508,7 +533,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return err
}
- allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Accounts, permissions.Write)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Delete)
if err != nil {
return fmt.Errorf("failed to validate user permissions: %w", err)
}
@@ -1021,13 +1046,12 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
// GetAccountByID returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccount(ctx, accountID)
@@ -1223,7 +1247,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
if removedGroupAffectsPeers || newGroupsAffectsPeers {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
- am.UpdateAccountPeers(ctx, userAuth.AccountId)
+ am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
}
}
@@ -1462,7 +1486,7 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
- am.UpdateAccountPeers(ctx, accountID)
+ am.BufferUpdateAccountPeers(ctx, accountID)
}
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
@@ -1515,19 +1539,13 @@ func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.St
}
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
-
- if !user.HasAdminPower() && !user.IsServiceUser {
- return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
- }
-
return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
}
diff --git a/management/server/account/manager.go b/management/server/account/manager.go
index 807d05067..ea664d10e 100644
--- a/management/server/account/manager.go
+++ b/management/server/account/manager.go
@@ -59,15 +59,15 @@ type Manager interface {
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
- SaveGroup(ctx context.Context, accountID, userID string, group *types.Group) error
- SaveGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error
+ SaveGroup(ctx context.Context, accountID, userID string, group *types.Group, create bool) error
+ SaveGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group, create bool) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error)
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error)
- SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error)
+ SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error)
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
@@ -93,7 +93,7 @@ type Manager interface {
HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
- SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
+ SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error)
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManager() idp.Manager
@@ -114,4 +114,5 @@ type Manager interface {
CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error)
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error)
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
+ GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error)
}
diff --git a/management/server/account_test.go b/management/server/account_test.go
index 49a7464e3..7f34cf845 100644
--- a/management/server/account_test.go
+++ b/management/server/account_test.go
@@ -1115,7 +1115,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
Name: "GroupA",
Peers: []string{},
}
- if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
+ if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1131,7 +1131,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
Action: types.PolicyTrafficActionAccept,
},
},
- })
+ }, true)
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -1150,7 +1150,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
}()
group.Peers = []string{peer1.ID, peer2.ID, peer3.ID}
- if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
+ if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1192,7 +1192,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
}
- if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
+ if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1223,7 +1223,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
Action: types.PolicyTrafficActionAccept,
},
},
- })
+ }, true)
if err != nil {
t.Errorf("delete default rule: %v", err)
return
@@ -1240,7 +1240,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
Name: "GroupA",
Peers: []string{peer1.ID, peer3.ID},
}
- if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
+ if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1256,7 +1256,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
Action: types.PolicyTrafficActionAccept,
},
},
- })
+ }, true)
if err != nil {
t.Errorf("save policy: %v", err)
return
@@ -1295,7 +1295,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
- })
+ }, true)
require.NoError(t, err, "failed to save group")
@@ -1315,7 +1315,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
Action: types.PolicyTrafficActionAccept,
},
},
- })
+ }, true)
if err != nil {
t.Errorf("save policy: %v", err)
return
@@ -2794,13 +2794,13 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
})
}
-//type TB interface {
+// type TB interface {
// Cleanup(func())
// Helper()
// TempDir() string
// Errorf(format string, args ...interface{})
// Fatalf(format string, args ...interface{})
-//}
+// }
func createManager(t testing.TB) (*DefaultAccountManager, error) {
t.Helper()
@@ -2816,8 +2816,6 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
return nil, err
}
- permissionsManagerMock := permissions.NewManagerMock()
-
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
@@ -2831,7 +2829,9 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
Return(false, nil).
AnyTimes()
- manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
+ permissionsManager := permissions.NewManager(store)
+
+ manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
if err != nil {
return nil, err
}
diff --git a/management/server/dns.go b/management/server/dns.go
index 8dcc59413..a3f32c2a9 100644
--- a/management/server/dns.go
+++ b/management/server/dns.go
@@ -10,6 +10,8 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -62,17 +64,12 @@ func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerG
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- if user.IsRegularUser() {
- return nil, status.NewAdminPermissionError()
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID)
@@ -84,17 +81,12 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
}
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
- return err
+ return status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return err
- }
-
- if !user.HasAdminPower() {
- return status.NewAdminPermissionError()
+ if !allowed {
+ return status.NewPermissionDeniedError()
}
var updateAccountPeers bool
diff --git a/management/server/dns_test.go b/management/server/dns_test.go
index aeccc6187..36476b14c 100644
--- a/management/server/dns_test.go
+++ b/management/server/dns_test.go
@@ -211,14 +211,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
- permissionsManagerMock := permissions.NewManagerMock()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
-
- return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
+ permissionsManager := permissions.NewManager(store)
+ return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
}
func createDNSStore(t *testing.T) (store.Store, error) {
@@ -505,7 +504,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
Name: "GroupB",
Peers: []string{},
},
- })
+ }, true)
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -565,7 +564,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
- })
+ }, true)
assert.NoError(t, err)
done := make(chan struct{})
diff --git a/management/server/event.go b/management/server/event.go
index 58c6c70fb..6342bfedb 100644
--- a/management/server/event.go
+++ b/management/server/event.go
@@ -9,6 +9,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -21,17 +23,12 @@ func isEnabled() bool {
// GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Events, operations.Read)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- if !(user.HasAdminPower() || user.IsServiceUser) {
- return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view events")
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true)
@@ -56,7 +53,7 @@ func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userI
filtered = append(filtered, event)
}
- err = am.fillEventsWithUserInfo(ctx, events, accountID, user)
+ err = am.fillEventsWithUserInfo(ctx, events, accountID, userID)
if err != nil {
return nil, err
}
@@ -89,8 +86,8 @@ type eventUserInfo struct {
accountId string
}
-func (am *DefaultAccountManager) fillEventsWithUserInfo(ctx context.Context, events []*activity.Event, accountId string, user *types.User) error {
- eventUserInfo, err := am.getEventsUserInfo(ctx, events, accountId, user)
+func (am *DefaultAccountManager) fillEventsWithUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) error {
+ eventUserInfo, err := am.getEventsUserInfo(ctx, events, accountId, userId)
if err != nil {
return err
}
@@ -105,14 +102,14 @@ func (am *DefaultAccountManager) fillEventsWithUserInfo(ctx context.Context, eve
return nil
}
-func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, user *types.User) (map[string]eventUserInfo, error) {
+func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountId)
if err != nil {
return nil, err
}
// @note check whether using a external initiator user here is an issue
- userInfos, err := am.BuildUserInfosForAccount(ctx, accountId, user.Id, accountUsers)
+ userInfos, err := am.BuildUserInfosForAccount(ctx, accountId, userId, accountUsers)
if err != nil {
return nil, err
}
@@ -146,10 +143,10 @@ func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events [
return eventUserInfos, nil
}
- return am.getEventsExternalUserInfo(ctx, externalUserIds, eventUserInfos, user)
+ return am.getEventsExternalUserInfo(ctx, externalUserIds, eventUserInfos, userId)
}
-func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, externalUserIds []string, eventUserInfos map[string]eventUserInfo, user *types.User) (map[string]eventUserInfo, error) {
+func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, externalUserIds []string, eventUserInfos map[string]eventUserInfo, userId string) (map[string]eventUserInfo, error) {
externalAccountId := ""
fetched := make(map[string]struct{})
externalUsers := []*types.User{}
@@ -182,7 +179,7 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context,
return eventUserInfos, nil
}
- externalUserInfos, err := am.BuildUserInfosForAccount(ctx, externalAccountId, user.Id, externalUsers)
+ externalUserInfos, err := am.BuildUserInfosForAccount(ctx, externalAccountId, userId, externalUsers)
if err != nil {
return nil, err
}
diff --git a/management/server/group.go b/management/server/group.go
index 01ebb457c..0bd840798 100644
--- a/management/server/group.go
+++ b/management/server/group.go
@@ -12,6 +12,8 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -30,17 +32,13 @@ func (e *GroupLinkError) Error() string {
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
if err != nil {
return err
}
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return err
- }
-
- if user.IsRegularUser() {
- return status.NewAdminPermissionError()
+ if !allowed {
+ return status.NewPermissionDeniedError()
}
return nil
@@ -68,27 +66,26 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
}
// SaveGroup object of the peers
-func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
+func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *types.Group, create bool) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- return am.SaveGroups(ctx, accountID, userID, []*types.Group{newGroup})
+ return am.SaveGroups(ctx, accountID, userID, []*types.Group{newGroup}, create)
}
// SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
-func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error {
+ operation := operations.Create
+ if !create {
+ operation = operations.Update
+ }
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operation)
if err != nil {
- return err
+ return status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return err
- }
-
- if user.IsRegularUser() {
- return status.NewAdminPermissionError()
+ if !allowed {
+ return status.NewPermissionDeniedError()
}
var eventsToStore []func()
@@ -210,17 +207,12 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Delete)
if err != nil {
- return err
+ return status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return err
- }
-
- if user.IsRegularUser() {
- return status.NewAdminPermissionError()
+ if !allowed {
+ return status.NewPermissionDeniedError()
}
var allErrors error
diff --git a/management/server/group_test.go b/management/server/group_test.go
index dffaa80e3..4966f2b33 100644
--- a/management/server/group_test.go
+++ b/management/server/group_test.go
@@ -40,7 +40,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
}
for _, group := range account.Groups {
group.Issued = types.GroupIssuedIntegration
- err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
+ err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true)
if err != nil {
t.Errorf("should allow to create %s groups", types.GroupIssuedIntegration)
}
@@ -48,7 +48,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
for _, group := range account.Groups {
group.Issued = types.GroupIssuedJWT
- err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
+ err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true)
if err != nil {
t.Errorf("should allow to create %s groups", types.GroupIssuedJWT)
}
@@ -56,7 +56,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
for _, group := range account.Groups {
group.Issued = types.GroupIssuedAPI
group.ID = ""
- err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
+ err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true)
if err == nil {
t.Errorf("should not create api group with the same name, %s", group.Name)
}
@@ -162,7 +162,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
}
}
- err = manager.SaveGroups(context.Background(), account.Id, groupAdminUserID, groups)
+ err = manager.SaveGroups(context.Background(), account.Id, groupAdminUserID, groups, true)
assert.NoError(t, err, "Failed to save test groups")
testCases := []struct {
@@ -382,13 +382,13 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
return nil, nil, err
}
- _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
- _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2)
- _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups)
- _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies)
- _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys)
- _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
- _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
+ _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute, true)
+ _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2, true)
+ _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups, true)
+ _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies, true)
+ _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys, true)
+ _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers, true)
+ _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration, true)
acc, err := am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
@@ -426,7 +426,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
Name: "GroupE",
Peers: []string{peer2.ID},
},
- })
+ }, true)
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -446,7 +446,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
ID: "groupB",
Name: "GroupB",
Peers: []string{peer1.ID, peer2.ID},
- })
+ }, true)
assert.NoError(t, err)
select {
@@ -524,7 +524,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
Action: types.PolicyTrafficActionAccept,
},
},
- })
+ }, true)
assert.NoError(t, err)
// Saving a group linked to policy should update account peers and send peer update
@@ -539,7 +539,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
- })
+ }, true)
assert.NoError(t, err)
select {
@@ -608,7 +608,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
ID: "groupC",
Name: "GroupC",
Peers: []string{peer1.ID, peer3.ID},
- })
+ }, true)
assert.NoError(t, err)
select {
@@ -649,7 +649,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
- })
+ }, true)
assert.NoError(t, err)
select {
@@ -676,7 +676,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
ID: "groupD",
Name: "GroupD",
Peers: []string{peer1.ID},
- })
+ }, true)
assert.NoError(t, err)
select {
@@ -723,7 +723,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
ID: "groupE",
Name: "GroupE",
Peers: []string{peer2.ID, peer3.ID},
- })
+ }, true)
assert.NoError(t, err)
select {
diff --git a/management/server/groups/manager.go b/management/server/groups/manager.go
index 27698a085..df4b6c3d6 100644
--- a/management/server/groups/manager.go
+++ b/management/server/groups/manager.go
@@ -8,6 +8,8 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/permissions"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
@@ -39,7 +41,7 @@ func NewManager(store store.Store, permissionsManager permissions.Manager, accou
}
func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Groups, permissions.Read)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
if err != nil {
return nil, err
}
@@ -70,7 +72,7 @@ func (m *managerImpl) GetAllGroupsMap(ctx context.Context, accountID, userID str
}
func (m *managerImpl) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resource *types.Resource) error {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Groups, permissions.Write)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return err
}
diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go
index dba5ab13b..a7ed639c3 100644
--- a/management/server/grpcserver.go
+++ b/management/server/grpcserver.go
@@ -828,6 +828,7 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
Scope: s.config.PKCEAuthorizationFlow.ProviderConfig.Scope,
RedirectURLs: s.config.PKCEAuthorizationFlow.ProviderConfig.RedirectURLs,
UseIDToken: s.config.PKCEAuthorizationFlow.ProviderConfig.UseIDToken,
+ DisablePromptLogin: s.config.PKCEAuthorizationFlow.ProviderConfig.DisablePromptLogin,
},
}
diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml
index 82971541d..c699e9eef 100644
--- a/management/server/http/api/openapi.yml
+++ b/management/server/http/api/openapi.yml
@@ -2397,6 +2397,29 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
+ /api/users/current:
+ get:
+ summary: Retrieve current user
+ description: Get information about the current user
+ tags: [ Users ]
+ security:
+ - BearerAuth: [ ]
+ - TokenAuth: [ ]
+ responses:
+ '200':
+ description: A User object
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/User'
+ '400':
+ "$ref": "#/components/responses/bad_request"
+ '401':
+ "$ref": "#/components/responses/requires_authentication"
+ '403':
+ "$ref": "#/components/responses/forbidden"
+ '500':
+ "$ref": "#/components/responses/internal_error"
/api/peers:
get:
summary: List all Peers
diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go
index 751311333..9bdb3e4ac 100644
--- a/management/server/http/api/types.gen.go
+++ b/management/server/http/api/types.gen.go
@@ -230,7 +230,7 @@ type Account struct {
// AccountExtraSettings defines model for AccountExtraSettings.
type AccountExtraSettings struct {
- // NetworkTrafficLogsEnabled Enables or disables network traffic logs. If enabled, all network traffic logs from peers will be stored.
+ // NetworkTrafficLogsEnabled Enables or disables network traffic logging. If enabled, all network traffic events from peers will be stored.
NetworkTrafficLogsEnabled bool `json:"network_traffic_logs_enabled"`
// NetworkTrafficPacketCounterEnabled Enables or disables network traffic packet counter. If enabled, network packets and their size will be counted and reported. (This can have an slight impact on performance)
diff --git a/management/server/http/handler.go b/management/server/http/handler.go
index e4cc8585a..483bb989a 100644
--- a/management/server/http/handler.go
+++ b/management/server/http/handler.go
@@ -66,15 +66,13 @@ func NewAPIHandler(
corsMiddleware := cors.AllowAll()
- acMiddleware := middleware.NewAccessControl(accountManager.GetUserFromUserAuth)
-
rootRouter := mux.NewRouter()
metricsMiddleware := appMetrics.HTTPMiddleware()
prefix := apiPrefix
router := rootRouter.PathPrefix(prefix).Subrouter()
- router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
+ router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler)
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter(), permissionsManager, peersManager, proxyController, settingsManager); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
@@ -85,6 +83,8 @@ func NewAPIHandler(
users.AddEndpoints(accountManager, router)
setup_keys.AddEndpoints(accountManager, router)
policies.AddEndpoints(accountManager, LocationManager, router)
+ policies.AddPostureCheckEndpoints(accountManager, LocationManager, router)
+ policies.AddLocationsEndpoints(accountManager, LocationManager, permissionsManager, router)
groups.AddEndpoints(accountManager, router)
routes.AddEndpoints(accountManager, router)
dns.AddEndpoints(accountManager, router)
diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go
index 667095018..3ae833dc0 100644
--- a/management/server/http/handlers/groups/groups_handler.go
+++ b/management/server/http/handlers/groups/groups_handler.go
@@ -143,7 +143,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
IntegrationReference: existingGroup.IntegrationReference,
}
- if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group); err != nil {
+ if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group, false); err != nil {
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
util.WriteError(r.Context(), err, w)
return
@@ -203,7 +203,7 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
Issued: types.GroupIssuedAPI,
}
- err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group)
+ err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group, true)
if err != nil {
util.WriteError(r.Context(), err, w)
return
diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go
index f4ac34e53..2caa2f5bf 100644
--- a/management/server/http/handlers/groups/groups_handler_test.go
+++ b/management/server/http/handlers/groups/groups_handler_test.go
@@ -35,7 +35,7 @@ var TestPeers = map[string]*nbpeer.Peer{
func initGroupTestData(initGroups ...*types.Group) *handler {
return &handler{
accountManager: &mock_server.MockAccountManager{
- SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group) error {
+ SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group, create bool) error {
if !strings.HasPrefix(group.ID, "id-") {
group.ID = "id-was-set"
}
diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go
index ae7255e5f..fa78836d8 100644
--- a/management/server/http/handlers/peers/peers_handler.go
+++ b/management/server/http/handlers/peers/peers_handler.go
@@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
+ "github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/http/api"
@@ -244,7 +245,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
return
}
- account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
+ account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
if err != nil {
util.WriteError(r.Context(), err, w)
return
diff --git a/management/server/http/handlers/policies/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go
index fbdc324d6..b7b53f53f 100644
--- a/management/server/http/handlers/policies/geolocation_handler_test.go
+++ b/management/server/http/handlers/policies/geolocation_handler_test.go
@@ -10,6 +10,7 @@ import (
"path/filepath"
"testing"
+ "github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
@@ -17,6 +18,9 @@ import (
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
+ "github.com/netbirdio/netbird/management/server/permissions"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
@@ -41,6 +45,14 @@ func initGeolocationTestData(t *testing.T) *geolocationsHandler {
assert.NoError(t, err)
t.Cleanup(func() { _ = geo.Stop() })
+ ctrl := gomock.NewController(t)
+ permissionsManagerMock := permissions.NewMockManager(ctrl)
+ permissionsManagerMock.
+ EXPECT().
+ ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), modules.Policies, operations.Read).
+ Return(true, nil).
+ AnyTimes()
+
return &geolocationsHandler{
accountManager: &mock_server.MockAccountManager{
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
@@ -48,6 +60,7 @@ func initGeolocationTestData(t *testing.T) *geolocationsHandler {
},
},
geolocationManager: geo,
+ permissionsManager: permissionsManagerMock,
}
}
diff --git a/management/server/http/handlers/policies/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go
index fb19887dc..84c8ea0aa 100644
--- a/management/server/http/handlers/policies/geolocations_handler.go
+++ b/management/server/http/handlers/policies/geolocations_handler.go
@@ -11,6 +11,9 @@ import (
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
+ "github.com/netbirdio/netbird/management/server/permissions"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status"
)
@@ -22,19 +25,21 @@ var (
type geolocationsHandler struct {
accountManager account.Manager
geolocationManager geolocation.Geolocation
+ permissionsManager permissions.Manager
}
-func addLocationsEndpoint(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) {
- locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager)
+func AddLocationsEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, permissionsManager permissions.Manager, router *mux.Router) {
+ locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, permissionsManager)
router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS")
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS")
}
// newGeolocationsHandlerHandler creates a new Geolocations handler
-func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationManager geolocation.Geolocation) *geolocationsHandler {
+func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationManager geolocation.Geolocation, permissionsManager permissions.Manager) *geolocationsHandler {
return &geolocationsHandler{
accountManager: accountManager,
geolocationManager: geolocationManager,
+ permissionsManager: permissionsManager,
}
}
@@ -98,20 +103,22 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.
}
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
- userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
+ ctx := r.Context()
+
+ userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
return err
}
- _, userID := userAuth.AccountId, userAuth.UserId
+ accountID, userID := userAuth.AccountId, userAuth.UserId
- user, err := l.accountManager.GetUserByID(r.Context(), userID)
+ allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
if err != nil {
- return err
+ return status.NewPermissionValidationError(err)
}
- if !user.HasAdminPower() {
- return status.Errorf(status.PermissionDenied, "user is not allowed to perform this action")
+ if !allowed {
+ return status.NewPermissionDeniedError()
}
return nil
}
diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go
index 01a09842a..9ff7ea0ea 100644
--- a/management/server/http/handlers/policies/policies_handler.go
+++ b/management/server/http/handlers/policies/policies_handler.go
@@ -28,7 +28,6 @@ func AddEndpoints(accountManager account.Manager, locationManager geolocation.Ge
router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS")
- addPostureCheckEndpoint(accountManager, locationManager, router)
}
// newHandler creates a new policies handler
@@ -96,7 +95,7 @@ func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
return
}
- h.savePolicy(w, r, accountID, userID, policyID)
+ h.savePolicy(w, r, accountID, userID, policyID, false)
}
// createPolicy handles policy creation request
@@ -109,11 +108,11 @@ func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) {
accountID, userID := userAuth.AccountId, userAuth.UserId
- h.savePolicy(w, r, accountID, userID, "")
+ h.savePolicy(w, r, accountID, userID, "", true)
}
// savePolicy handles policy creation and update
-func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) {
+func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string, create bool) {
var req api.PutApiPoliciesPolicyIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -280,7 +279,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
policy.SourcePostureChecks = *req.SourcePostureChecks
}
- policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy)
+ policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy, create)
if err != nil {
util.WriteError(r.Context(), err, w)
return
diff --git a/management/server/http/handlers/policies/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go
index 6450295eb..6f3dbc792 100644
--- a/management/server/http/handlers/policies/policies_handler_test.go
+++ b/management/server/http/handlers/policies/policies_handler_test.go
@@ -34,7 +34,7 @@ func initPoliciesTestData(policies ...*types.Policy) *handler {
}
return policy, nil
},
- SavePolicyFunc: func(_ context.Context, _, _ string, policy *types.Policy) (*types.Policy, error) {
+ SavePolicyFunc: func(_ context.Context, _, _ string, policy *types.Policy, create bool) (*types.Policy, error) {
if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set"
diff --git a/management/server/http/handlers/policies/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go
index b99649dbc..2925f96ef 100644
--- a/management/server/http/handlers/policies/posture_checks_handler.go
+++ b/management/server/http/handlers/policies/posture_checks_handler.go
@@ -21,14 +21,13 @@ type postureChecksHandler struct {
geolocationManager geolocation.Geolocation
}
-func addPostureCheckEndpoint(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) {
+func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) {
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager)
router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS")
- addLocationsEndpoint(accountManager, locationManager, router)
}
// newPostureChecksHandler creates a new PostureChecks handler
@@ -85,7 +84,7 @@ func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http
return
}
- p.savePostureChecks(w, r, accountID, userID, postureChecksID)
+ p.savePostureChecks(w, r, accountID, userID, postureChecksID, false)
}
// createPostureCheck handles posture check creation request
@@ -98,7 +97,7 @@ func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http
accountID, userID := userAuth.AccountId, userAuth.UserId
- p.savePostureChecks(w, r, accountID, userID, "")
+ p.savePostureChecks(w, r, accountID, userID, "", true)
}
// getPostureCheck handles a posture check Get request identified by ID
@@ -151,7 +150,7 @@ func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http
}
// savePostureChecks handles posture checks create and update
-func (p *postureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) {
+func (p *postureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string, create bool) {
var (
err error
req api.PostureCheckUpdate
@@ -176,7 +175,7 @@ func (p *postureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
return
}
- postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks)
+ postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks, create)
if err != nil {
util.WriteError(r.Context(), err, w)
return
diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go
index e3844caa2..e875b3738 100644
--- a/management/server/http/handlers/policies/posture_checks_handler_test.go
+++ b/management/server/http/handlers/policies/posture_checks_handler_test.go
@@ -40,7 +40,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH
}
return p, nil
},
- SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
+ SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) {
postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks
diff --git a/management/server/http/handlers/users/users_handler.go b/management/server/http/handlers/users/users_handler.go
index 19f56c464..c69c6b944 100644
--- a/management/server/http/handlers/users/users_handler.go
+++ b/management/server/http/handlers/users/users_handler.go
@@ -25,6 +25,7 @@ type handler struct {
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
userHandler := newHandler(accountManager)
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
+ router.HandleFunc("/users/current", userHandler.getCurrentUser).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
@@ -259,6 +260,29 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
+func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
+ return
+ }
+ ctx := r.Context()
+ userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
+ user, err := h.accountManager.GetCurrentUserInfo(ctx, accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ util.WriteJSONObject(r.Context(), w, toUserResponse(user, userID))
+}
+
func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
autoGroups := user.AutoGroups
if autoGroups == nil {
diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go
index a6a904a4c..604954819 100644
--- a/management/server/http/handlers/users/users_handler_test.go
+++ b/management/server/http/handlers/users/users_handler_test.go
@@ -9,6 +9,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
+ "time"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
@@ -123,6 +124,64 @@ func initUsersTestData() *handler {
return nil
},
+ GetCurrentUserInfoFunc: func(ctx context.Context, accountID, userID string) (*types.UserInfo, error) {
+ switch userID {
+ case "not-found":
+ return nil, status.NewUserNotFoundError("not-found")
+ case "not-of-account":
+ return nil, status.NewUserNotPartOfAccountError()
+ case "blocked-user":
+ return nil, status.NewUserBlockedError()
+ case "service-user":
+ return nil, status.NewPermissionDeniedError()
+ case "owner":
+ return &types.UserInfo{
+ ID: "owner",
+ Name: "",
+ Role: "owner",
+ Status: "active",
+ IsServiceUser: false,
+ IsBlocked: false,
+ NonDeletable: false,
+ Issued: "api",
+ Permissions: types.UserPermissions{
+ DashboardView: "full",
+ },
+ }, nil
+ case "regular-user":
+ return &types.UserInfo{
+ ID: "regular-user",
+ Name: "",
+ Role: "user",
+ Status: "active",
+ IsServiceUser: false,
+ IsBlocked: false,
+ NonDeletable: false,
+ Issued: "api",
+ Permissions: types.UserPermissions{
+ DashboardView: "limited",
+ },
+ }, nil
+
+ case "admin-user":
+ return &types.UserInfo{
+ ID: "admin-user",
+ Name: "",
+ Role: "admin",
+ Status: "active",
+ IsServiceUser: false,
+ IsBlocked: false,
+ NonDeletable: false,
+ LastLogin: time.Time{},
+ Issued: "api",
+ Permissions: types.UserPermissions{
+ DashboardView: "full",
+ },
+ }, nil
+ }
+
+ return nil, fmt.Errorf("user id %s not handled", userID)
+ },
},
}
}
@@ -481,3 +540,73 @@ func TestDeleteUser(t *testing.T) {
})
}
}
+
+func TestCurrentUser(t *testing.T) {
+ tt := []struct {
+ name string
+ expectedStatus int
+ requestAuth nbcontext.UserAuth
+ }{
+ {
+ name: "without auth",
+ expectedStatus: http.StatusInternalServerError,
+ },
+ {
+ name: "user not found",
+ requestAuth: nbcontext.UserAuth{UserId: "not-found"},
+ expectedStatus: http.StatusNotFound,
+ },
+ {
+ name: "not of account",
+ requestAuth: nbcontext.UserAuth{UserId: "not-of-account"},
+ expectedStatus: http.StatusForbidden,
+ },
+ {
+ name: "blocked user",
+ requestAuth: nbcontext.UserAuth{UserId: "blocked-user"},
+ expectedStatus: http.StatusForbidden,
+ },
+ {
+ name: "service user",
+ requestAuth: nbcontext.UserAuth{UserId: "service-user"},
+ expectedStatus: http.StatusForbidden,
+ },
+ {
+ name: "owner",
+ requestAuth: nbcontext.UserAuth{UserId: "owner"},
+ expectedStatus: http.StatusOK,
+ },
+ {
+ name: "regular user",
+ requestAuth: nbcontext.UserAuth{UserId: "regular-user"},
+ expectedStatus: http.StatusOK,
+ },
+ {
+ name: "admin user",
+ requestAuth: nbcontext.UserAuth{UserId: "admin-user"},
+ expectedStatus: http.StatusOK,
+ },
+ }
+
+ userHandler := initUsersTestData()
+ for _, tc := range tt {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, "/api/users/current", nil)
+ if tc.requestAuth.UserId != "" {
+ req = nbcontext.SetUserAuthInRequest(req, tc.requestAuth)
+ }
+
+ rr := httptest.NewRecorder()
+
+ userHandler.getCurrentUser(rr, req)
+
+ res := rr.Result()
+ defer res.Body.Close()
+
+ if status := rr.Code; status != tc.expectedStatus {
+ t.Fatalf("handler returned wrong status code: got %v want %v",
+ status, tc.expectedStatus)
+ }
+ })
+ }
+}
diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go
deleted file mode 100644
index 4ed90f47b..000000000
--- a/management/server/http/middleware/access_control.go
+++ /dev/null
@@ -1,77 +0,0 @@
-package middleware
-
-import (
- "context"
- "net/http"
- "regexp"
-
- log "github.com/sirupsen/logrus"
-
- nbcontext "github.com/netbirdio/netbird/management/server/context"
- "github.com/netbirdio/netbird/management/server/http/middleware/bypass"
- "github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/status"
- "github.com/netbirdio/netbird/management/server/types"
-)
-
-// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
-type GetUser func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
-
-// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
-type AccessControl struct {
- getUser GetUser
-}
-
-// NewAccessControl instance constructor
-func NewAccessControl(getUser GetUser) *AccessControl {
- return &AccessControl{
- getUser: getUser,
- }
-}
-
-var tokenPathRegexp = regexp.MustCompile(`^.*/api/users/.*/tokens.*$`)
-
-// Handler method of the middleware which forbids all modify requests for non admin users
-func (a *AccessControl) Handler(h http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-
- if bypass.ShouldBypass(r.URL.Path, h, w, r) {
- return
- }
-
- userAuth, err := nbcontext.GetUserAuthFromRequest(r)
- if err != nil {
- log.WithContext(r.Context()).Errorf("failed to get user auth from request: %s", err)
- util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid user auth"), w)
- }
-
- user, err := a.getUser(r.Context(), userAuth)
- if err != nil {
- log.WithContext(r.Context()).Errorf("failed to get user: %s", err)
- util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid user auth"), w)
- return
- }
-
- if user.IsBlocked() {
- util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
- return
- }
-
- if !user.HasAdminPower() {
- switch r.Method {
- case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
-
- if tokenPathRegexp.MatchString(r.URL.Path) {
- log.WithContext(r.Context()).Debugf("valid Path")
- h.ServeHTTP(w, r)
- return
- }
-
- util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w)
- return
- }
- }
-
- h.ServeHTTP(w, r)
- })
-}
diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go
index 31ea06460..12e68e983 100644
--- a/management/server/http/testing/testing_tools/tools.go
+++ b/management/server/http/testing/testing_tools/tools.go
@@ -15,7 +15,6 @@ import (
"time"
"github.com/golang-jwt/jwt"
-
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -123,9 +122,9 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
validatorMock := server.MocIntegratedValidator{}
proxyController := integrations.NewController(store)
userManager := users.NewManager(store)
- permissionsManagerMock := permissions.NewManagerMock()
- settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManagerMock)
- am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManagerMock)
+ permissionsManager := permissions.NewManager(store)
+ settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
+ am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
@@ -143,9 +142,9 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
resourcesManagerMock := resources.NewManagerMock()
routersManagerMock := routers.NewManagerMock()
groupsManagerMock := groups.NewManagerMock()
- peersManager := peers.NewManager(store, permissionsManagerMock)
+ peersManager := peers.NewManager(store, permissionsManager)
- apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManagerMock, peersManager, settingsManager)
+ apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go
index c87fe05ce..b85a43da4 100644
--- a/management/server/management_proto_test.go
+++ b/management/server/management_proto_test.go
@@ -432,8 +432,6 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
- permissionsManagerMock := permissions.NewManagerMock()
-
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
@@ -443,8 +441,10 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
AnyTimes().
Return(&types.Settings{}, nil)
+ permissionsManager := permissions.NewManager(store)
+
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
- eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
+ eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
if err != nil {
cleanup()
diff --git a/management/server/management_test.go b/management/server/management_test.go
index dd987c005..a4f9a5e38 100644
--- a/management/server/management_test.go
+++ b/management/server/management_test.go
@@ -195,7 +195,7 @@ func startServer(
Return(&types.Settings{}, nil).
AnyTimes()
- permissionsManagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(str)
accountManager, err := server.BuildManager(
context.Background(),
str,
@@ -210,7 +210,7 @@ func startServer(
metrics,
port_forwarding.NewControllerMock(),
settingsMockManager,
- permissionsManagerMock,
+ permissionsManager,
)
if err != nil {
t.Fatalf("failed creating an account manager: %v", err)
diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go
index 008a7059f..870fe3219 100644
--- a/management/server/mock_server/account_mock.go
+++ b/management/server/mock_server/account_mock.go
@@ -44,8 +44,8 @@ type MockAccountManager struct {
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error)
- SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group) error
- SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group) error
+ SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error
+ SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
@@ -53,7 +53,7 @@ type MockAccountManager struct {
GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*types.Group, error)
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error)
- SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error)
+ SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error)
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
@@ -97,7 +97,7 @@ type MockAccountManager struct {
HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() account.ExternalCacheManager
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
- SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
+ SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error)
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager
@@ -115,6 +115,7 @@ type MockAccountManager struct {
CreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, error)
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
+ GetCurrentUserInfoFunc func(ctx context.Context, accountID, userID string) (*types.UserInfo, error)
}
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
@@ -322,17 +323,17 @@ func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, gro
}
// SaveGroup mock implementation of SaveGroup from server.AccountManager interface
-func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
+func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *types.Group, create bool) error {
if am.SaveGroupFunc != nil {
- return am.SaveGroupFunc(ctx, accountID, userID, group)
+ return am.SaveGroupFunc(ctx, accountID, userID, group, create)
}
return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented")
}
// SaveGroups mock implementation of SaveGroups from server.AccountManager interface
-func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
+func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error {
if am.SaveGroupsFunc != nil {
- return am.SaveGroupsFunc(ctx, accountID, userID, groups)
+ return am.SaveGroupsFunc(ctx, accountID, userID, groups, create)
}
return status.Errorf(codes.Unimplemented, "method SaveGroups is not implemented")
}
@@ -386,9 +387,9 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
}
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
-func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) {
+func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) {
if am.SavePolicyFunc != nil {
- return am.SavePolicyFunc(ctx, accountID, userID, policy)
+ return am.SavePolicyFunc(ctx, accountID, userID, policy, create)
}
return nil, status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
}
@@ -722,9 +723,9 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p
}
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
-func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
+func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) {
if am.SavePostureChecksFunc != nil {
- return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
+ return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks, create)
}
return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
}
@@ -871,3 +872,10 @@ func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string
}
return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented")
}
+
+func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error) {
+ if am.GetCurrentUserInfoFunc != nil {
+ return am.GetCurrentUserInfoFunc(ctx, accountID, userID)
+ }
+ return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented")
+}
diff --git a/management/server/nameserver.go b/management/server/nameserver.go
index b1cf2bc72..797d7c11c 100644
--- a/management/server/nameserver.go
+++ b/management/server/nameserver.go
@@ -11,6 +11,8 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -20,17 +22,12 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Read)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- if user.IsRegularUser() {
- return nil, status.NewAdminPermissionError()
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupID)
@@ -41,13 +38,12 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Create)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
newNSGroup := &nbdns.NameServerGroup{
@@ -103,13 +99,12 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
}
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Update)
if err != nil {
- return err
+ return status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return err
+ if !allowed {
+ return status.NewPermissionDeniedError()
}
var updateAccountPeers bool
@@ -154,13 +149,12 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Delete)
if err != nil {
- return err
+ return status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return err
+ if !allowed {
+ return status.NewPermissionDeniedError()
}
var nsGroup *nbdns.NameServerGroup
@@ -198,17 +192,12 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
// ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Read)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- if user.IsRegularUser() {
- return nil, status.NewAdminPermissionError()
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go
index 13039ae63..1ba790797 100644
--- a/management/server/nameserver_test.go
+++ b/management/server/nameserver_test.go
@@ -775,12 +775,11 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
- permissionsManagerMock := permissions.NewManagerMock()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
-
- return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
+ permissionsManager := permissions.NewManager(store)
+ return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
}
func createNSStore(t *testing.T) (store.Store, error) {
@@ -966,7 +965,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
Name: "GroupB",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
},
- })
+ }, true)
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go
index 609b68918..1c46e9281 100644
--- a/management/server/networks/manager.go
+++ b/management/server/networks/manager.go
@@ -12,6 +12,8 @@ import (
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/permissions"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
)
@@ -46,7 +48,7 @@ func NewManager(store store.Store, permissionsManager permissions.Manager, resou
}
func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -58,7 +60,7 @@ func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID stri
}
func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, permissions.Networks, permissions.Write)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, modules.Networks, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -82,7 +84,7 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network
}
func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -94,7 +96,7 @@ func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, network
}
func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, permissions.Networks, permissions.Write)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, modules.Networks, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -116,7 +118,7 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
}
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
diff --git a/management/server/networks/manager_test.go b/management/server/networks/manager_test.go
index edd830c25..bf196fcb3 100644
--- a/management/server/networks/manager_test.go
+++ b/management/server/networks/manager_test.go
@@ -18,7 +18,7 @@ import (
func Test_GetAllNetworksReturnsNetworks(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "allowedUser"
+ userID := "testAdminId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir())
if err != nil {
@@ -26,7 +26,7 @@ func Test_GetAllNetworksReturnsNetworks(t *testing.T) {
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
@@ -41,7 +41,7 @@ func Test_GetAllNetworksReturnsNetworks(t *testing.T) {
func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "invalidUser"
+ userID := "testUserId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir())
if err != nil {
@@ -49,7 +49,7 @@ func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) {
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
@@ -63,7 +63,7 @@ func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) {
func Test_GetNetworkReturnsNetwork(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "allowedUser"
+ userID := "testAdminId"
networkID := "testNetworkId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir())
@@ -72,7 +72,7 @@ func Test_GetNetworkReturnsNetwork(t *testing.T) {
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
@@ -86,7 +86,7 @@ func Test_GetNetworkReturnsNetwork(t *testing.T) {
func Test_GetNetworkReturnsPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "invalidUser"
+ userID := "testUserId"
networkID := "testNetworkId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir())
@@ -95,7 +95,7 @@ func Test_GetNetworkReturnsPermissionDenied(t *testing.T) {
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
@@ -108,7 +108,7 @@ func Test_GetNetworkReturnsPermissionDenied(t *testing.T) {
func Test_CreateNetworkSuccessfully(t *testing.T) {
ctx := context.Background()
- userID := "allowedUser"
+ userID := "testAdminId"
network := &types.Network{
AccountID: "testAccountId",
Name: "new-network",
@@ -120,7 +120,7 @@ func Test_CreateNetworkSuccessfully(t *testing.T) {
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
@@ -133,7 +133,7 @@ func Test_CreateNetworkSuccessfully(t *testing.T) {
func Test_CreateNetworkFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
- userID := "invalidUser"
+ userID := "testUserId"
network := &types.Network{
AccountID: "testAccountId",
Name: "new-network",
@@ -145,7 +145,7 @@ func Test_CreateNetworkFailsWithPermissionDenied(t *testing.T) {
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
@@ -159,7 +159,7 @@ func Test_CreateNetworkFailsWithPermissionDenied(t *testing.T) {
func Test_DeleteNetworkSuccessfully(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "allowedUser"
+ userID := "testAdminId"
networkID := "testNetworkId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir())
@@ -168,7 +168,7 @@ func Test_DeleteNetworkSuccessfully(t *testing.T) {
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
@@ -181,7 +181,7 @@ func Test_DeleteNetworkSuccessfully(t *testing.T) {
func Test_DeleteNetworkFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "invalidUser"
+ userID := "testUserId"
networkID := "testNetworkId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir())
@@ -190,7 +190,7 @@ func Test_DeleteNetworkFailsWithPermissionDenied(t *testing.T) {
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
@@ -202,7 +202,7 @@ func Test_DeleteNetworkFailsWithPermissionDenied(t *testing.T) {
func Test_UpdateNetworkSuccessfully(t *testing.T) {
ctx := context.Background()
- userID := "allowedUser"
+ userID := "testAdminId"
network := &types.Network{
AccountID: "testAccountId",
ID: "testNetworkId",
@@ -215,7 +215,7 @@ func Test_UpdateNetworkSuccessfully(t *testing.T) {
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
@@ -228,7 +228,7 @@ func Test_UpdateNetworkSuccessfully(t *testing.T) {
func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
- userID := "invalidUser"
+ userID := "testUserId"
network := &types.Network{
AccountID: "testAccountId",
ID: "testNetworkId",
@@ -242,7 +242,7 @@ func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) {
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go
index acaacbfb9..21d1e54de 100644
--- a/management/server/networks/resources/manager.go
+++ b/management/server/networks/resources/manager.go
@@ -10,6 +10,8 @@ import (
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/management/server/permissions"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
nbtypes "github.com/netbirdio/netbird/management/server/types"
@@ -47,7 +49,7 @@ func NewManager(store store.Store, permissionsManager permissions.Manager, group
}
func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -59,7 +61,7 @@ func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, u
}
func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -71,7 +73,7 @@ func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, u
}
func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -93,7 +95,7 @@ func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID,
}
func (m *managerImpl) CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, permissions.Networks, permissions.Write)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, modules.Networks, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -164,7 +166,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
}
func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -185,7 +187,7 @@ func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networ
}
func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, permissions.Networks, permissions.Write)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, modules.Networks, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -305,7 +307,7 @@ func (m *managerImpl) updateResourceGroups(ctx context.Context, transaction stor
}
func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
diff --git a/management/server/networks/resources/manager_test.go b/management/server/networks/resources/manager_test.go
index 993cd65df..3a91b4af8 100644
--- a/management/server/networks/resources/manager_test.go
+++ b/management/server/networks/resources/manager_test.go
@@ -17,7 +17,7 @@ import (
func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "allowedUser"
+ userID := "testAdminId"
networkID := "testNetworkId"
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
@@ -25,7 +25,7 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
@@ -38,7 +38,7 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "invalidUser"
+ userID := "testUserId"
networkID := "testNetworkId"
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
@@ -46,7 +46,7 @@ func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
@@ -59,14 +59,14 @@ func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) {
func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "allowedUser"
+ userID := "testAdminId"
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
@@ -79,14 +79,14 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) {
func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "invalidUser"
+ userID := "testUserId"
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
@@ -100,7 +100,7 @@ func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) {
func Test_GetResourceInNetworkReturnsResources(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "allowedUser"
+ userID := "testAdminId"
networkID := "testNetworkId"
resourceID := "testResourceId"
@@ -109,7 +109,7 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
@@ -122,7 +122,7 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) {
func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "invalidUser"
+ userID := "testUserId"
networkID := "testNetworkId"
resourceID := "testResourceId"
@@ -131,7 +131,7 @@ func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
@@ -144,7 +144,7 @@ func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) {
func Test_CreateResourceSuccessfully(t *testing.T) {
ctx := context.Background()
- userID := "allowedUser"
+ userID := "testAdminId"
resource := &types.NetworkResource{
AccountID: "testAccountId",
NetworkID: "testNetworkId",
@@ -158,7 +158,7 @@ func Test_CreateResourceSuccessfully(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
@@ -170,7 +170,7 @@ func Test_CreateResourceSuccessfully(t *testing.T) {
func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
- userID := "invalidUser"
+ userID := "testUserId"
resource := &types.NetworkResource{
AccountID: "testAccountId",
NetworkID: "testNetworkId",
@@ -184,7 +184,7 @@ func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
@@ -197,7 +197,7 @@ func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) {
func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
ctx := context.Background()
- userID := "allowedUser"
+ userID := "testAdminId"
resource := &types.NetworkResource{
AccountID: "testAccountId",
NetworkID: "testNetworkId",
@@ -211,7 +211,7 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
@@ -223,7 +223,7 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
func Test_CreateResourceFailsWithUsedName(t *testing.T) {
ctx := context.Background()
- userID := "allowedUser"
+ userID := "testAdminId"
resource := &types.NetworkResource{
AccountID: "testAccountId",
NetworkID: "testNetworkId",
@@ -237,7 +237,7 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
@@ -250,7 +250,7 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
func Test_UpdateResourceSuccessfully(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "allowedUser"
+ userID := "testAdminId"
networkID := "testNetworkId"
resourceID := "testResourceId"
resource := &types.NetworkResource{
@@ -267,7 +267,7 @@ func Test_UpdateResourceSuccessfully(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
@@ -283,7 +283,7 @@ func Test_UpdateResourceSuccessfully(t *testing.T) {
func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "allowedUser"
+ userID := "testAdminId"
networkID := "testNetworkId"
resourceID := "otherResourceId"
resource := &types.NetworkResource{
@@ -299,7 +299,7 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
@@ -312,7 +312,7 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) {
func Test_UpdateResourceFailsWithNameInUse(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "allowedUser"
+ userID := "testAdminId"
networkID := "testNetworkId"
resourceID := "testResourceId"
resource := &types.NetworkResource{
@@ -329,7 +329,7 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
@@ -342,7 +342,7 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) {
func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "invalidUser"
+ userID := "testUserId"
networkID := "testNetworkId"
resourceID := "testResourceId"
resource := &types.NetworkResource{
@@ -358,7 +358,7 @@ func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
@@ -371,7 +371,7 @@ func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) {
func Test_DeleteResourceSuccessfully(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "allowedUser"
+ userID := "testAdminId"
networkID := "testNetworkId"
resourceID := "testResourceId"
@@ -380,7 +380,7 @@ func Test_DeleteResourceSuccessfully(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
@@ -392,7 +392,7 @@ func Test_DeleteResourceSuccessfully(t *testing.T) {
func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "invalidUser"
+ userID := "testUserId"
networkID := "testNetworkId"
resourceID := "testResourceId"
@@ -401,7 +401,7 @@ func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go
index 595fffd97..7b488b361 100644
--- a/management/server/networks/routers/manager.go
+++ b/management/server/networks/routers/manager.go
@@ -12,6 +12,8 @@ import (
"github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/permissions"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
)
@@ -44,7 +46,7 @@ func NewManager(store store.Store, permissionsManager permissions.Manager, accou
}
func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -56,7 +58,7 @@ func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, use
}
func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -78,7 +80,7 @@ func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, use
}
func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, permissions.Networks, permissions.Write)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, modules.Networks, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -126,7 +128,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
}
func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -147,7 +149,7 @@ func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkI
}
func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, permissions.Networks, permissions.Write)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, modules.Networks, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -193,7 +195,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
}
func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
diff --git a/management/server/networks/routers/manager_test.go b/management/server/networks/routers/manager_test.go
index 47f5ad7e3..541643222 100644
--- a/management/server/networks/routers/manager_test.go
+++ b/management/server/networks/routers/manager_test.go
@@ -16,7 +16,7 @@ import (
func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "allowedUser"
+ userID := "testAdminId"
networkID := "testNetworkId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
@@ -24,7 +24,7 @@ func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
@@ -37,7 +37,7 @@ func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) {
func Test_GetAllRoutersInNetworkReturnsPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "invalidUser"
+ userID := "testUserId"
networkID := "testNetworkId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
@@ -45,7 +45,7 @@ func Test_GetAllRoutersInNetworkReturnsPermissionDenied(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
@@ -58,7 +58,7 @@ func Test_GetAllRoutersInNetworkReturnsPermissionDenied(t *testing.T) {
func Test_GetRouterReturnsRouter(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "allowedUser"
+ userID := "testAdminId"
networkID := "testNetworkId"
resourceID := "testRouterId"
@@ -67,7 +67,7 @@ func Test_GetRouterReturnsRouter(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
@@ -79,7 +79,7 @@ func Test_GetRouterReturnsRouter(t *testing.T) {
func Test_GetRouterReturnsPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "invalidUser"
+ userID := "testUserId"
networkID := "testNetworkId"
resourceID := "testRouterId"
@@ -88,7 +88,7 @@ func Test_GetRouterReturnsPermissionDenied(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
@@ -100,7 +100,7 @@ func Test_GetRouterReturnsPermissionDenied(t *testing.T) {
func Test_CreateRouterSuccessfully(t *testing.T) {
ctx := context.Background()
- userID := "allowedUser"
+ userID := "testAdminId"
router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 9999, true)
if err != nil {
require.NoError(t, err)
@@ -111,7 +111,7 @@ func Test_CreateRouterSuccessfully(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
@@ -126,7 +126,7 @@ func Test_CreateRouterSuccessfully(t *testing.T) {
func Test_CreateRouterFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
- userID := "invalidUser"
+ userID := "testUserId"
router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 9999, true)
if err != nil {
require.NoError(t, err)
@@ -137,7 +137,7 @@ func Test_CreateRouterFailsWithPermissionDenied(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
@@ -150,7 +150,7 @@ func Test_CreateRouterFailsWithPermissionDenied(t *testing.T) {
func Test_DeleteRouterSuccessfully(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "allowedUser"
+ userID := "testAdminId"
networkID := "testNetworkId"
routerID := "testRouterId"
@@ -159,7 +159,7 @@ func Test_DeleteRouterSuccessfully(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
@@ -170,7 +170,7 @@ func Test_DeleteRouterSuccessfully(t *testing.T) {
func Test_DeleteRouterFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
- userID := "invalidUser"
+ userID := "testUserId"
networkID := "testNetworkId"
routerID := "testRouterId"
@@ -179,7 +179,7 @@ func Test_DeleteRouterFailsWithPermissionDenied(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
@@ -190,7 +190,7 @@ func Test_DeleteRouterFailsWithPermissionDenied(t *testing.T) {
func Test_UpdateRouterSuccessfully(t *testing.T) {
ctx := context.Background()
- userID := "allowedUser"
+ userID := "testAdminId"
router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1, true)
if err != nil {
require.NoError(t, err)
@@ -201,7 +201,7 @@ func Test_UpdateRouterSuccessfully(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
@@ -212,7 +212,7 @@ func Test_UpdateRouterSuccessfully(t *testing.T) {
func Test_UpdateRouterFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
- userID := "invalidUser"
+ userID := "testUserId"
router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1, true)
if err != nil {
require.NoError(t, err)
@@ -223,7 +223,7 @@ func Test_UpdateRouterFailsWithPermissionDenied(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
- permissionsManager := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
diff --git a/management/server/peer.go b/management/server/peer.go
index e7d4b29f5..27825a148 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -17,6 +17,8 @@ import (
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/geolocation"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/posture"
@@ -37,17 +39,9 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return nil, err
}
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
- return nil, err
- }
-
- if user.IsRegularUser() && settings.RegularUsersViewBlocked {
- return []*nbpeer.Peer{}, nil
+ return nil, status.NewPermissionValidationError(err)
}
accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, nameFilter, ipFilter)
@@ -67,10 +61,23 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
peersMap[peer.ID] = peer
}
- if user.IsAdminOrServiceUser() {
+ if allowed {
return peers, nil
}
+ settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get account settings: %w", err)
+ }
+
+ if settings.RegularUsersViewBlocked {
+ return []*nbpeer.Peer{}, nil
+ }
+
+ return am.getUserAccessiblePeers(ctx, accountID, peersMap, peers)
+}
+
+func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, accountID string, peersMap map[string]*nbpeer.Peer, peers []*nbpeer.Peer) ([]*nbpeer.Peer, error) {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, err
@@ -135,7 +142,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
if expired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
- am.UpdateAccountPeers(ctx, accountID)
+ am.BufferUpdateAccountPeers(ctx, accountID)
}
return nil
@@ -183,13 +190,12 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
var peer *nbpeer.Peer
@@ -315,15 +321,12 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- if userID != activity.SystemInitiator {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
- if err != nil {
- return err
- }
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return err
- }
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Delete)
+ if err != nil {
+ return status.NewPermissionValidationError(err)
+ }
+ if !allowed {
+ return status.NewPermissionDeniedError()
}
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID)
@@ -383,7 +386,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
}
if updateAccountPeers {
- am.UpdateAccountPeers(ctx, accountID)
+ am.BufferUpdateAccountPeers(ctx, accountID)
}
return nil
@@ -653,7 +656,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
unlock = nil
if updateAccountPeers {
- am.UpdateAccountPeers(ctx, accountID)
+ am.BufferUpdateAccountPeers(ctx, accountID)
}
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
@@ -748,7 +751,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
}
if isStatusChanged || sync.UpdateAccountPeers || (updated && len(postureChecks) > 0) {
- am.UpdateAccountPeers(ctx, accountID)
+ am.BufferUpdateAccountPeers(ctx, accountID)
}
return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
@@ -893,7 +896,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
unlockPeer = nil
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
- am.UpdateAccountPeers(ctx, accountID)
+ am.BufferUpdateAccountPeers(ctx, accountID)
}
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
@@ -1094,41 +1097,33 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Se
// GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
- if err != nil {
- return nil, err
- }
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
- if err != nil {
- return nil, err
- }
-
- if user.IsRegularUser() && settings.RegularUsersViewBlocked {
- return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID)
- }
-
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
if err != nil {
return nil, err
}
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
+ if err != nil {
+ return nil, status.NewPermissionValidationError(err)
+ }
+ if allowed {
+ return peer, nil
+ }
+
+ user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ if err != nil {
+ return nil, err
+ }
+
// if admin or user owns this peer, return peer
if user.IsAdminOrServiceUser() || peer.UserID == userID {
return peer, nil
}
- // it is also possible that user doesn't own the peer but some of his peers have access to it,
- // this is a valid case, show the peer as well.
- userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID)
- if err != nil {
- return nil, err
- }
+ return am.checkIfUserOwnsPeer(ctx, accountID, userID, peer)
+}
+func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, err
@@ -1139,16 +1134,23 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
return nil, err
}
+ // it is also possible that user doesn't own the peer but some of his peers have access to it,
+ // this is a valid case, show the peer as well.
+ userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID)
+ if err != nil {
+ return nil, err
+ }
+
for _, p := range userPeers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap)
for _, aclPeer := range aclPeers {
- if aclPeer.ID == peerID {
+ if aclPeer.ID == peer.ID {
return peer, nil
}
}
}
- return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peerID, accountID)
+ return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peer.ID, accountID)
}
// UpdateAccountPeers updates all peers that belong to an account.
@@ -1226,6 +1228,21 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
}
}
+func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
+ mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{})
+ lock := mu.(*sync.Mutex)
+
+ if !lock.TryLock() {
+ return
+ }
+
+ go func() {
+ time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load()))
+ lock.Unlock()
+ am.UpdateAccountPeers(ctx, accountID)
+ }()
+}
+
// UpdateAccountPeer updates a single peer that belongs to an account.
// Should be called when changes need to be synced to a specific peer only.
func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) {
diff --git a/management/server/peer_test.go b/management/server/peer_test.go
index b2563dcb0..406c3e49e 100644
--- a/management/server/peer_test.go
+++ b/management/server/peer_test.go
@@ -303,12 +303,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
group1.Peers = append(group1.Peers, peer1.ID)
group2.Peers = append(group2.Peers, peer2.ID)
- err = manager.SaveGroup(context.Background(), account.Id, userID, &group1)
+ err = manager.SaveGroup(context.Background(), account.Id, userID, &group1, true)
if err != nil {
t.Errorf("expecting group1 to be added, got failure %v", err)
return
}
- err = manager.SaveGroup(context.Background(), account.Id, userID, &group2)
+ err = manager.SaveGroup(context.Background(), account.Id, userID, &group2, true)
if err != nil {
t.Errorf("expecting group2 to be added, got failure %v", err)
return
@@ -327,7 +327,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
},
},
}
- policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
+ policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
@@ -375,7 +375,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
}
policy.Enabled = false
- _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
+ _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
@@ -1264,9 +1264,9 @@ func Test_RegisterPeerByUser(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
+ permissionsManager := permissions.NewManager(s)
- permissionsManagerMock := permissions.NewManagerMock()
- am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
+ am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1333,9 +1333,9 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
+ permissionsManager := permissions.NewManager(s)
- permissionsManagerMock := permissions.NewManagerMock()
- am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
+ am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1406,8 +1406,9 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
- permissionsManagerMock := permissions.NewManagerMock()
- am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
+ permissionsManager := permissions.NewManager(s)
+
+ am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1477,7 +1478,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
Name: "GroupC",
Peers: []string{},
},
- })
+ }, true)
require.NoError(t, err)
// create a user with auto groups
@@ -1653,7 +1654,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
Action: types.PolicyTrafficActionAccept,
},
},
- })
+ }, true)
require.NoError(t, err)
done := make(chan struct{})
diff --git a/management/server/peers/manager.go b/management/server/peers/manager.go
index b00c1761b..fe48bf576 100644
--- a/management/server/peers/manager.go
+++ b/management/server/peers/manager.go
@@ -8,6 +8,8 @@ import (
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
)
@@ -31,7 +33,7 @@ func NewManager(store store.Store, permissionsManager permissions.Manager) Manag
}
func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
- allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Peers, permissions.Read)
+ allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
@@ -44,13 +46,13 @@ func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID str
}
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
- allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Peers, permissions.Read)
+ allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
- return nil, status.NewPermissionDeniedError()
+ return m.store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID)
}
return m.store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go
index 24ac09d1a..50a44eb0f 100644
--- a/management/server/permissions/manager.go
+++ b/management/server/permissions/manager.go
@@ -1,34 +1,24 @@
package permissions
+//go:generate go run github.com/golang/mock/mockgen -package permissions -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
+
import (
"context"
- "errors"
- "fmt"
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/management/server/activity"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
+ "github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
-type Module string
-
-const (
- Networks Module = "networks"
- Peers Module = "peers"
- Groups Module = "groups"
- Settings Module = "settings"
- Accounts Module = "accounts"
-)
-
-type Operation string
-
-const (
- Read Operation = "read"
- Write Operation = "write"
-)
-
type Manager interface {
- ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error)
+ ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error)
+ ValidateRoleModuleAccess(ctx context.Context, accountID string, role roles.RolePermissions, module modules.Module, operation operations.Operation) bool
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
}
@@ -36,16 +26,23 @@ type managerImpl struct {
store store.Store
}
-type managerMock struct {
-}
-
func NewManager(store store.Store) Manager {
return &managerImpl{
store: store,
}
}
-func (m *managerImpl) ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) {
+func (m *managerImpl) ValidateUserPermissions(
+ ctx context.Context,
+ accountID string,
+ userID string,
+ module modules.Module,
+ operation operations.Operation,
+) (bool, error) {
+ if userID == activity.SystemInitiator {
+ return true, nil
+ }
+
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return false, err
@@ -55,49 +52,42 @@ func (m *managerImpl) ValidateUserPermissions(ctx context.Context, accountID, us
return false, status.NewUserNotFoundError(userID)
}
+ if user.IsBlocked() {
+ return false, status.NewUserBlockedError()
+ }
+
if err := m.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return false, err
}
- switch module {
- case Accounts:
- if operation == Write && user.Role != types.UserRoleOwner {
- return false, nil
- }
- return true, nil
- default:
+ if operation == operations.Read && user.IsServiceUser {
+ return true, nil // this should be replaced by proper granular access role
}
- switch user.Role {
- case types.UserRoleAdmin, types.UserRoleOwner:
- return true, nil
- case types.UserRoleUser:
- return m.validateRegularUserPermissions(ctx, accountID, module, operation)
- case types.UserRoleBillingAdmin:
- return false, nil
- default:
- return false, errors.New("invalid role")
+ role, ok := roles.RolesMap[user.Role]
+ if !ok {
+ return false, status.NewUserRoleNotFoundError(string(user.Role))
}
+
+ return m.ValidateRoleModuleAccess(ctx, accountID, role, module, operation), nil
}
-func (m *managerImpl) validateRegularUserPermissions(ctx context.Context, accountID string, module Module, operation Operation) (bool, error) {
- settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
- if err != nil {
- return false, fmt.Errorf("failed to get settings: %w", err)
- }
- if settings.RegularUsersViewBlocked {
- return false, nil
+func (m *managerImpl) ValidateRoleModuleAccess(
+ ctx context.Context,
+ accountID string,
+ role roles.RolePermissions,
+ module modules.Module,
+ operation operations.Operation,
+) bool {
+ if permissions, ok := role.Permissions[module]; ok {
+ if allowed, exists := permissions[operation]; exists {
+ return allowed
+ }
+ log.WithContext(ctx).Tracef("operation %s not found on module %s for role %s", operation, module, role.Role)
+ return false
}
- if operation == Write {
- return false, nil
- }
-
- if module == Peers {
- return true, nil
- }
-
- return false, nil
+ return role.AutoAllowNew[operation]
}
func (m *managerImpl) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
@@ -106,24 +96,3 @@ func (m *managerImpl) ValidateAccountAccess(ctx context.Context, accountID strin
}
return nil
}
-
-func NewManagerMock() Manager {
- return &managerMock{}
-}
-
-func (m *managerMock) ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) {
- switch userID {
- case "a23efe53-63fb-11ec-90d6-0242ac120003", "allowedUser", "testingUser", "account_creator", "serviceUserID", "test_user":
- return true, nil
- default:
- return false, nil
- }
-}
-
-func (m *managerMock) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
- // @note managers explicitly checked this, so should the mock
- if user.AccountID != accountID {
- return status.NewUserNotPartOfAccountError()
- }
- return nil
-}
diff --git a/management/server/permissions/manager_mock.go b/management/server/permissions/manager_mock.go
new file mode 100644
index 000000000..266a24270
--- /dev/null
+++ b/management/server/permissions/manager_mock.go
@@ -0,0 +1,82 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: ./manager.go
+
+// Package permissions is a generated GoMock package.
+package permissions
+
+import (
+ context "context"
+ reflect "reflect"
+
+ gomock "github.com/golang/mock/gomock"
+ modules "github.com/netbirdio/netbird/management/server/permissions/modules"
+ operations "github.com/netbirdio/netbird/management/server/permissions/operations"
+ roles "github.com/netbirdio/netbird/management/server/permissions/roles"
+ types "github.com/netbirdio/netbird/management/server/types"
+)
+
+// MockManager is a mock of Manager interface.
+type MockManager struct {
+ ctrl *gomock.Controller
+ recorder *MockManagerMockRecorder
+}
+
+// MockManagerMockRecorder is the mock recorder for MockManager.
+type MockManagerMockRecorder struct {
+ mock *MockManager
+}
+
+// NewMockManager creates a new mock instance.
+func NewMockManager(ctrl *gomock.Controller) *MockManager {
+ mock := &MockManager{ctrl: ctrl}
+ mock.recorder = &MockManagerMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockManager) EXPECT() *MockManagerMockRecorder {
+ return m.recorder
+}
+
+// ValidateAccountAccess mocks base method.
+func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "ValidateAccountAccess", ctx, accountID, user, allowOwnerAndAdmin)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// ValidateAccountAccess indicates an expected call of ValidateAccountAccess.
+func (mr *MockManagerMockRecorder) ValidateAccountAccess(ctx, accountID, user, allowOwnerAndAdmin interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateAccountAccess", reflect.TypeOf((*MockManager)(nil).ValidateAccountAccess), ctx, accountID, user, allowOwnerAndAdmin)
+}
+
+// ValidateRoleModuleAccess mocks base method.
+func (m *MockManager) ValidateRoleModuleAccess(ctx context.Context, accountID string, role roles.RolePermissions, module modules.Module, operation operations.Operation) bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "ValidateRoleModuleAccess", ctx, accountID, role, module, operation)
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// ValidateRoleModuleAccess indicates an expected call of ValidateRoleModuleAccess.
+func (mr *MockManagerMockRecorder) ValidateRoleModuleAccess(ctx, accountID, role, module, operation interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateRoleModuleAccess", reflect.TypeOf((*MockManager)(nil).ValidateRoleModuleAccess), ctx, accountID, role, module, operation)
+}
+
+// ValidateUserPermissions mocks base method.
+func (m *MockManager) ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "ValidateUserPermissions", ctx, accountID, userID, module, operation)
+ ret0, _ := ret[0].(bool)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// ValidateUserPermissions indicates an expected call of ValidateUserPermissions.
+func (mr *MockManagerMockRecorder) ValidateUserPermissions(ctx, accountID, userID, module, operation interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateUserPermissions", reflect.TypeOf((*MockManager)(nil).ValidateUserPermissions), ctx, accountID, userID, module, operation)
+}
diff --git a/management/server/permissions/modules/module.go b/management/server/permissions/modules/module.go
new file mode 100644
index 000000000..4c42b6190
--- /dev/null
+++ b/management/server/permissions/modules/module.go
@@ -0,0 +1,19 @@
+package modules
+
+type Module string
+
+const (
+ Networks Module = "networks"
+ Peers Module = "peers"
+ Groups Module = "groups"
+ Settings Module = "settings"
+ Accounts Module = "accounts"
+ Dns Module = "dns"
+ Nameservers Module = "nameservers"
+ Events Module = "events"
+ Policies Module = "policies"
+ Routes Module = "routes"
+ Users Module = "users"
+ SetupKeys Module = "setup_keys"
+ Pats Module = "pats"
+)
diff --git a/management/server/permissions/operations/operation.go b/management/server/permissions/operations/operation.go
new file mode 100644
index 000000000..11481234f
--- /dev/null
+++ b/management/server/permissions/operations/operation.go
@@ -0,0 +1,10 @@
+package operations
+
+type Operation string
+
+const (
+ Create Operation = "create"
+ Read Operation = "read"
+ Update Operation = "update"
+ Delete Operation = "delete"
+)
diff --git a/management/server/permissions/roles/admin.go b/management/server/permissions/roles/admin.go
new file mode 100644
index 000000000..af3a81297
--- /dev/null
+++ b/management/server/permissions/roles/admin.go
@@ -0,0 +1,25 @@
+package roles
+
+import (
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
+ "github.com/netbirdio/netbird/management/server/types"
+)
+
+var Admin = RolePermissions{
+ Role: types.UserRoleAdmin,
+ AutoAllowNew: map[operations.Operation]bool{
+ operations.Read: true,
+ operations.Create: true,
+ operations.Update: true,
+ operations.Delete: true,
+ },
+ Permissions: Permissions{
+ modules.Accounts: {
+ operations.Read: true,
+ operations.Create: false,
+ operations.Update: false,
+ operations.Delete: false,
+ },
+ },
+}
diff --git a/management/server/permissions/roles/owner.go b/management/server/permissions/roles/owner.go
new file mode 100644
index 000000000..668470e47
--- /dev/null
+++ b/management/server/permissions/roles/owner.go
@@ -0,0 +1,16 @@
+package roles
+
+import (
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
+ "github.com/netbirdio/netbird/management/server/types"
+)
+
+var Owner = RolePermissions{
+ Role: types.UserRoleOwner,
+ AutoAllowNew: map[operations.Operation]bool{
+ operations.Read: true,
+ operations.Create: true,
+ operations.Update: true,
+ operations.Delete: true,
+ },
+}
diff --git a/management/server/permissions/roles/role_permissions.go b/management/server/permissions/roles/role_permissions.go
new file mode 100644
index 000000000..dda7e6b99
--- /dev/null
+++ b/management/server/permissions/roles/role_permissions.go
@@ -0,0 +1,21 @@
+package roles
+
+import (
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
+ "github.com/netbirdio/netbird/management/server/types"
+)
+
+type RolePermissions struct {
+ Role types.UserRole
+ Permissions Permissions
+ AutoAllowNew map[operations.Operation]bool
+}
+
+type Permissions map[modules.Module]map[operations.Operation]bool
+
+var RolesMap = map[types.UserRole]RolePermissions{
+ types.UserRoleOwner: Owner,
+ types.UserRoleAdmin: Admin,
+ types.UserRoleUser: User,
+}
diff --git a/management/server/permissions/roles/user.go b/management/server/permissions/roles/user.go
new file mode 100644
index 000000000..bb3df0aea
--- /dev/null
+++ b/management/server/permissions/roles/user.go
@@ -0,0 +1,16 @@
+package roles
+
+import (
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
+ "github.com/netbirdio/netbird/management/server/types"
+)
+
+var User = RolePermissions{
+ Role: types.UserRoleUser,
+ AutoAllowNew: map[operations.Operation]bool{
+ operations.Read: false,
+ operations.Create: false,
+ operations.Update: false,
+ operations.Delete: false,
+ },
+}
diff --git a/management/server/policy.go b/management/server/policy.go
index 15111ba06..1e9331d43 100644
--- a/management/server/policy.go
+++ b/management/server/policy.go
@@ -7,6 +7,8 @@ import (
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/proto"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -17,38 +19,32 @@ import (
// GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- if user.IsRegularUser() {
- return nil, status.NewAdminPermissionError()
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID)
}
// SavePolicy in the store
-func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) {
+func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ operation := operations.Create
+ if !create {
+ operation = operations.Update
+ }
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operation)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- if user.IsRegularUser() {
- return nil, status.NewAdminPermissionError()
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
var isUpdate = policy.ID != ""
@@ -95,17 +91,12 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Delete)
if err != nil {
- return err
+ return status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return err
- }
-
- if user.IsRegularUser() {
- return status.NewAdminPermissionError()
+ if !allowed {
+ return status.NewPermissionDeniedError()
}
var policy *types.Policy
@@ -143,17 +134,12 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
// ListPolicies from the store.
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- if user.IsRegularUser() {
- return nil, status.NewAdminPermissionError()
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
diff --git a/management/server/policy_test.go b/management/server/policy_test.go
index 10b7fc2d1..0c1160cda 100644
--- a/management/server/policy_test.go
+++ b/management/server/policy_test.go
@@ -883,7 +883,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Name: "GroupD",
Peers: []string{peer1.ID, peer2.ID},
},
- })
+ }, true)
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -915,7 +915,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: types.PolicyTrafficActionAccept,
},
},
- })
+ }, true)
assert.NoError(t, err)
select {
@@ -947,7 +947,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: types.PolicyTrafficActionAccept,
},
},
- })
+ }, true)
assert.NoError(t, err)
select {
@@ -979,7 +979,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: types.PolicyTrafficActionAccept,
},
},
- })
+ }, true)
assert.NoError(t, err)
select {
@@ -1010,7 +1010,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: types.PolicyTrafficActionAccept,
},
},
- })
+ }, true)
assert.NoError(t, err)
select {
@@ -1030,7 +1030,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
}()
policyWithSourceAndDestinationPeers.Enabled = false
- policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
+ policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers, true)
assert.NoError(t, err)
select {
@@ -1051,7 +1051,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
policyWithSourceAndDestinationPeers.Description = "updated description"
policyWithSourceAndDestinationPeers.Rules[0].Destinations = []string{"groupA"}
- policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
+ policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers, true)
assert.NoError(t, err)
select {
@@ -1071,7 +1071,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
}()
policyWithSourceAndDestinationPeers.Enabled = true
- policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
+ policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers, true)
assert.NoError(t, err)
select {
diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go
index 859ae6332..f91e89b45 100644
--- a/management/server/posture_checks.go
+++ b/management/server/posture_checks.go
@@ -10,6 +10,8 @@ import (
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server/activity"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
@@ -17,38 +19,32 @@ import (
)
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- if !user.HasAdminPower() {
- return nil, status.NewAdminPermissionError()
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID)
}
// SavePostureChecks saves a posture check.
-func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
+func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ operation := operations.Create
+ if !create {
+ operation = operations.Update
+ }
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operation)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- if !user.HasAdminPower() {
- return nil, status.NewAdminPermissionError()
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
var updateAccountPeers bool
@@ -94,17 +90,12 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
if err != nil {
- return err
+ return status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return err
- }
-
- if !user.HasAdminPower() {
- return status.NewAdminPermissionError()
+ if !allowed {
+ return status.NewPermissionDeniedError()
}
var postureChecks *posture.Checks
@@ -136,17 +127,12 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
// ListPostureChecks returns a list of posture checks.
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- if !user.HasAdminPower() {
- return nil, status.NewAdminPermissionError()
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID)
diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go
index bad162f05..232955f7d 100644
--- a/management/server/posture_checks_test.go
+++ b/management/server/posture_checks_test.go
@@ -33,7 +33,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
t.Run("Generic posture check flow", func(t *testing.T) {
// regular users can not create checks
- _, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{})
+ _, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}, true)
assert.Error(t, err)
// regular users cannot list check
@@ -48,7 +48,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
MinVersion: "0.26.0",
},
},
- })
+ }, true)
assert.NoError(t, err)
// admin users can list check
@@ -68,7 +68,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
},
},
},
- })
+ }, true)
assert.Error(t, err)
// admins can update posture checks
@@ -77,7 +77,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
MinVersion: "0.27.0",
},
}
- _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck)
+ _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck, true)
assert.NoError(t, err)
// users should not be able to delete posture checks
@@ -137,7 +137,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Name: "GroupC",
Peers: []string{},
},
- })
+ }, true)
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -156,7 +156,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
},
},
}
- postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA)
+ postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA, true)
require.NoError(t, err)
postureCheckB := &posture.Checks{
@@ -177,7 +177,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
- postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
+ postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB, true)
assert.NoError(t, err)
select {
@@ -200,7 +200,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
MinVersion: "0.29.0",
},
}
- _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
+ _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB, true)
assert.NoError(t, err)
select {
@@ -232,7 +232,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
- policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
+ policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
assert.NoError(t, err)
select {
@@ -261,7 +261,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
- _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
+ _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB, true)
assert.NoError(t, err)
select {
@@ -280,7 +280,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}()
policy.SourcePostureChecks = []string{}
- _, err := manager.SavePolicy(context.Background(), account.Id, userID, policy)
+ _, err := manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
assert.NoError(t, err)
select {
@@ -308,7 +308,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}
})
- _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
+ _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB, true)
assert.NoError(t, err)
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
@@ -325,7 +325,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
},
},
SourcePostureChecks: []string{postureCheckB.ID},
- })
+ }, true)
assert.NoError(t, err)
done := make(chan struct{})
@@ -339,7 +339,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
MinVersion: "0.29.0",
},
}
- _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
+ _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB, true)
assert.NoError(t, err)
select {
@@ -369,7 +369,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
},
},
SourcePostureChecks: []string{postureCheckB.ID},
- })
+ }, true)
assert.NoError(t, err)
done := make(chan struct{})
@@ -383,7 +383,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
MinVersion: "0.29.0",
},
}
- _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
+ _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB, true)
assert.NoError(t, err)
select {
@@ -408,7 +408,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
},
},
SourcePostureChecks: []string{postureCheckB.ID},
- })
+ }, true)
assert.NoError(t, err)
done := make(chan struct{})
@@ -426,7 +426,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
},
},
}
- _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
+ _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB, true)
assert.NoError(t, err)
select {
@@ -465,7 +465,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
},
}
- postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA)
+ postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA, true)
require.NoError(t, err, "failed to save postureCheckA")
postureCheckB := &posture.Checks{
@@ -475,7 +475,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
},
}
- postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB)
+ postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB, true)
require.NoError(t, err, "failed to save postureCheckB")
policy := &types.Policy{
@@ -490,7 +490,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
SourcePostureChecks: []string{postureCheckA.ID},
}
- policy, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
+ policy, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true)
require.NoError(t, err, "failed to save policy")
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
@@ -514,7 +514,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
policy.Rules[0].Sources = []string{"groupB"}
policy.Rules[0].Destinations = []string{"groupA"}
- _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
+ _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
@@ -525,7 +525,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
policy.Rules[0].Sources = []string{"groupA"}
policy.Rules[0].Destinations = []string{"groupB"}
- _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
+ _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
@@ -546,7 +546,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
policy.Rules[0].Sources = []string{"nonExistentGroup"}
policy.Rules[0].Destinations = []string{"nonExistentGroup"}
- _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
+ _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
diff --git a/management/server/route.go b/management/server/route.go
index abf20743a..8b91e127a 100644
--- a/management/server/route.go
+++ b/management/server/route.go
@@ -8,6 +8,8 @@ import (
"github.com/rs/xid"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -20,17 +22,12 @@ import (
// GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- if !user.IsAdminOrServiceUser() {
- return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID)
@@ -123,13 +120,12 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Create)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err = am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
account, err := am.Store.GetAccount(ctx, accountID)
@@ -242,13 +238,12 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
if err != nil {
- return err
+ return status.NewPermissionValidationError(err)
}
-
- if err = am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return err
+ if !allowed {
+ return status.NewPermissionDeniedError()
}
account, err := am.Store.GetAccount(ctx, accountID)
@@ -318,13 +313,12 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
if err != nil {
- return err
+ return status.NewPermissionValidationError(err)
}
-
- if err = am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return err
+ if !allowed {
+ return status.NewPermissionDeniedError()
}
account, err := am.Store.GetAccount(ctx, accountID)
@@ -354,17 +348,12 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- if !user.IsAdminOrServiceUser() {
- return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
diff --git a/management/server/route_test.go b/management/server/route_test.go
index 699c1304b..dcda3e6d1 100644
--- a/management/server/route_test.go
+++ b/management/server/route_test.go
@@ -1215,7 +1215,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
Name: "peer1 group",
Peers: []string{peer1ID},
}
- err = am.SaveGroup(context.Background(), account.Id, userID, newGroup)
+ err = am.SaveGroup(context.Background(), account.Id, userID, newGroup, true)
require.NoError(t, err)
rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser")
@@ -1227,7 +1227,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
newPolicy.Rules[0].Sources = []string{newGroup.ID}
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
- _, err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy)
+ _, err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, true)
require.NoError(t, err)
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)
@@ -1260,7 +1260,6 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
- permissionsManagerMock := permissions.NewManagerMock()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
@@ -1283,7 +1282,9 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
AnyTimes().
Return(&types.ExtraSettings{}, nil)
- return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
+ permissionsManager := permissions.NewManager(store)
+
+ return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
}
func createRouterStore(t *testing.T) (store.Store, error) {
@@ -1504,7 +1505,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou
}
for _, group := range newGroup {
- err = am.SaveGroup(context.Background(), accountID, userID, group)
+ err = am.SaveGroup(context.Background(), accountID, userID, group, true)
if err != nil {
return nil, err
}
@@ -1958,7 +1959,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
Name: "GroupC",
Peers: []string{},
},
- })
+ }, true)
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID)
@@ -2142,7 +2143,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
ID: "groupB",
Name: "GroupB",
Peers: []string{peer1ID},
- })
+ }, true)
assert.NoError(t, err)
select {
@@ -2182,7 +2183,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
ID: "groupC",
Name: "GroupC",
Peers: []string{peer1ID},
- })
+ }, true)
assert.NoError(t, err)
select {
diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go
index 2b3f4877b..94392ebf7 100644
--- a/management/server/settings/manager.go
+++ b/management/server/settings/manager.go
@@ -9,6 +9,8 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
"github.com/netbirdio/netbird/management/server/permissions"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -44,7 +46,7 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager {
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
if userID != activity.SystemInitiator {
- ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Settings, permissions.Read)
+ ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
diff --git a/management/server/setupkey.go b/management/server/setupkey.go
index 8b73a7d1e..b0903c8d0 100644
--- a/management/server/setupkey.go
+++ b/management/server/setupkey.go
@@ -8,6 +8,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -56,17 +58,12 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Create)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- if user.IsRegularUser() {
- return nil, status.NewAdminPermissionError()
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
var setupKey *types.SetupKey
@@ -113,17 +110,12 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Update)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- if user.IsRegularUser() {
- return nil, status.NewAdminPermissionError()
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
var oldKey *types.SetupKey
@@ -175,17 +167,12 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
// ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Read)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- if user.IsRegularUser() {
- return nil, status.NewAdminPermissionError()
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID)
@@ -193,17 +180,12 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Read)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return nil, err
- }
-
- if user.IsRegularUser() {
- return nil, status.NewAdminPermissionError()
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID)
@@ -221,17 +203,12 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
// DeleteSetupKey removes the setup key from the account
func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Delete)
if err != nil {
- return err
+ return status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
- return err
- }
-
- if user.IsRegularUser() {
- return status.NewAdminPermissionError()
+ if !allowed {
+ return status.NewPermissionDeniedError()
}
var deletedSetupKey *types.SetupKey
diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go
index 6e1e1cf7d..a561de40d 100644
--- a/management/server/setupkey_test.go
+++ b/management/server/setupkey_test.go
@@ -41,7 +41,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
Name: "group_name_2",
Peers: []string{},
},
- })
+ }, true)
if err != nil {
t.Fatal(err)
}
@@ -109,7 +109,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
ID: "group_1",
Name: "group_name_1",
Peers: []string{},
- })
+ }, true)
if err != nil {
t.Fatal(err)
}
@@ -118,7 +118,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
ID: "group_2",
Name: "group_name_2",
Peers: []string{},
- })
+ }, true)
if err != nil {
t.Fatal(err)
}
@@ -403,7 +403,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
- })
+ }, true)
assert.NoError(t, err)
policy := &types.Policy{
@@ -418,7 +418,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
},
},
}
- _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
+ _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
diff --git a/management/server/status/error.go b/management/server/status/error.go
index 5ab6f4e9e..8fbe0bad9 100644
--- a/management/server/status/error.go
+++ b/management/server/status/error.go
@@ -3,6 +3,8 @@ package status
import (
"errors"
"fmt"
+
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
)
const (
@@ -98,6 +100,11 @@ func NewUserNotFoundError(userKey string) error {
return Errorf(NotFound, "user: %s not found", userKey)
}
+// NewUserBlockedError creates a new Error with PermissionDenied type for a blocked user
+func NewUserBlockedError() error {
+ return Errorf(PermissionDenied, "user is blocked")
+}
+
// NewPeerNotRegisteredError creates a new Error with NotFound type for a missing peer
func NewPeerNotRegisteredError() error {
return Errorf(Unauthenticated, "peer is not registered")
@@ -212,3 +219,11 @@ func NewPATNotFoundError(patID string) error {
func NewExtraSettingsNotFoundError() error {
return ErrExtraSettingsNotFound
}
+
+func NewUserRoleNotFoundError(role string) error {
+ return Errorf(NotFound, "user role: %s not found", role)
+}
+
+func NewOperationNotFoundError(operation operations.Operation) error {
+ return Errorf(NotFound, "operation: %s not found", operation)
+}
diff --git a/management/server/testdata/networks.sql b/management/server/testdata/networks.sql
index 8138ce520..bcb202084 100644
--- a/management/server/testdata/networks.sql
+++ b/management/server/testdata/networks.sql
@@ -16,3 +16,7 @@ INSERT INTO network_routers VALUES('testRouterId','testNetworkId','testAccountId
CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,`name` text,`description` text,`type` text,`address` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_resources` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO network_resources VALUES('testResourceId','testNetworkId','testAccountId','some-name','some-description','host','3.3.3.3/32');
INSERT INTO network_resources VALUES('anotherTestResourceId','testNetworkId','testAccountId','used-name','some-description','host','3.3.3.3/32');
+
+CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
+INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
diff --git a/management/server/types/config.go b/management/server/types/config.go
index d2e418264..7a16b20a1 100644
--- a/management/server/types/config.go
+++ b/management/server/types/config.go
@@ -154,6 +154,8 @@ type ProviderConfig struct {
UseIDToken bool
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
+ // DisablePromptLogin makes the PKCE flow to not prompt the user for login
+ DisablePromptLogin bool
}
// StoreConfig contains Store configuration
diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go
index de7dd57df..a85650136 100644
--- a/management/server/updatechannel.go
+++ b/management/server/updatechannel.go
@@ -42,10 +42,10 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
start := time.Now()
var found, dropped bool
- p.channelsMux.Lock()
+ p.channelsMux.RLock()
defer func() {
- p.channelsMux.Unlock()
+ p.channelsMux.RUnlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountSendUpdateDuration(time.Since(start), found, dropped)
}
@@ -141,12 +141,12 @@ func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) {
func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} {
start := time.Now()
- p.channelsMux.Lock()
+ p.channelsMux.RLock()
m := make(map[string]struct{})
defer func() {
- p.channelsMux.Unlock()
+ p.channelsMux.RUnlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountGetAllConnectedPeersDuration(time.Since(start), len(m))
}
@@ -163,10 +163,10 @@ func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} {
func (p *PeersUpdateManager) HasChannel(peerID string) bool {
start := time.Now()
- p.channelsMux.Lock()
+ p.channelsMux.RLock()
defer func() {
- p.channelsMux.Unlock()
+ p.channelsMux.RUnlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountHasChannelDuration(time.Since(start))
}
diff --git a/management/server/user.go b/management/server/user.go
index c446bd8ea..9ec16e72c 100644
--- a/management/server/user.go
+++ b/management/server/user.go
@@ -14,6 +14,8 @@ import (
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -25,17 +27,12 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
- return nil, err
- }
-
- if !initiatorUser.HasAdminPower() {
- return nil, status.NewAdminPermissionError()
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
if role == types.UserRoleOwner {
@@ -88,12 +85,16 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
return nil, err
}
- initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Users, operations.Create)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
+ }
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
+ initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ if err != nil {
return nil, err
}
@@ -237,12 +238,12 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
return err
}
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
- return err
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete)
+ if err != nil {
+ return status.NewPermissionValidationError(err)
}
-
- if !initiatorUser.HasAdminPower() {
- return status.NewAdminPermissionError()
+ if !allowed {
+ return status.NewPermissionDeniedError()
}
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
@@ -294,13 +295,12 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
return status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites")
}
- initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create)
if err != nil {
- return err
+ return status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
- return err
+ if !allowed {
+ return status.NewPermissionDeniedError()
}
// check if the user is already registered with this ID
@@ -342,12 +342,16 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365")
}
- initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Create)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
+ }
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
+ initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
+ if err != nil {
return nil, err
}
@@ -380,25 +384,29 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Delete)
+ if err != nil {
+ return status.NewPermissionValidationError(err)
+ }
+ if !allowed {
+ return status.NewPermissionDeniedError()
+ }
+
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
if err != nil {
return err
}
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
- return err
- }
-
- if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() {
- return status.NewAdminPermissionError()
- }
-
- pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID)
+ targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
if err != nil {
return err
}
- targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
+ if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) {
+ return status.NewAdminPermissionError()
+ }
+
+ pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID)
if err != nil {
return err
}
@@ -415,16 +423,25 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
// GetPAT returns a specific PAT from a user
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) {
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Read)
+ if err != nil {
+ return nil, status.NewPermissionValidationError(err)
+ }
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
+ }
+
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
if err != nil {
return nil, err
}
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
+ targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
+ if err != nil {
return nil, err
}
- if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() {
+ if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) {
return nil, status.NewAdminPermissionError()
}
@@ -433,16 +450,25 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
// GetAllPATs returns all PATs for a user
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) {
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Read)
+ if err != nil {
+ return nil, status.NewPermissionValidationError(err)
+ }
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
+ }
+
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
if err != nil {
return nil, err
}
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
+ targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
+ if err != nil {
return nil, err
}
- if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() {
+ if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) {
return nil, status.NewAdminPermissionError()
}
@@ -480,19 +506,13 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
return nil, nil //nolint:nilnil
}
- initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create) // TODO: split by Create and Update
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
-
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
- return nil, err
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
}
-
- if !initiatorUser.HasAdminPower() || initiatorUser.IsBlocked() {
- return nil, status.NewAdminPermissionError()
- }
-
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, err
@@ -513,6 +533,11 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
groupsMap[group.ID] = group
}
+ initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
+ if err != nil {
+ return nil, err
+ }
+
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, update := range updates {
if update == nil {
@@ -795,33 +820,37 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
// based on provided user role.
func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, initiatorUserID string) (map[string]*types.UserInfo, error) {
- accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Read)
if err != nil {
- return nil, err
+ return nil, status.NewPermissionValidationError(err)
}
- initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
+ user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("failed to get user: %w", err)
}
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
- return nil, err
+ accountUsers := []*types.User{}
+ switch {
+ case allowed:
+ accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
+ if err != nil {
+ return nil, err
+ }
+ case user.AccountID == accountID:
+ accountUsers = append(accountUsers, user)
+ default:
+ return map[string]*types.UserInfo{}, nil
}
return am.BuildUserInfosForAccount(ctx, accountID, initiatorUserID, accountUsers)
}
// BuildUserInfosForAccount builds user info for the given account.
-func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) {
+func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, accountID, _ string, accountUsers []*types.User) (map[string]*types.UserInfo, error) {
var queriedUsers []*idp.UserData
var err error
- initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
- if err != nil {
- return nil, err
- }
-
if !isNil(am.idpManager) {
users := make(map[string]userLoggedInOnce, len(accountUsers))
usersFromIntegration := make([]*idp.UserData, 0)
@@ -860,11 +889,6 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
if len(queriedUsers) == 0 {
for _, accountUser := range accountUsers {
- if initiatorUser.IsRegularUser() && initiatorUser.Id != accountUser.Id {
- // if user is not an admin then show only current user and do not show other users
- continue
- }
-
info, err := accountUser.ToUserInfo(nil, settings)
if err != nil {
return nil, err
@@ -876,11 +900,6 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
}
for _, localUser := range accountUsers {
- if initiatorUser.IsRegularUser() && initiatorUser.Id != localUser.Id {
- // if user is not an admin then show only current user and do not show other users
- continue
- }
-
var info *types.UserInfo
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
info, err = localUser.ToUserInfo(queriedUser, settings)
@@ -977,19 +996,19 @@ func (am *DefaultAccountManager) deleteUserFromIDP(ctx context.Context, targetUs
// If an error occurs while deleting the user, the function skips it and continues deleting other users.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error {
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete)
+ if err != nil {
+ return status.NewPermissionValidationError(err)
+ }
+ if !allowed {
+ return status.NewPermissionDeniedError()
+ }
+
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
if err != nil {
return err
}
- if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
- return err
- }
-
- if !initiatorUser.HasAdminPower() {
- return status.NewAdminPermissionError()
- }
-
var allErrors error
var updateAccountPeers bool
@@ -1213,3 +1232,30 @@ func validateUserInvite(invite *types.UserInfo) error {
return nil
}
+
+// GetCurrentUserInfo retrieves the account's current user info
+func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error) {
+ user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+ if err != nil {
+ return nil, err
+ }
+
+ if user.IsBlocked() {
+ return nil, status.NewUserBlockedError()
+ }
+
+ if user.IsServiceUser {
+ return nil, status.NewPermissionDeniedError()
+ }
+
+ if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
+ return nil, err
+ }
+
+ userInfo, err := am.getUserInfo(ctx, user, accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ return userInfo, nil
+}
diff --git a/management/server/user_test.go b/management/server/user_test.go
index d3344738b..83c5ac49a 100644
--- a/management/server/user_test.go
+++ b/management/server/user_test.go
@@ -13,6 +13,7 @@ import (
nbcache "github.com/netbirdio/netbird/management/server/cache"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
+ "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/util"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -59,11 +60,11 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
am := DefaultAccountManager{
Store: s,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
@@ -109,11 +110,11 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn)
@@ -137,11 +138,11 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn)
@@ -166,11 +167,11 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockWrongExpiresIn)
@@ -191,11 +192,11 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn)
@@ -218,17 +219,18 @@ func TestUser_DeletePAT(t *testing.T) {
HashedToken: mockToken1,
},
},
+ Role: types.UserRoleAdmin,
}
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
err = am.DeletePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1)
@@ -261,17 +263,18 @@ func TestUser_GetPAT(t *testing.T) {
HashedToken: mockToken1,
},
},
+ Role: types.UserRoleAdmin,
}
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
pat, err := am.GetPAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1)
@@ -304,17 +307,18 @@ func TestUser_GetAllPATs(t *testing.T) {
HashedToken: mockToken2,
},
},
+ Role: types.UserRoleAdmin,
}
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
pats, err := am.GetAllPATs(context.Background(), mockAccountID, mockUserID, mockUserID)
@@ -406,11 +410,11 @@ func TestUser_CreateServiceUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
user, err := am.createServiceUser(context.Background(), mockAccountID, mockUserID, mockRole, mockServiceUserName, false, []string{"group1", "group2"})
@@ -453,11 +457,11 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{
@@ -501,11 +505,11 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
_, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{
@@ -532,12 +536,12 @@ func TestUser_InviteNewUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
cacheLoading: map[string]chan struct{}{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval)
@@ -640,11 +644,11 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockServiceUserID)
@@ -678,11 +682,11 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockUserID)
@@ -732,12 +736,11 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
- Store: store,
- eventStore: &activity.InMemoryEventStore{},
- integratedPeerValidator: MocIntegratedValidator{},
- permissionsManager: permissionsMananagerMock,
+ Store: store,
+ eventStore: &activity.InMemoryEventStore{},
+ permissionsManager: permissionsManager,
}
testCases := []struct {
@@ -842,12 +845,12 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
integratedPeerValidator: MocIntegratedValidator{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
testCases := []struct {
@@ -953,11 +956,11 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
claims := nbcontext.UserAuth{
@@ -991,11 +994,11 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
users, err := am.ListUsers(context.Background(), mockAccountID)
@@ -1080,11 +1083,11 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
users, err := am.ListUsers(context.Background(), mockAccountID)
@@ -1125,13 +1128,13 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
idpManager: &idp.GoogleWorkspaceManager{}, // empty manager
cacheLoading: map[string]chan struct{}{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval)
@@ -1188,11 +1191,11 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID)
@@ -1222,11 +1225,11 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
- permissionsMananagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsMananagerMock,
+ permissionsManager: permissionsManager,
}
users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockServiceUserID)
@@ -1417,7 +1420,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
- })
+ }, true)
require.NoError(t, err)
policy := &types.Policy{
@@ -1432,7 +1435,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
},
},
}
- _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
+ _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -1589,13 +1592,11 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "")
require.NoError(t, s.SaveAccount(context.Background(), account2))
- permissionsManagerMock := permissions.NewManagerMock()
+ permissionsManager := permissions.NewManager(s)
am := DefaultAccountManager{
Store: s,
eventStore: &activity.InMemoryEventStore{},
- idpManager: nil,
- cacheLoading: map[string]chan struct{}{},
- permissionsManager: permissionsManagerMock,
+ permissionsManager: permissionsManager,
}
_, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true)
@@ -1607,3 +1608,175 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID)
assert.Equal(t, account1.Users[targetId].AutoGroups, user.AutoGroups)
}
+
+func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
+ store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
+ if err != nil {
+ t.Fatalf("Error when creating store: %s", err)
+ }
+ t.Cleanup(cleanup)
+
+ account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "")
+ account1.Settings.RegularUsersViewBlocked = false
+ account1.Users["blocked-user"] = &types.User{
+ Id: "blocked-user",
+ AccountID: account1.Id,
+ Blocked: true,
+ }
+ account1.Users["service-user"] = &types.User{
+ Id: "service-user",
+ IsServiceUser: true,
+ ServiceUserName: "service-user",
+ }
+ account1.Users["regular-user"] = &types.User{
+ Id: "regular-user",
+ Role: types.UserRoleUser,
+ }
+ account1.Users["admin-user"] = &types.User{
+ Id: "admin-user",
+ Role: types.UserRoleAdmin,
+ }
+ require.NoError(t, store.SaveAccount(context.Background(), account1))
+
+ account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "")
+ account2.Users["settings-blocked-user"] = &types.User{
+ Id: "settings-blocked-user",
+ Role: types.UserRoleUser,
+ }
+ require.NoError(t, store.SaveAccount(context.Background(), account2))
+
+ permissionsManager := permissions.NewManager(store)
+ am := DefaultAccountManager{
+ Store: store,
+ eventStore: &activity.InMemoryEventStore{},
+ permissionsManager: permissionsManager,
+ }
+
+ tt := []struct {
+ name string
+ accountId string
+ userId string
+ expectedErr error
+ expectedResult *types.UserInfo
+ }{
+ {
+ name: "not found",
+ accountId: account1.Id,
+ userId: "not-found",
+ expectedErr: status.NewUserNotFoundError("not-found"),
+ },
+ {
+ name: "not part of account",
+ accountId: account1.Id,
+ userId: "account2Owner",
+ expectedErr: status.NewUserNotPartOfAccountError(),
+ },
+ {
+ name: "blocked",
+ accountId: account1.Id,
+ userId: "blocked-user",
+ expectedErr: status.NewUserBlockedError(),
+ },
+ {
+ name: "service user",
+ accountId: account1.Id,
+ userId: "service-user",
+ expectedErr: status.NewPermissionDeniedError(),
+ },
+ {
+ name: "owner user",
+ accountId: account1.Id,
+ userId: "account1Owner",
+ expectedResult: &types.UserInfo{
+ ID: "account1Owner",
+ Name: "",
+ Role: "owner",
+ AutoGroups: []string{},
+ Status: "active",
+ IsServiceUser: false,
+ IsBlocked: false,
+ NonDeletable: false,
+ LastLogin: time.Time{},
+ Issued: "api",
+ IntegrationReference: integration_reference.IntegrationReference{},
+ Permissions: types.UserPermissions{
+ DashboardView: "full",
+ },
+ },
+ },
+ {
+ name: "regular user",
+ accountId: account1.Id,
+ userId: "regular-user",
+ expectedResult: &types.UserInfo{
+ ID: "regular-user",
+ Name: "",
+ Role: "user",
+ Status: "active",
+ IsServiceUser: false,
+ IsBlocked: false,
+ NonDeletable: false,
+ LastLogin: time.Time{},
+ Issued: "api",
+ IntegrationReference: integration_reference.IntegrationReference{},
+ Permissions: types.UserPermissions{
+ DashboardView: "limited",
+ },
+ },
+ },
+ {
+ name: "admin user",
+ accountId: account1.Id,
+ userId: "admin-user",
+ expectedResult: &types.UserInfo{
+ ID: "admin-user",
+ Name: "",
+ Role: "admin",
+ Status: "active",
+ IsServiceUser: false,
+ IsBlocked: false,
+ NonDeletable: false,
+ LastLogin: time.Time{},
+ Issued: "api",
+ IntegrationReference: integration_reference.IntegrationReference{},
+ Permissions: types.UserPermissions{
+ DashboardView: "full",
+ },
+ },
+ },
+ {
+ name: "settings blocked regular user",
+ accountId: account2.Id,
+ userId: "settings-blocked-user",
+ expectedResult: &types.UserInfo{
+ ID: "settings-blocked-user",
+ Name: "",
+ Role: "user",
+ Status: "active",
+ IsServiceUser: false,
+ IsBlocked: false,
+ NonDeletable: false,
+ LastLogin: time.Time{},
+ Issued: "api",
+ IntegrationReference: integration_reference.IntegrationReference{},
+ Permissions: types.UserPermissions{
+ DashboardView: "blocked",
+ },
+ },
+ },
+ }
+
+ for _, tc := range tt {
+ t.Run(tc.name, func(t *testing.T) {
+ result, err := am.GetCurrentUserInfo(context.Background(), tc.accountId, tc.userId)
+
+ if tc.expectedErr != nil {
+ assert.Equal(t, err, tc.expectedErr)
+ return
+ }
+
+ require.NoError(t, err)
+ assert.EqualValues(t, tc.expectedResult, result)
+ })
+ }
+}
diff --git a/proto/management/management.proto b/proto/management/management.proto
index 8770c613c..4bb66b50b 100644
--- a/proto/management/management.proto
+++ b/proto/management/management.proto
@@ -368,6 +368,8 @@ message ProviderConfig {
string AuthorizationEndpoint = 9;
// RedirectURLs handles authorization code from IDP manager
repeated string RedirectURLs = 10;
+ // DisablePromptLogin makes the PKCE flow to not prompt the user for login
+ bool DisablePromptLogin = 11;
}
// Route represents a route.Route object
diff --git a/release_files/install.sh b/release_files/install.sh
index 459645c58..e5a61dcfe 100755
--- a/release_files/install.sh
+++ b/release_files/install.sh
@@ -109,6 +109,9 @@ add_apt_repo() {
curl -sSL https://pkgs.netbird.io/debian/public.key \
| ${SUDO} gpg --dearmor -o /usr/share/keyrings/netbird-archive-keyring.gpg
+ # Explicitly set the file permission
+ ${SUDO} chmod 0644 /usr/share/keyrings/netbird-archive-keyring.gpg
+
echo 'deb [signed-by=/usr/share/keyrings/netbird-archive-keyring.gpg] https://pkgs.netbird.io/debian stable main' \
| ${SUDO} tee /etc/apt/sources.list.d/netbird.list