+
+
+
+
+
- {{ .Error }}.
-
- {{ else }}
-
-
-
- Login successful
+
+
+
+
+
+ {{ if .Error }}
+
+
+ {{ else }}
+
+
+ {{ end }}
+
+
+
+ {{ if .Error }}
+
Login Failed
+ {{ else }}
+ Login Successful
+ {{ end }}
+
+
+
+ {{ if .Error }}
+
+ {{ .Error }}
+
+ {{ else }}
+
+
+ Your device is now registered and logged in to NetBird. You can now close this window.
+
+ {{ end }}
+
+
- Your device is now registered and logged in to NetBird.
-
- You can now close this window.
- {{ end }}
+
diff --git a/client/internal/templates/pkce_auth_msg_test.go b/client/internal/templates/pkce_auth_msg_test.go
new file mode 100644
index 000000000..75b1c9e76
--- /dev/null
+++ b/client/internal/templates/pkce_auth_msg_test.go
@@ -0,0 +1,299 @@
+package templates
+
+import (
+ "html/template"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestPKCEAuthMsgTemplate(t *testing.T) {
+ tests := []struct {
+ name string
+ data map[string]string
+ outputFile string
+ expectedTitle string
+ expectedInContent []string
+ notExpectedInContent []string
+ }{
+ {
+ name: "error_state",
+ data: map[string]string{
+ "Error": "authentication failed: invalid state",
+ },
+ outputFile: "pkce-auth-error.html",
+ expectedTitle: "Login Failed",
+ expectedInContent: []string{
+ "authentication failed: invalid state",
+ "Login Failed",
+ },
+ notExpectedInContent: []string{
+ "Login Successful",
+ "Your device is now registered and logged in to NetBird",
+ },
+ },
+ {
+ name: "success_state",
+ data: map[string]string{
+ // No error field means success
+ },
+ outputFile: "pkce-auth-success.html",
+ expectedTitle: "Login Successful",
+ expectedInContent: []string{
+ "Login Successful",
+ "Your device is now registered and logged in to NetBird. You can now close this window.",
+ },
+ notExpectedInContent: []string{
+ "Login Failed",
+ },
+ },
+ {
+ name: "error_state_timeout",
+ data: map[string]string{
+ "Error": "authentication timeout: request expired after 5 minutes",
+ },
+ outputFile: "pkce-auth-timeout.html",
+ expectedTitle: "Login Failed",
+ expectedInContent: []string{
+ "authentication timeout: request expired after 5 minutes",
+ "Login Failed",
+ },
+ notExpectedInContent: []string{
+ "Login Successful",
+ "Your device is now registered and logged in to NetBird",
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Parse the template
+ tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
+ if err != nil {
+ t.Fatalf("Failed to parse template: %v", err)
+ }
+
+ // Create temp directory for this test
+ tempDir := t.TempDir()
+ outputPath := filepath.Join(tempDir, tt.outputFile)
+
+ // Create output file
+ file, err := os.Create(outputPath)
+ if err != nil {
+ t.Fatalf("Failed to create output file: %v", err)
+ }
+
+ // Execute the template
+ if err := tmpl.Execute(file, tt.data); err != nil {
+ file.Close()
+ t.Fatalf("Failed to execute template: %v", err)
+ }
+ file.Close()
+
+ t.Logf("Generated test output: %s", outputPath)
+
+ // Read the generated file
+ content, err := os.ReadFile(outputPath)
+ if err != nil {
+ t.Fatalf("Failed to read output file: %v", err)
+ }
+
+ contentStr := string(content)
+
+ // Verify file has content
+ if len(contentStr) == 0 {
+ t.Error("Output file is empty")
+ }
+
+ // Verify basic HTML structure
+ basicElements := []string{
+ "",
+ "",
+ "",
+ "NetBird",
+ }
+
+ for _, elem := range basicElements {
+ if !contains(contentStr, elem) {
+ t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
+ }
+ }
+
+ // Verify expected title
+ if !contains(contentStr, tt.expectedTitle) {
+ t.Errorf("Expected HTML to contain title '%s', but it was not found", tt.expectedTitle)
+ }
+
+ // Verify expected content is present
+ for _, expected := range tt.expectedInContent {
+ if !contains(contentStr, expected) {
+ t.Errorf("Expected HTML to contain '%s', but it was not found", expected)
+ }
+ }
+
+ // Verify unexpected content is not present
+ for _, notExpected := range tt.notExpectedInContent {
+ if contains(contentStr, notExpected) {
+ t.Errorf("Expected HTML to NOT contain '%s', but it was found", notExpected)
+ }
+ }
+ })
+ }
+}
+
+func TestPKCEAuthMsgTemplateValidation(t *testing.T) {
+ // Test that the template can be parsed without errors
+ tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
+ if err != nil {
+ t.Fatalf("Template parsing failed: %v", err)
+ }
+
+ // Test with empty data
+ t.Run("empty_data", func(t *testing.T) {
+ tempDir := t.TempDir()
+ outputPath := filepath.Join(tempDir, "empty-data.html")
+
+ file, err := os.Create(outputPath)
+ if err != nil {
+ t.Fatalf("Failed to create output file: %v", err)
+ }
+ defer file.Close()
+
+ if err := tmpl.Execute(file, nil); err != nil {
+ t.Errorf("Template execution with nil data failed: %v", err)
+ }
+ })
+
+ // Test with error data
+ t.Run("with_error", func(t *testing.T) {
+ tempDir := t.TempDir()
+ outputPath := filepath.Join(tempDir, "with-error.html")
+
+ file, err := os.Create(outputPath)
+ if err != nil {
+ t.Fatalf("Failed to create output file: %v", err)
+ }
+ defer file.Close()
+
+ data := map[string]string{
+ "Error": "test error message",
+ }
+ if err := tmpl.Execute(file, data); err != nil {
+ t.Errorf("Template execution with error data failed: %v", err)
+ }
+ })
+}
+
+func TestPKCEAuthMsgTemplateContent(t *testing.T) {
+ // Test that the template contains expected elements
+ tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
+ if err != nil {
+ t.Fatalf("Template parsing failed: %v", err)
+ }
+
+ t.Run("success_content", func(t *testing.T) {
+ tempDir := t.TempDir()
+ outputPath := filepath.Join(tempDir, "success.html")
+
+ file, err := os.Create(outputPath)
+ if err != nil {
+ t.Fatalf("Failed to create output file: %v", err)
+ }
+ defer file.Close()
+
+ data := map[string]string{}
+ if err := tmpl.Execute(file, data); err != nil {
+ t.Fatalf("Template execution failed: %v", err)
+ }
+
+ // Read the file and verify it contains expected content
+ content, err := os.ReadFile(outputPath)
+ if err != nil {
+ t.Fatalf("Failed to read output file: %v", err)
+ }
+
+ // Check for success indicators
+ contentStr := string(content)
+ if len(contentStr) == 0 {
+ t.Error("Generated HTML is empty")
+ }
+
+ // Basic HTML structure checks
+ requiredElements := []string{
+ "",
+ "",
+ "",
+ "Login Successful",
+ "NetBird",
+ }
+
+ for _, elem := range requiredElements {
+ if !contains(contentStr, elem) {
+ t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
+ }
+ }
+ })
+
+ t.Run("error_content", func(t *testing.T) {
+ tempDir := t.TempDir()
+ outputPath := filepath.Join(tempDir, "error.html")
+
+ file, err := os.Create(outputPath)
+ if err != nil {
+ t.Fatalf("Failed to create output file: %v", err)
+ }
+ defer file.Close()
+
+ errorMsg := "test error message"
+ data := map[string]string{
+ "Error": errorMsg,
+ }
+ if err := tmpl.Execute(file, data); err != nil {
+ t.Fatalf("Template execution failed: %v", err)
+ }
+
+ // Read the file and verify it contains expected content
+ content, err := os.ReadFile(outputPath)
+ if err != nil {
+ t.Fatalf("Failed to read output file: %v", err)
+ }
+
+ // Check for error indicators
+ contentStr := string(content)
+ if len(contentStr) == 0 {
+ t.Error("Generated HTML is empty")
+ }
+
+ // Basic HTML structure checks
+ requiredElements := []string{
+ "",
+ "",
+ "",
+ "Login Failed",
+ errorMsg,
+ }
+
+ for _, elem := range requiredElements {
+ if !contains(contentStr, elem) {
+ t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
+ }
+ }
+ })
+}
+
+func contains(s, substr string) bool {
+ return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
+ (len(s) > 0 && len(substr) > 0 && containsHelper(s, substr)))
+}
+
+func containsHelper(s, substr string) bool {
+ for i := 0; i <= len(s)-len(substr); i++ {
+ if s[i:i+len(substr)] == substr {
+ return true
+ }
+ }
+ return false
+}
diff --git a/client/internal/updatemanager/doc.go b/client/internal/updatemanager/doc.go
new file mode 100644
index 000000000..54d1bdeab
--- /dev/null
+++ b/client/internal/updatemanager/doc.go
@@ -0,0 +1,35 @@
+// Package updatemanager provides automatic update management for the NetBird client.
+// It monitors for new versions, handles update triggers from management server directives,
+// and orchestrates the download and installation of client updates.
+//
+// # Overview
+//
+// The update manager operates as a background service that continuously monitors for
+// available updates and automatically initiates the update process when conditions are met.
+// It integrates with the installer package to perform the actual installation.
+//
+// # Update Flow
+//
+// The complete update process follows these steps:
+//
+// 1. Manager receives update directive via SetVersion() or detects new version
+// 2. Manager validates update should proceed (version comparison, rate limiting)
+// 3. Manager publishes "updating" event to status recorder
+// 4. Manager persists UpdateState to track update attempt
+// 5. Manager downloads installer file (.msi or .exe) to temporary directory
+// 6. Manager triggers installation via installer.RunInstallation()
+// 7. Installer package handles the actual installation process
+// 8. On next startup, CheckUpdateSuccess() verifies update completion
+// 9. Manager publishes success/failure event to status recorder
+// 10. Manager cleans up UpdateState
+//
+// # State Management
+//
+// Update state is persisted across restarts to track update attempts:
+//
+// - PreUpdateVersion: Version before update attempt
+// - TargetVersion: Version attempting to update to
+//
+// This enables verification of successful updates and appropriate user notification
+// after the client restarts with the new version.
+package updatemanager
diff --git a/client/internal/updatemanager/downloader/downloader.go b/client/internal/updatemanager/downloader/downloader.go
new file mode 100644
index 000000000..2ac36efed
--- /dev/null
+++ b/client/internal/updatemanager/downloader/downloader.go
@@ -0,0 +1,138 @@
+package downloader
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/version"
+)
+
+const (
+ userAgent = "NetBird agent installer/%s"
+ DefaultRetryDelay = 3 * time.Second
+)
+
+func DownloadToFile(ctx context.Context, retryDelay time.Duration, url, dstFile string) error {
+ log.Debugf("starting download from %s", url)
+
+ out, err := os.Create(dstFile)
+ if err != nil {
+ return fmt.Errorf("failed to create destination file %q: %w", dstFile, err)
+ }
+ defer func() {
+ if cerr := out.Close(); cerr != nil {
+ log.Warnf("error closing file %q: %v", dstFile, cerr)
+ }
+ }()
+
+ // First attempt
+ err = downloadToFileOnce(ctx, url, out)
+ if err == nil {
+ log.Infof("successfully downloaded file to %s", dstFile)
+ return nil
+ }
+
+ // If retryDelay is 0, don't retry
+ if retryDelay == 0 {
+ return err
+ }
+
+ log.Warnf("download failed, retrying after %v: %v", retryDelay, err)
+
+ // Sleep before retry
+ if sleepErr := sleepWithContext(ctx, retryDelay); sleepErr != nil {
+ return fmt.Errorf("download cancelled during retry delay: %w", sleepErr)
+ }
+
+ // Truncate file before retry
+ if err := out.Truncate(0); err != nil {
+ return fmt.Errorf("failed to truncate file on retry: %w", err)
+ }
+ if _, err := out.Seek(0, 0); err != nil {
+ return fmt.Errorf("failed to seek to beginning of file: %w", err)
+ }
+
+ // Second attempt
+ if err := downloadToFileOnce(ctx, url, out); err != nil {
+ return fmt.Errorf("download failed after retry: %w", err)
+ }
+
+ log.Infof("successfully downloaded file to %s", dstFile)
+ return nil
+}
+
+func DownloadToMemory(ctx context.Context, url string, limit int64) ([]byte, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create HTTP request: %w", err)
+ }
+
+ // Add User-Agent header
+ req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to perform HTTP request: %w", err)
+ }
+ defer func() {
+ if cerr := resp.Body.Close(); cerr != nil {
+ log.Warnf("error closing response body: %v", cerr)
+ }
+ }()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
+ }
+
+ data, err := io.ReadAll(io.LimitReader(resp.Body, limit))
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+
+ return data, nil
+}
+
+func downloadToFileOnce(ctx context.Context, url string, out *os.File) error {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+ if err != nil {
+ return fmt.Errorf("failed to create HTTP request: %w", err)
+ }
+
+ // Add User-Agent header
+ req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return fmt.Errorf("failed to perform HTTP request: %w", err)
+ }
+ defer func() {
+ if cerr := resp.Body.Close(); cerr != nil {
+ log.Warnf("error closing response body: %v", cerr)
+ }
+ }()
+
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
+ }
+
+ if _, err := io.Copy(out, resp.Body); err != nil {
+ return fmt.Errorf("failed to write response body to file: %w", err)
+ }
+
+ return nil
+}
+
+func sleepWithContext(ctx context.Context, duration time.Duration) error {
+ select {
+ case <-time.After(duration):
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
diff --git a/client/internal/updatemanager/downloader/downloader_test.go b/client/internal/updatemanager/downloader/downloader_test.go
new file mode 100644
index 000000000..045db3a2d
--- /dev/null
+++ b/client/internal/updatemanager/downloader/downloader_test.go
@@ -0,0 +1,199 @@
+package downloader
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+const (
+ retryDelay = 100 * time.Millisecond
+)
+
+func TestDownloadToFile_Success(t *testing.T) {
+ // Create a test server that responds successfully
+ content := "test file content"
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte(content))
+ }))
+ defer server.Close()
+
+ // Create a temporary file for download
+ tempDir := t.TempDir()
+ dstFile := filepath.Join(tempDir, "downloaded.txt")
+
+ // Download the file
+ err := DownloadToFile(context.Background(), retryDelay, server.URL, dstFile)
+ if err != nil {
+ t.Fatalf("expected no error, got: %v", err)
+ }
+
+ // Verify the file content
+ data, err := os.ReadFile(dstFile)
+ if err != nil {
+ t.Fatalf("failed to read downloaded file: %v", err)
+ }
+
+ if string(data) != content {
+ t.Errorf("expected content %q, got %q", content, string(data))
+ }
+}
+
+func TestDownloadToFile_SuccessAfterRetry(t *testing.T) {
+ content := "test file content after retry"
+ var attemptCount atomic.Int32
+
+ // Create a test server that fails on first attempt, succeeds on second
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ attempt := attemptCount.Add(1)
+ if attempt == 1 {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte("error"))
+ return
+ }
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte(content))
+ }))
+ defer server.Close()
+
+ // Create a temporary file for download
+ tempDir := t.TempDir()
+ dstFile := filepath.Join(tempDir, "downloaded.txt")
+
+ // Download the file (should succeed after retry)
+ if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err != nil {
+ t.Fatalf("expected no error after retry, got: %v", err)
+ }
+
+ // Verify the file content
+ data, err := os.ReadFile(dstFile)
+ if err != nil {
+ t.Fatalf("failed to read downloaded file: %v", err)
+ }
+
+ if string(data) != content {
+ t.Errorf("expected content %q, got %q", content, string(data))
+ }
+
+ // Verify it took 2 attempts
+ if attemptCount.Load() != 2 {
+ t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
+ }
+}
+
+func TestDownloadToFile_FailsAfterRetry(t *testing.T) {
+ var attemptCount atomic.Int32
+
+ // Create a test server that always fails
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ attemptCount.Add(1)
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte("error"))
+ }))
+ defer server.Close()
+
+ // Create a temporary file for download
+ tempDir := t.TempDir()
+ dstFile := filepath.Join(tempDir, "downloaded.txt")
+
+ // Download the file (should fail after retry)
+ if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err == nil {
+ t.Fatal("expected error after retry, got nil")
+ }
+
+ // Verify it tried 2 times
+ if attemptCount.Load() != 2 {
+ t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
+ }
+}
+
+func TestDownloadToFile_ContextCancellationDuringRetry(t *testing.T) {
+ var attemptCount atomic.Int32
+
+ // Create a test server that always fails
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ attemptCount.Add(1)
+ w.WriteHeader(http.StatusInternalServerError)
+ }))
+ defer server.Close()
+
+ // Create a temporary file for download
+ tempDir := t.TempDir()
+ dstFile := filepath.Join(tempDir, "downloaded.txt")
+
+ // Create a context that will be cancelled during retry delay
+ ctx, cancel := context.WithCancel(context.Background())
+
+ // Cancel after a short delay (during the retry sleep)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ cancel()
+ }()
+
+ // Download the file (should fail due to context cancellation during retry)
+ err := DownloadToFile(ctx, 1*time.Second, server.URL, dstFile)
+ if err == nil {
+ t.Fatal("expected error due to context cancellation, got nil")
+ }
+
+ // Should have only made 1 attempt (cancelled during retry delay)
+ if attemptCount.Load() != 1 {
+ t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
+ }
+}
+
+func TestDownloadToFile_InvalidURL(t *testing.T) {
+ tempDir := t.TempDir()
+ dstFile := filepath.Join(tempDir, "downloaded.txt")
+
+ err := DownloadToFile(context.Background(), retryDelay, "://invalid-url", dstFile)
+ if err == nil {
+ t.Fatal("expected error for invalid URL, got nil")
+ }
+}
+
+func TestDownloadToFile_InvalidDestination(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte("test"))
+ }))
+ defer server.Close()
+
+ // Use an invalid destination path
+ err := DownloadToFile(context.Background(), retryDelay, server.URL, "/invalid/path/that/does/not/exist/file.txt")
+ if err == nil {
+ t.Fatal("expected error for invalid destination, got nil")
+ }
+}
+
+func TestDownloadToFile_NoRetry(t *testing.T) {
+ var attemptCount atomic.Int32
+
+ // Create a test server that always fails
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ attemptCount.Add(1)
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte("error"))
+ }))
+ defer server.Close()
+
+ // Create a temporary file for download
+ tempDir := t.TempDir()
+ dstFile := filepath.Join(tempDir, "downloaded.txt")
+
+ // Download the file with retryDelay = 0 (should not retry)
+ if err := DownloadToFile(context.Background(), 0, server.URL, dstFile); err == nil {
+ t.Fatal("expected error, got nil")
+ }
+
+ // Verify it only made 1 attempt (no retry)
+ if attemptCount.Load() != 1 {
+ t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
+ }
+}
diff --git a/client/internal/updatemanager/installer/binary_nowindows.go b/client/internal/updatemanager/installer/binary_nowindows.go
new file mode 100644
index 000000000..19f3bef83
--- /dev/null
+++ b/client/internal/updatemanager/installer/binary_nowindows.go
@@ -0,0 +1,7 @@
+//go:build !windows
+
+package installer
+
+func UpdaterBinaryNameWithoutExtension() string {
+ return updaterBinary
+}
diff --git a/client/internal/updatemanager/installer/binary_windows.go b/client/internal/updatemanager/installer/binary_windows.go
new file mode 100644
index 000000000..4c66391c2
--- /dev/null
+++ b/client/internal/updatemanager/installer/binary_windows.go
@@ -0,0 +1,11 @@
+package installer
+
+import (
+ "path/filepath"
+ "strings"
+)
+
+func UpdaterBinaryNameWithoutExtension() string {
+ ext := filepath.Ext(updaterBinary)
+ return strings.TrimSuffix(updaterBinary, ext)
+}
diff --git a/client/internal/updatemanager/installer/doc.go b/client/internal/updatemanager/installer/doc.go
new file mode 100644
index 000000000..0a60454bb
--- /dev/null
+++ b/client/internal/updatemanager/installer/doc.go
@@ -0,0 +1,111 @@
+// Package installer provides functionality for managing NetBird application
+// updates and installations across Windows, macOS. It handles
+// the complete update lifecycle including artifact download, cryptographic verification,
+// installation execution, process management, and result reporting.
+//
+// # Architecture
+//
+// The installer package uses a two-process architecture to enable self-updates:
+//
+// 1. Service Process: The main NetBird daemon process that initiates updates
+// 2. Updater Process: A detached child process that performs the actual installation
+//
+// This separation is critical because:
+// - The service binary cannot update itself while running
+// - The installer (EXE/MSI/PKG) will terminate the service during installation
+// - The updater process survives service termination and restarts it after installation
+// - Results can be communicated back to the service after it restarts
+//
+// # Update Flow
+//
+// Service Process (RunInstallation):
+//
+// 1. Validates target version format (semver)
+// 2. Determines installer type (EXE, MSI, PKG, or Homebrew)
+// 3. Downloads installer file from GitHub releases (if applicable)
+// 4. Verifies installer signature using reposign package (cryptographic verification in service process before
+// launching updater)
+// 5. Copies service binary to tempDir as "updater" (or "updater.exe" on Windows)
+// 6. Launches updater process with detached mode:
+// - --temp-dir: Temporary directory path
+// - --service-dir: Service installation directory
+// - --installer-file: Path to downloaded installer (if applicable)
+// - --dry-run: Optional flag to test without actually installing
+// 7. Service process continues running (will be terminated by installer later)
+// 8. Service can watch for result.json using ResultHandler.Watch() to detect completion
+//
+// Updater Process (Setup):
+//
+// 1. Receives parameters from service via command-line arguments
+// 2. Runs installer with appropriate silent/quiet flags:
+// - Windows EXE: installer.exe /S
+// - Windows MSI: msiexec.exe /i installer.msi /quiet /qn /l*v msi.log
+// - macOS PKG: installer -pkg installer.pkg -target /
+// - macOS Homebrew: brew upgrade netbirdio/tap/netbird
+// 3. Installer terminates daemon and UI processes
+// 4. Installer replaces binaries with new version
+// 5. Updater waits for installer to complete
+// 6. Updater restarts daemon:
+// - Windows: netbird.exe service start
+// - macOS/Linux: netbird service start
+// 7. Updater restarts UI:
+// - Windows: Launches netbird-ui.exe as active console user using CreateProcessAsUser
+// - macOS: Uses launchctl asuser to launch NetBird.app for console user
+// - Linux: Not implemented (UI typically auto-starts)
+// 8. Updater writes result.json with success/error status
+// 9. Updater process exits
+//
+// # Result Communication
+//
+// The ResultHandler (result.go) manages communication between updater and service:
+//
+// Result Structure:
+//
+// type Result struct {
+// Success bool // true if installation succeeded
+// Error string // error message if Success is false
+// ExecutedAt time.Time // when installation completed
+// }
+//
+// Result files are automatically cleaned up after being read.
+//
+// # File Locations
+//
+// Temporary Directory (platform-specific):
+//
+// Windows:
+// - Path: %ProgramData%\Netbird\tmp-install
+// - Example: C:\ProgramData\Netbird\tmp-install
+//
+// macOS:
+// - Path: /var/lib/netbird/tmp-install
+// - Requires root permissions
+//
+// Files created during installation:
+//
+// tmp-install/
+// installer.log
+// updater[.exe] # Copy of service binary
+// netbird_installer_*.[exe|msi|pkg] # Downloaded installer
+// result.json # Installation result
+// msi.log # MSI verbose log (Windows MSI only)
+//
+// # API Reference
+//
+// # Cleanup
+//
+// CleanUpInstallerFiles() removes temporary files after successful installation:
+// - Downloaded installer files (*.exe, *.msi, *.pkg)
+// - Updater binary copy
+// - Does NOT remove result.json (cleaned by ResultHandler after read)
+// - Does NOT remove msi.log (kept for debugging)
+//
+// # Dry-Run Mode
+//
+// Dry-run mode allows testing the update process without actually installing:
+//
+// Enable via environment variable:
+//
+// export NB_AUTO_UPDATE_DRY_RUN=true
+// netbird service install-update 0.29.0
+package installer
diff --git a/client/internal/updatemanager/installer/installer.go b/client/internal/updatemanager/installer/installer.go
new file mode 100644
index 000000000..caf5873f8
--- /dev/null
+++ b/client/internal/updatemanager/installer/installer.go
@@ -0,0 +1,50 @@
+//go:build !windows && !darwin
+
+package installer
+
+import (
+ "context"
+ "fmt"
+)
+
+const (
+ updaterBinary = "updater"
+)
+
+type Installer struct {
+ tempDir string
+}
+
+// New used by the service
+func New() *Installer {
+ return &Installer{}
+}
+
+// NewWithDir used by the updater process, get the tempDir from the service via cmd line
+func NewWithDir(tempDir string) *Installer {
+ return &Installer{
+ tempDir: tempDir,
+ }
+}
+
+func (u *Installer) TempDir() string {
+ return ""
+}
+
+func (c *Installer) LogFiles() []string {
+ return []string{}
+}
+
+func (u *Installer) CleanUpInstallerFiles() error {
+ return nil
+}
+
+func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) error {
+ return fmt.Errorf("unsupported platform")
+}
+
+// Setup runs the installer with appropriate arguments and manages the daemon/UI state
+// This will be run by the updater process
+func (u *Installer) Setup(ctx context.Context, dryRun bool, targetVersion string, daemonFolder string) (resultErr error) {
+ return fmt.Errorf("unsupported platform")
+}
diff --git a/client/internal/updatemanager/installer/installer_common.go b/client/internal/updatemanager/installer/installer_common.go
new file mode 100644
index 000000000..03378d55f
--- /dev/null
+++ b/client/internal/updatemanager/installer/installer_common.go
@@ -0,0 +1,293 @@
+//go:build windows || darwin
+
+package installer
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "path"
+ "path/filepath"
+ "strings"
+
+ "github.com/hashicorp/go-multierror"
+ goversion "github.com/hashicorp/go-version"
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
+ "github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
+)
+
+type Installer struct {
+ tempDir string
+}
+
+// New used by the service
+func New() *Installer {
+ return &Installer{
+ tempDir: defaultTempDir,
+ }
+}
+
+// NewWithDir used by the updater process, get the tempDir from the service via cmd line
+func NewWithDir(tempDir string) *Installer {
+ return &Installer{
+ tempDir: tempDir,
+ }
+}
+
+// RunInstallation starts the updater process to run the installation
+// This will run by the original service process
+func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) (err error) {
+ resultHandler := NewResultHandler(u.tempDir)
+
+ defer func() {
+ if err != nil {
+ if writeErr := resultHandler.WriteErr(err); writeErr != nil {
+ log.Errorf("failed to write error result: %v", writeErr)
+ }
+ }
+ }()
+
+ if err := validateTargetVersion(targetVersion); err != nil {
+ return err
+ }
+
+ if err := u.mkTempDir(); err != nil {
+ return err
+ }
+
+ var installerFile string
+ // Download files only when not using any third-party store
+ if installerType := TypeOfInstaller(ctx); installerType.Downloadable() {
+ log.Infof("download installer")
+ var err error
+ installerFile, err = u.downloadInstaller(ctx, installerType, targetVersion)
+ if err != nil {
+ log.Errorf("failed to download installer: %v", err)
+ return err
+ }
+
+ artifactVerify, err := reposign.NewArtifactVerify(DefaultSigningKeysBaseURL)
+ if err != nil {
+ log.Errorf("failed to create artifact verify: %v", err)
+ return err
+ }
+
+ if err := artifactVerify.Verify(ctx, targetVersion, installerFile); err != nil {
+ log.Errorf("artifact verification error: %v", err)
+ return err
+ }
+ }
+
+ log.Infof("running installer")
+ updaterPath, err := u.copyUpdater()
+ if err != nil {
+ return err
+ }
+
+ // the directory where the service has been installed
+ workspace, err := getServiceDir()
+ if err != nil {
+ return err
+ }
+
+ args := []string{
+ "--temp-dir", u.tempDir,
+ "--service-dir", workspace,
+ }
+
+ if isDryRunEnabled() {
+ args = append(args, "--dry-run=true")
+ }
+
+ if installerFile != "" {
+ args = append(args, "--installer-file", installerFile)
+ }
+
+ updateCmd := exec.Command(updaterPath, args...)
+ log.Infof("starting updater process: %s", updateCmd.String())
+
+ // Configure the updater to run in a separate session/process group
+ // so it survives the parent daemon being stopped
+ setUpdaterProcAttr(updateCmd)
+
+ // Start the updater process asynchronously
+ if err := updateCmd.Start(); err != nil {
+ return err
+ }
+
+ pid := updateCmd.Process.Pid
+ log.Infof("updater started with PID %d", pid)
+
+ // Release the process so the OS can fully detach it
+ if err := updateCmd.Process.Release(); err != nil {
+ log.Warnf("failed to release updater process: %v", err)
+ }
+
+ return nil
+}
+
+// CleanUpInstallerFiles
+// - the installer file (pkg, exe, msi)
+// - the selfcopy updater.exe
+func (u *Installer) CleanUpInstallerFiles() error {
+ // Check if tempDir exists
+ info, err := os.Stat(u.tempDir)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil
+ }
+ return err
+ }
+
+ if !info.IsDir() {
+ return nil
+ }
+
+ var merr *multierror.Error
+
+ if err := os.Remove(filepath.Join(u.tempDir, updaterBinary)); err != nil && !os.IsNotExist(err) {
+ merr = multierror.Append(merr, fmt.Errorf("failed to remove updater binary: %w", err))
+ }
+
+ entries, err := os.ReadDir(u.tempDir)
+ if err != nil {
+ return err
+ }
+
+ for _, entry := range entries {
+ if entry.IsDir() {
+ continue
+ }
+
+ name := entry.Name()
+ for _, ext := range binaryExtensions {
+ if strings.HasSuffix(strings.ToLower(name), strings.ToLower(ext)) {
+ if err := os.Remove(filepath.Join(u.tempDir, name)); err != nil {
+ merr = multierror.Append(merr, fmt.Errorf("failed to remove %s: %w", name, err))
+ }
+ break
+ }
+ }
+ }
+
+ return merr.ErrorOrNil()
+}
+
+func (u *Installer) downloadInstaller(ctx context.Context, installerType Type, targetVersion string) (string, error) {
+ fileURL := urlWithVersionArch(installerType, targetVersion)
+
+ // Clean up temp directory on error
+ var success bool
+ defer func() {
+ if !success {
+ if err := os.RemoveAll(u.tempDir); err != nil {
+ log.Errorf("error cleaning up temporary directory: %v", err)
+ }
+ }
+ }()
+
+ fileName := path.Base(fileURL)
+ if fileName == "." || fileName == "/" || fileName == "" {
+ return "", fmt.Errorf("invalid file URL: %s", fileURL)
+ }
+
+ outputFilePath := filepath.Join(u.tempDir, fileName)
+ if err := downloader.DownloadToFile(ctx, downloader.DefaultRetryDelay, fileURL, outputFilePath); err != nil {
+ return "", err
+ }
+
+ success = true
+ return outputFilePath, nil
+}
+
+func (u *Installer) TempDir() string {
+ return u.tempDir
+}
+
+func (u *Installer) mkTempDir() error {
+ if err := os.MkdirAll(u.tempDir, 0o755); err != nil {
+ log.Debugf("failed to create tempdir: %s", u.tempDir)
+ return err
+ }
+ return nil
+}
+
+func (u *Installer) copyUpdater() (string, error) {
+ src, err := getServiceBinary()
+ if err != nil {
+ return "", fmt.Errorf("failed to get updater binary: %w", err)
+ }
+
+ dst := filepath.Join(u.tempDir, updaterBinary)
+ if err := copyFile(src, dst); err != nil {
+ return "", fmt.Errorf("failed to copy updater binary: %w", err)
+ }
+
+ if err := os.Chmod(dst, 0o755); err != nil {
+ return "", fmt.Errorf("failed to set permissions: %w", err)
+ }
+
+ return dst, nil
+}
+
+func validateTargetVersion(targetVersion string) error {
+ if targetVersion == "" {
+ return fmt.Errorf("target version cannot be empty")
+ }
+
+ _, err := goversion.NewVersion(targetVersion)
+ if err != nil {
+ return fmt.Errorf("invalid target version %q: %w", targetVersion, err)
+ }
+
+ return nil
+}
+
+func copyFile(src, dst string) error {
+ log.Infof("copying %s to %s", src, dst)
+ in, err := os.Open(src)
+ if err != nil {
+ return fmt.Errorf("open source: %w", err)
+ }
+ defer func() {
+ if err := in.Close(); err != nil {
+ log.Warnf("failed to close source file: %v", err)
+ }
+ }()
+
+ out, err := os.Create(dst)
+ if err != nil {
+ return fmt.Errorf("create destination: %w", err)
+ }
+ defer func() {
+ if err := out.Close(); err != nil {
+ log.Warnf("failed to close destination file: %v", err)
+ }
+ }()
+
+ if _, err := io.Copy(out, in); err != nil {
+ return fmt.Errorf("copy: %w", err)
+ }
+
+ return nil
+}
+
+func getServiceDir() (string, error) {
+ exePath, err := os.Executable()
+ if err != nil {
+ return "", err
+ }
+ return filepath.Dir(exePath), nil
+}
+
+func getServiceBinary() (string, error) {
+ return os.Executable()
+}
+
+func isDryRunEnabled() bool {
+ return strings.EqualFold(strings.TrimSpace(os.Getenv("NB_AUTO_UPDATE_DRY_RUN")), "true")
+}
diff --git a/client/internal/updatemanager/installer/installer_log_darwin.go b/client/internal/updatemanager/installer/installer_log_darwin.go
new file mode 100644
index 000000000..50dd5d197
--- /dev/null
+++ b/client/internal/updatemanager/installer/installer_log_darwin.go
@@ -0,0 +1,11 @@
+package installer
+
+import (
+ "path/filepath"
+)
+
+func (u *Installer) LogFiles() []string {
+ return []string{
+ filepath.Join(u.tempDir, LogFile),
+ }
+}
diff --git a/client/internal/updatemanager/installer/installer_log_windows.go b/client/internal/updatemanager/installer/installer_log_windows.go
new file mode 100644
index 000000000..96e4cfd1f
--- /dev/null
+++ b/client/internal/updatemanager/installer/installer_log_windows.go
@@ -0,0 +1,12 @@
+package installer
+
+import (
+ "path/filepath"
+)
+
+func (u *Installer) LogFiles() []string {
+ return []string{
+ filepath.Join(u.tempDir, msiLogFile),
+ filepath.Join(u.tempDir, LogFile),
+ }
+}
diff --git a/client/internal/updatemanager/installer/installer_run_darwin.go b/client/internal/updatemanager/installer/installer_run_darwin.go
new file mode 100644
index 000000000..462e2c227
--- /dev/null
+++ b/client/internal/updatemanager/installer/installer_run_darwin.go
@@ -0,0 +1,238 @@
+package installer
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "os/exec"
+ "os/user"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "syscall"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ daemonName = "netbird"
+ updaterBinary = "updater"
+ uiBinary = "/Applications/NetBird.app"
+
+ defaultTempDir = "/var/lib/netbird/tmp-install"
+
+ pkgDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
+)
+
+var (
+ binaryExtensions = []string{"pkg"}
+)
+
+// Setup runs the installer with appropriate arguments and manages the daemon/UI state
+// This will be run by the updater process
+func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
+ resultHandler := NewResultHandler(u.tempDir)
+
+ // Always ensure daemon and UI are restarted after setup
+ defer func() {
+ log.Infof("write out result")
+ var err error
+ if resultErr == nil {
+ err = resultHandler.WriteSuccess()
+ } else {
+ err = resultHandler.WriteErr(resultErr)
+ }
+ if err != nil {
+ log.Errorf("failed to write update result: %v", err)
+ }
+
+ // skip service restart if dry-run mode is enabled
+ if dryRun {
+ return
+ }
+
+ log.Infof("starting daemon back")
+ if err := u.startDaemon(daemonFolder); err != nil {
+ log.Errorf("failed to start daemon: %v", err)
+ }
+
+ log.Infof("starting UI back")
+ if err := u.startUIAsUser(); err != nil {
+ log.Errorf("failed to start UI: %v", err)
+ }
+
+ }()
+
+ if dryRun {
+ time.Sleep(7 * time.Second)
+ log.Infof("dry-run mode enabled, skipping actual installation")
+ resultErr = fmt.Errorf("dry-run mode enabled")
+ return
+ }
+
+ switch TypeOfInstaller(ctx) {
+ case TypePKG:
+ resultErr = u.installPkgFile(ctx, installerFile)
+ case TypeHomebrew:
+ resultErr = u.updateHomeBrew(ctx)
+ }
+
+ return resultErr
+}
+
+func (u *Installer) startDaemon(daemonFolder string) error {
+ log.Infof("starting netbird service")
+ ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
+ defer cancel()
+
+ cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
+ if output, err := cmd.CombinedOutput(); err != nil {
+ log.Warnf("failed to start netbird service: %v, output: %s", err, string(output))
+ return err
+ }
+ log.Infof("netbird service started successfully")
+ return nil
+}
+
+func (u *Installer) startUIAsUser() error {
+ log.Infof("starting netbird-ui: %s", uiBinary)
+
+ // Get the current console user
+ cmd := exec.Command("stat", "-f", "%Su", "/dev/console")
+ output, err := cmd.Output()
+ if err != nil {
+ return fmt.Errorf("failed to get console user: %w", err)
+ }
+
+ username := strings.TrimSpace(string(output))
+ if username == "" || username == "root" {
+ return fmt.Errorf("no active user session found")
+ }
+
+ log.Infof("starting UI for user: %s", username)
+
+ // Get user's UID
+ userInfo, err := user.Lookup(username)
+ if err != nil {
+ return fmt.Errorf("failed to lookup user %s: %w", username, err)
+ }
+
+ // Start the UI process as the console user using launchctl
+ // This ensures the app runs in the user's context with proper GUI access
+ launchCmd := exec.Command("launchctl", "asuser", userInfo.Uid, "open", "-a", uiBinary)
+ log.Infof("launchCmd: %s", launchCmd.String())
+ // Set the user's home directory for proper macOS app behavior
+ launchCmd.Env = append(os.Environ(), "HOME="+userInfo.HomeDir)
+ log.Infof("set HOME environment variable: %s", userInfo.HomeDir)
+
+ if err := launchCmd.Start(); err != nil {
+ return fmt.Errorf("failed to start UI process: %w", err)
+ }
+
+ // Release the process so it can run independently
+ if err := launchCmd.Process.Release(); err != nil {
+ log.Warnf("failed to release UI process: %v", err)
+ }
+
+ log.Infof("netbird-ui started successfully for user %s", username)
+ return nil
+}
+
+func (u *Installer) installPkgFile(ctx context.Context, path string) error {
+ log.Infof("installing pkg file: %s", path)
+
+ // Kill any existing UI processes before installation
+ // This ensures the postinstall script's "open $APP" will start the new version
+ u.killUI()
+
+ volume := "/"
+
+ cmd := exec.CommandContext(ctx, "installer", "-pkg", path, "-target", volume)
+ if err := cmd.Start(); err != nil {
+ return fmt.Errorf("error running pkg file: %w", err)
+ }
+ log.Infof("installer started with PID %d", cmd.Process.Pid)
+ if err := cmd.Wait(); err != nil {
+ return fmt.Errorf("error running pkg file: %w", err)
+ }
+ log.Infof("pkg file installed successfully")
+ return nil
+}
+
+func (u *Installer) updateHomeBrew(ctx context.Context) error {
+ log.Infof("updating homebrew")
+
+ // Kill any existing UI processes before upgrade
+ // This ensures the new version will be started after upgrade
+ u.killUI()
+
+ // Homebrew must be run as a non-root user
+ // To find out which user installed NetBird using HomeBrew we can check the owner of our brew tap directory
+ // Check both Apple Silicon and Intel Mac paths
+ brewTapPath := "/opt/homebrew/Library/Taps/netbirdio/homebrew-tap/"
+ brewBinPath := "/opt/homebrew/bin/brew"
+ if _, err := os.Stat(brewTapPath); os.IsNotExist(err) {
+ // Try Intel Mac path
+ brewTapPath = "/usr/local/Homebrew/Library/Taps/netbirdio/homebrew-tap/"
+ brewBinPath = "/usr/local/bin/brew"
+ }
+
+ fileInfo, err := os.Stat(brewTapPath)
+ if err != nil {
+ return fmt.Errorf("error getting homebrew installation path info: %w", err)
+ }
+
+ fileSysInfo, ok := fileInfo.Sys().(*syscall.Stat_t)
+ if !ok {
+ return fmt.Errorf("error checking file owner, sysInfo type is %T not *syscall.Stat_t", fileInfo.Sys())
+ }
+
+ // Get username from UID
+ brewUser, err := user.LookupId(fmt.Sprintf("%d", fileSysInfo.Uid))
+ if err != nil {
+ return fmt.Errorf("error looking up brew installer user: %w", err)
+ }
+ userName := brewUser.Username
+ // Get user HOME, required for brew to run correctly
+ // https://github.com/Homebrew/brew/issues/15833
+ homeDir := brewUser.HomeDir
+
+ // Check if netbird-ui is installed (must run as the brew user, not root)
+ checkUICmd := exec.CommandContext(ctx, "sudo", "-u", userName, brewBinPath, "list", "--formula", "netbirdio/tap/netbird-ui")
+ checkUICmd.Env = append(os.Environ(), "HOME="+homeDir)
+ uiInstalled := checkUICmd.Run() == nil
+
+ // Homebrew does not support installing specific versions
+ // Thus it will always update to latest and ignore targetVersion
+ upgradeArgs := []string{"-u", userName, brewBinPath, "upgrade", "netbirdio/tap/netbird"}
+ if uiInstalled {
+ upgradeArgs = append(upgradeArgs, "netbirdio/tap/netbird-ui")
+ }
+
+ cmd := exec.CommandContext(ctx, "sudo", upgradeArgs...)
+ cmd.Env = append(os.Environ(), "HOME="+homeDir)
+
+ if output, err := cmd.CombinedOutput(); err != nil {
+ return fmt.Errorf("error running brew upgrade: %w, output: %s", err, string(output))
+ }
+
+ log.Infof("homebrew updated successfully")
+ return nil
+}
+
+func (u *Installer) killUI() {
+ log.Infof("killing existing netbird-ui processes")
+ cmd := exec.Command("pkill", "-x", "netbird-ui")
+ if output, err := cmd.CombinedOutput(); err != nil {
+ // pkill returns exit code 1 if no processes matched, which is fine
+ log.Debugf("pkill netbird-ui result: %v, output: %s", err, string(output))
+ } else {
+ log.Infof("netbird-ui processes killed")
+ }
+}
+
+func urlWithVersionArch(_ Type, version string) string {
+ url := strings.ReplaceAll(pkgDownloadURL, "%version", version)
+ return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
+}
diff --git a/client/internal/updatemanager/installer/installer_run_windows.go b/client/internal/updatemanager/installer/installer_run_windows.go
new file mode 100644
index 000000000..353cd885d
--- /dev/null
+++ b/client/internal/updatemanager/installer/installer_run_windows.go
@@ -0,0 +1,213 @@
+package installer
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "time"
+ "unsafe"
+
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/sys/windows"
+)
+
+const (
+ daemonName = "netbird.exe"
+ uiName = "netbird-ui.exe"
+ updaterBinary = "updater.exe"
+
+ msiLogFile = "msi.log"
+
+ msiDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
+ exeDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
+)
+
+var (
+ defaultTempDir = filepath.Join(os.Getenv("ProgramData"), "Netbird", "tmp-install")
+
+ // for the cleanup
+ binaryExtensions = []string{"msi", "exe"}
+)
+
+// Setup runs the installer with appropriate arguments and manages the daemon/UI state
+// This will be run by the updater process
+func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
+ resultHandler := NewResultHandler(u.tempDir)
+
+ // Always ensure daemon and UI are restarted after setup
+ defer func() {
+ log.Infof("starting daemon back")
+ if err := u.startDaemon(daemonFolder); err != nil {
+ log.Errorf("failed to start daemon: %v", err)
+ }
+
+ log.Infof("starting UI back")
+ if err := u.startUIAsUser(daemonFolder); err != nil {
+ log.Errorf("failed to start UI: %v", err)
+ }
+
+ log.Infof("write out result")
+ var err error
+ if resultErr == nil {
+ err = resultHandler.WriteSuccess()
+ } else {
+ err = resultHandler.WriteErr(resultErr)
+ }
+ if err != nil {
+ log.Errorf("failed to write update result: %v", err)
+ }
+ }()
+
+ if dryRun {
+ log.Infof("dry-run mode enabled, skipping actual installation")
+ resultErr = fmt.Errorf("dry-run mode enabled")
+ return
+ }
+
+ installerType, err := typeByFileExtension(installerFile)
+ if err != nil {
+ log.Debugf("%v", err)
+ resultErr = err
+ return
+ }
+
+ var cmd *exec.Cmd
+ switch installerType {
+ case TypeExe:
+ log.Infof("run exe installer: %s", installerFile)
+ cmd = exec.CommandContext(ctx, installerFile, "/S")
+ default:
+ installerDir := filepath.Dir(installerFile)
+ logPath := filepath.Join(installerDir, msiLogFile)
+ log.Infof("run msi installer: %s", installerFile)
+ cmd = exec.CommandContext(ctx, "msiexec.exe", "/i", filepath.Base(installerFile), "/quiet", "/qn", "/l*v", logPath)
+ }
+
+ cmd.Dir = filepath.Dir(installerFile)
+
+ if resultErr = cmd.Start(); resultErr != nil {
+ log.Errorf("error starting installer: %v", resultErr)
+ return
+ }
+
+ log.Infof("installer started with PID %d", cmd.Process.Pid)
+ if resultErr = cmd.Wait(); resultErr != nil {
+ log.Errorf("installer process finished with error: %v", resultErr)
+ return
+ }
+
+ return nil
+}
+
+func (u *Installer) startDaemon(daemonFolder string) error {
+ log.Infof("starting netbird service")
+ ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
+ defer cancel()
+
+ cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
+ if output, err := cmd.CombinedOutput(); err != nil {
+ log.Debugf("failed to start netbird service: %v, output: %s", err, string(output))
+ return err
+ }
+ log.Infof("netbird service started successfully")
+ return nil
+}
+
+func (u *Installer) startUIAsUser(daemonFolder string) error {
+ uiPath := filepath.Join(daemonFolder, uiName)
+ log.Infof("starting netbird-ui: %s", uiPath)
+
+ // Get the active console session ID
+ sessionID := windows.WTSGetActiveConsoleSessionId()
+ if sessionID == 0xFFFFFFFF {
+ return fmt.Errorf("no active user session found")
+ }
+
+ // Get the user token for that session
+ var userToken windows.Token
+ err := windows.WTSQueryUserToken(sessionID, &userToken)
+ if err != nil {
+ return fmt.Errorf("failed to query user token: %w", err)
+ }
+ defer func() {
+ if err := userToken.Close(); err != nil {
+ log.Warnf("failed to close user token: %v", err)
+ }
+ }()
+
+ // Duplicate the token to a primary token
+ var primaryToken windows.Token
+ err = windows.DuplicateTokenEx(
+ userToken,
+ windows.MAXIMUM_ALLOWED,
+ nil,
+ windows.SecurityImpersonation,
+ windows.TokenPrimary,
+ &primaryToken,
+ )
+ if err != nil {
+ return fmt.Errorf("failed to duplicate token: %w", err)
+ }
+ defer func() {
+ if err := primaryToken.Close(); err != nil {
+ log.Warnf("failed to close token: %v", err)
+ }
+ }()
+
+ // Prepare startup info
+ var si windows.StartupInfo
+ si.Cb = uint32(unsafe.Sizeof(si))
+ si.Desktop = windows.StringToUTF16Ptr("winsta0\\default")
+
+ var pi windows.ProcessInformation
+
+ cmdLine, err := windows.UTF16PtrFromString(fmt.Sprintf("\"%s\"", uiPath))
+ if err != nil {
+ return fmt.Errorf("failed to convert path to UTF16: %w", err)
+ }
+
+ creationFlags := uint32(0x00000200 | 0x00000008 | 0x00000400) // CREATE_NEW_PROCESS_GROUP | DETACHED_PROCESS | CREATE_UNICODE_ENVIRONMENT
+
+ err = windows.CreateProcessAsUser(
+ primaryToken,
+ nil,
+ cmdLine,
+ nil,
+ nil,
+ false,
+ creationFlags,
+ nil,
+ nil,
+ &si,
+ &pi,
+ )
+ if err != nil {
+ return fmt.Errorf("CreateProcessAsUser failed: %w", err)
+ }
+
+ // Close handles
+ if err := windows.CloseHandle(pi.Process); err != nil {
+ log.Warnf("failed to close process handle: %v", err)
+ }
+ if err := windows.CloseHandle(pi.Thread); err != nil {
+ log.Warnf("failed to close thread handle: %v", err)
+ }
+
+ log.Infof("netbird-ui started successfully in session %d", sessionID)
+ return nil
+}
+
+func urlWithVersionArch(it Type, version string) string {
+ var url string
+ if it == TypeExe {
+ url = exeDownloadURL
+ } else {
+ url = msiDownloadURL
+ }
+ url = strings.ReplaceAll(url, "%version", version)
+ return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
+}
diff --git a/client/internal/updatemanager/installer/log.go b/client/internal/updatemanager/installer/log.go
new file mode 100644
index 000000000..8b60dba28
--- /dev/null
+++ b/client/internal/updatemanager/installer/log.go
@@ -0,0 +1,5 @@
+package installer
+
+const (
+ LogFile = "installer.log"
+)
diff --git a/client/internal/updatemanager/installer/procattr_darwin.go b/client/internal/updatemanager/installer/procattr_darwin.go
new file mode 100644
index 000000000..56f2018bb
--- /dev/null
+++ b/client/internal/updatemanager/installer/procattr_darwin.go
@@ -0,0 +1,15 @@
+package installer
+
+import (
+ "os/exec"
+ "syscall"
+)
+
+// setUpdaterProcAttr configures the updater process to run in a new session,
+// making it independent of the parent daemon process. This ensures the updater
+// survives when the daemon is stopped during the pkg installation.
+func setUpdaterProcAttr(cmd *exec.Cmd) {
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Setsid: true,
+ }
+}
diff --git a/client/internal/updatemanager/installer/procattr_windows.go b/client/internal/updatemanager/installer/procattr_windows.go
new file mode 100644
index 000000000..29a8a2de0
--- /dev/null
+++ b/client/internal/updatemanager/installer/procattr_windows.go
@@ -0,0 +1,14 @@
+package installer
+
+import (
+ "os/exec"
+ "syscall"
+)
+
+// setUpdaterProcAttr configures the updater process to run detached from the parent,
+// making it independent of the parent daemon process.
+func setUpdaterProcAttr(cmd *exec.Cmd) {
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | 0x00000008, // 0x00000008 is DETACHED_PROCESS
+ }
+}
diff --git a/client/internal/updatemanager/installer/repourl_dev.go b/client/internal/updatemanager/installer/repourl_dev.go
new file mode 100644
index 000000000..088821ad3
--- /dev/null
+++ b/client/internal/updatemanager/installer/repourl_dev.go
@@ -0,0 +1,7 @@
+//go:build devartifactsign
+
+package installer
+
+const (
+ DefaultSigningKeysBaseURL = "http://192.168.0.10:9089/signrepo"
+)
diff --git a/client/internal/updatemanager/installer/repourl_prod.go b/client/internal/updatemanager/installer/repourl_prod.go
new file mode 100644
index 000000000..abddc62c1
--- /dev/null
+++ b/client/internal/updatemanager/installer/repourl_prod.go
@@ -0,0 +1,7 @@
+//go:build !devartifactsign
+
+package installer
+
+const (
+ DefaultSigningKeysBaseURL = "https://publickeys.netbird.io/artifact-signatures"
+)
diff --git a/client/internal/updatemanager/installer/result.go b/client/internal/updatemanager/installer/result.go
new file mode 100644
index 000000000..03d08d527
--- /dev/null
+++ b/client/internal/updatemanager/installer/result.go
@@ -0,0 +1,230 @@
+package installer
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "time"
+
+ "github.com/fsnotify/fsnotify"
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ resultFile = "result.json"
+)
+
+type Result struct {
+ Success bool
+ Error string
+ ExecutedAt time.Time
+}
+
+// ResultHandler handles reading and writing update results
+type ResultHandler struct {
+ resultFile string
+}
+
+// NewResultHandler creates a new communicator with the given directory path
+// The result file will be created as "result.json" in the specified directory
+func NewResultHandler(installerDir string) *ResultHandler {
+ // Create it if it doesn't exist
+ // do not care if already exists
+ _ = os.MkdirAll(installerDir, 0o700)
+
+ rh := &ResultHandler{
+ resultFile: filepath.Join(installerDir, resultFile),
+ }
+ return rh
+}
+
+func (rh *ResultHandler) GetErrorResultReason() string {
+ result, err := rh.tryReadResult()
+ if err == nil && !result.Success {
+ return result.Error
+ }
+
+ if err := rh.cleanup(); err != nil {
+ log.Warnf("failed to cleanup result file: %v", err)
+ }
+
+ return ""
+}
+
+func (rh *ResultHandler) WriteSuccess() error {
+ result := Result{
+ Success: true,
+ ExecutedAt: time.Now(),
+ }
+ return rh.write(result)
+}
+
+func (rh *ResultHandler) WriteErr(errReason error) error {
+ result := Result{
+ Success: false,
+ Error: errReason.Error(),
+ ExecutedAt: time.Now(),
+ }
+ return rh.write(result)
+}
+
+func (rh *ResultHandler) Watch(ctx context.Context) (Result, error) {
+ log.Infof("start watching result: %s", rh.resultFile)
+
+ // Check if file already exists (updater finished before we started watching)
+ if result, err := rh.tryReadResult(); err == nil {
+ log.Infof("installer result: %v", result)
+ return result, nil
+ }
+
+ dir := filepath.Dir(rh.resultFile)
+
+ if err := rh.waitForDirectory(ctx, dir); err != nil {
+ return Result{}, err
+ }
+
+ return rh.watchForResultFile(ctx, dir)
+}
+
+func (rh *ResultHandler) waitForDirectory(ctx context.Context, dir string) error {
+ ticker := time.NewTicker(300 * time.Millisecond)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-ticker.C:
+ if info, err := os.Stat(dir); err == nil && info.IsDir() {
+ return nil
+ }
+ }
+ }
+}
+
+func (rh *ResultHandler) watchForResultFile(ctx context.Context, dir string) (Result, error) {
+ watcher, err := fsnotify.NewWatcher()
+ if err != nil {
+ log.Error(err)
+ return Result{}, err
+ }
+
+ defer func() {
+ if err := watcher.Close(); err != nil {
+ log.Warnf("failed to close watcher: %v", err)
+ }
+ }()
+
+ if err := watcher.Add(dir); err != nil {
+ return Result{}, fmt.Errorf("failed to watch directory: %v", err)
+ }
+
+ // Check again after setting up watcher to avoid race condition
+ // (file could have been created between initial check and watcher setup)
+ if result, err := rh.tryReadResult(); err == nil {
+ log.Infof("installer result: %v", result)
+ return result, nil
+ }
+
+ for {
+ select {
+ case <-ctx.Done():
+ return Result{}, ctx.Err()
+ case event, ok := <-watcher.Events:
+ if !ok {
+ return Result{}, errors.New("watcher closed unexpectedly")
+ }
+
+ if result, done := rh.handleWatchEvent(event); done {
+ return result, nil
+ }
+ case err, ok := <-watcher.Errors:
+ if !ok {
+ return Result{}, errors.New("watcher closed unexpectedly")
+ }
+ return Result{}, fmt.Errorf("watcher error: %w", err)
+ }
+ }
+}
+
+func (rh *ResultHandler) handleWatchEvent(event fsnotify.Event) (Result, bool) {
+ if event.Name != rh.resultFile {
+ return Result{}, false
+ }
+
+ if event.Has(fsnotify.Create) {
+ result, err := rh.tryReadResult()
+ if err != nil {
+ log.Debugf("error while reading result: %v", err)
+ return result, true
+ }
+ log.Infof("installer result: %v", result)
+ return result, true
+ }
+
+ return Result{}, false
+}
+
+// Write writes the update result to a file for the UI to read
+func (rh *ResultHandler) write(result Result) error {
+ log.Infof("write out installer result to: %s", rh.resultFile)
+ // Ensure directory exists
+ dir := filepath.Dir(rh.resultFile)
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ log.Errorf("failed to create directory %s: %v", dir, err)
+ return err
+ }
+
+ data, err := json.Marshal(result)
+ if err != nil {
+ return err
+ }
+
+ // Write to a temporary file first, then rename for atomic operation
+ tmpPath := rh.resultFile + ".tmp"
+ if err := os.WriteFile(tmpPath, data, 0o600); err != nil {
+ log.Errorf("failed to create temp file: %s", err)
+ return err
+ }
+
+ // Atomic rename
+ if err := os.Rename(tmpPath, rh.resultFile); err != nil {
+ if cleanupErr := os.Remove(tmpPath); cleanupErr != nil {
+ log.Warnf("Failed to remove temp result file: %v", err)
+ }
+ return err
+ }
+
+ return nil
+}
+
+func (rh *ResultHandler) cleanup() error {
+ err := os.Remove(rh.resultFile)
+ if err != nil && !os.IsNotExist(err) {
+ return err
+ }
+ log.Debugf("delete installer result file: %s", rh.resultFile)
+ return nil
+}
+
+// tryReadResult attempts to read and validate the result file
+func (rh *ResultHandler) tryReadResult() (Result, error) {
+ data, err := os.ReadFile(rh.resultFile)
+ if err != nil {
+ return Result{}, err
+ }
+
+ var result Result
+ if err := json.Unmarshal(data, &result); err != nil {
+ return Result{}, fmt.Errorf("invalid result format: %w", err)
+ }
+
+ if err := rh.cleanup(); err != nil {
+ log.Warnf("failed to cleanup result file: %v", err)
+ }
+
+ return result, nil
+}
diff --git a/client/internal/updatemanager/installer/types.go b/client/internal/updatemanager/installer/types.go
new file mode 100644
index 000000000..656d84f88
--- /dev/null
+++ b/client/internal/updatemanager/installer/types.go
@@ -0,0 +1,14 @@
+package installer
+
+type Type struct {
+ name string
+ downloadable bool
+}
+
+func (t Type) String() string {
+ return t.name
+}
+
+func (t Type) Downloadable() bool {
+ return t.downloadable
+}
diff --git a/client/internal/updatemanager/installer/types_darwin.go b/client/internal/updatemanager/installer/types_darwin.go
new file mode 100644
index 000000000..95a0cb737
--- /dev/null
+++ b/client/internal/updatemanager/installer/types_darwin.go
@@ -0,0 +1,22 @@
+package installer
+
+import (
+ "context"
+ "os/exec"
+)
+
+var (
+ TypeHomebrew = Type{name: "Homebrew", downloadable: false}
+ TypePKG = Type{name: "pkg", downloadable: true}
+)
+
+func TypeOfInstaller(ctx context.Context) Type {
+ cmd := exec.CommandContext(ctx, "pkgutil", "--pkg-info", "io.netbird.client")
+ _, err := cmd.Output()
+ if err != nil && cmd.ProcessState.ExitCode() == 1 {
+ // Not installed using pkg file, thus installed using Homebrew
+
+ return TypeHomebrew
+ }
+ return TypePKG
+}
diff --git a/client/internal/updatemanager/installer/types_windows.go b/client/internal/updatemanager/installer/types_windows.go
new file mode 100644
index 000000000..d4e5d83bd
--- /dev/null
+++ b/client/internal/updatemanager/installer/types_windows.go
@@ -0,0 +1,51 @@
+package installer
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/sys/windows/registry"
+)
+
+const (
+ uninstallKeyPath64 = `SOFTWARE\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
+ uninstallKeyPath32 = `SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
+)
+
+var (
+ TypeExe = Type{name: "EXE", downloadable: true}
+ TypeMSI = Type{name: "MSI", downloadable: true}
+)
+
+func TypeOfInstaller(_ context.Context) Type {
+ paths := []string{uninstallKeyPath64, uninstallKeyPath32}
+
+ for _, path := range paths {
+ k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
+ if err != nil {
+ continue
+ }
+
+ if err := k.Close(); err != nil {
+ log.Warnf("Error closing registry key: %v", err)
+ }
+ return TypeExe
+
+ }
+
+ log.Debug("No registry entry found for Netbird, assuming MSI installation")
+ return TypeMSI
+}
+
+func typeByFileExtension(filePath string) (Type, error) {
+ switch {
+ case strings.HasSuffix(strings.ToLower(filePath), ".exe"):
+ return TypeExe, nil
+ case strings.HasSuffix(strings.ToLower(filePath), ".msi"):
+ return TypeMSI, nil
+ default:
+ return Type{}, fmt.Errorf("unsupported installer type for file: %s", filePath)
+ }
+}
diff --git a/client/internal/updatemanager/manager.go b/client/internal/updatemanager/manager.go
new file mode 100644
index 000000000..eae11de56
--- /dev/null
+++ b/client/internal/updatemanager/manager.go
@@ -0,0 +1,374 @@
+//go:build windows || darwin
+
+package updatemanager
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "runtime"
+ "sync"
+ "time"
+
+ v "github.com/hashicorp/go-version"
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/internal/peer"
+ "github.com/netbirdio/netbird/client/internal/statemanager"
+ "github.com/netbirdio/netbird/client/internal/updatemanager/installer"
+ cProto "github.com/netbirdio/netbird/client/proto"
+ "github.com/netbirdio/netbird/version"
+)
+
+const (
+ latestVersion = "latest"
+ // this version will be ignored
+ developmentVersion = "development"
+)
+
+var errNoUpdateState = errors.New("no update state found")
+
+type UpdateState struct {
+ PreUpdateVersion string
+ TargetVersion string
+}
+
+func (u UpdateState) Name() string {
+ return "autoUpdate"
+}
+
+type Manager struct {
+ statusRecorder *peer.Status
+ stateManager *statemanager.Manager
+
+ lastTrigger time.Time
+ mgmUpdateChan chan struct{}
+ updateChannel chan struct{}
+ currentVersion string
+ update UpdateInterface
+ wg sync.WaitGroup
+
+ cancel context.CancelFunc
+
+ expectedVersion *v.Version
+ updateToLatestVersion bool
+
+ // updateMutex protect update and expectedVersion fields
+ updateMutex sync.Mutex
+
+ triggerUpdateFn func(context.Context, string) error
+}
+
+func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
+ if runtime.GOOS == "darwin" {
+ isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable()
+ if isBrew {
+ log.Warnf("auto-update disabled on Home Brew installation")
+ return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet")
+ }
+ }
+ return newManager(statusRecorder, stateManager)
+}
+
+func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
+ manager := &Manager{
+ statusRecorder: statusRecorder,
+ stateManager: stateManager,
+ mgmUpdateChan: make(chan struct{}, 1),
+ updateChannel: make(chan struct{}, 1),
+ currentVersion: version.NetbirdVersion(),
+ update: version.NewUpdate("nb/client"),
+ }
+ manager.triggerUpdateFn = manager.triggerUpdate
+
+ stateManager.RegisterState(&UpdateState{})
+
+ return manager, nil
+}
+
+// CheckUpdateSuccess checks if the update was successful and send a notification.
+// It works without to start the update manager.
+func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
+ reason := m.lastResultErrReason()
+ if reason != "" {
+ m.statusRecorder.PublishEvent(
+ cProto.SystemEvent_ERROR,
+ cProto.SystemEvent_SYSTEM,
+ "Auto-update failed",
+ fmt.Sprintf("Auto-update failed: %s", reason),
+ nil,
+ )
+ }
+
+ updateState, err := m.loadAndDeleteUpdateState(ctx)
+ if err != nil {
+ if errors.Is(err, errNoUpdateState) {
+ return
+ }
+ log.Errorf("failed to load update state: %v", err)
+ return
+ }
+
+ log.Debugf("auto-update state loaded, %v", *updateState)
+
+ if updateState.TargetVersion == m.currentVersion {
+ m.statusRecorder.PublishEvent(
+ cProto.SystemEvent_INFO,
+ cProto.SystemEvent_SYSTEM,
+ "Auto-update completed",
+ fmt.Sprintf("Your NetBird Client was auto-updated to version %s", m.currentVersion),
+ nil,
+ )
+ return
+ }
+}
+
+func (m *Manager) Start(ctx context.Context) {
+ if m.cancel != nil {
+ log.Errorf("Manager already started")
+ return
+ }
+
+ m.update.SetDaemonVersion(m.currentVersion)
+ m.update.SetOnUpdateListener(func() {
+ select {
+ case m.updateChannel <- struct{}{}:
+ default:
+ }
+ })
+ go m.update.StartFetcher()
+
+ ctx, cancel := context.WithCancel(ctx)
+ m.cancel = cancel
+
+ m.wg.Add(1)
+ go m.updateLoop(ctx)
+}
+
+func (m *Manager) SetVersion(expectedVersion string) {
+ log.Infof("set expected agent version for upgrade: %s", expectedVersion)
+ if m.cancel == nil {
+ log.Errorf("manager not started")
+ return
+ }
+
+ m.updateMutex.Lock()
+ defer m.updateMutex.Unlock()
+
+ if expectedVersion == "" {
+ log.Errorf("empty expected version provided")
+ m.expectedVersion = nil
+ m.updateToLatestVersion = false
+ return
+ }
+
+ if expectedVersion == latestVersion {
+ m.updateToLatestVersion = true
+ m.expectedVersion = nil
+ } else {
+ expectedSemVer, err := v.NewVersion(expectedVersion)
+ if err != nil {
+ log.Errorf("error parsing version: %v", err)
+ return
+ }
+ if m.expectedVersion != nil && m.expectedVersion.Equal(expectedSemVer) {
+ return
+ }
+ m.expectedVersion = expectedSemVer
+ m.updateToLatestVersion = false
+ }
+
+ select {
+ case m.mgmUpdateChan <- struct{}{}:
+ default:
+ }
+}
+
+func (m *Manager) Stop() {
+ if m.cancel == nil {
+ return
+ }
+
+ m.cancel()
+ m.updateMutex.Lock()
+ if m.update != nil {
+ m.update.StopWatch()
+ m.update = nil
+ }
+ m.updateMutex.Unlock()
+
+ m.wg.Wait()
+}
+
+func (m *Manager) onContextCancel() {
+ if m.cancel == nil {
+ return
+ }
+
+ m.updateMutex.Lock()
+ defer m.updateMutex.Unlock()
+ if m.update != nil {
+ m.update.StopWatch()
+ m.update = nil
+ }
+}
+
+func (m *Manager) updateLoop(ctx context.Context) {
+ defer m.wg.Done()
+
+ for {
+ select {
+ case <-ctx.Done():
+ m.onContextCancel()
+ return
+ case <-m.mgmUpdateChan:
+ case <-m.updateChannel:
+ log.Infof("fetched new version info")
+ }
+
+ m.handleUpdate(ctx)
+ }
+}
+
+func (m *Manager) handleUpdate(ctx context.Context) {
+ var updateVersion *v.Version
+
+ m.updateMutex.Lock()
+ if m.update == nil {
+ m.updateMutex.Unlock()
+ return
+ }
+
+ expectedVersion := m.expectedVersion
+ useLatest := m.updateToLatestVersion
+ curLatestVersion := m.update.LatestVersion()
+ m.updateMutex.Unlock()
+
+ switch {
+ // Resolve "latest" to actual version
+ case useLatest:
+ if curLatestVersion == nil {
+ log.Tracef("latest version not fetched yet")
+ return
+ }
+ updateVersion = curLatestVersion
+ // Update to specific version
+ case expectedVersion != nil:
+ updateVersion = expectedVersion
+ default:
+ log.Debugf("no expected version information set")
+ return
+ }
+
+ log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion)
+ if !m.shouldUpdate(updateVersion) {
+ return
+ }
+
+ m.lastTrigger = time.Now()
+ log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion)
+ m.statusRecorder.PublishEvent(
+ cProto.SystemEvent_CRITICAL,
+ cProto.SystemEvent_SYSTEM,
+ "Automatically updating client",
+ "Your client version is older than auto-update version set in Management, updating client now.",
+ nil,
+ )
+
+ m.statusRecorder.PublishEvent(
+ cProto.SystemEvent_CRITICAL,
+ cProto.SystemEvent_SYSTEM,
+ "",
+ "",
+ map[string]string{"progress_window": "show", "version": updateVersion.String()},
+ )
+
+ updateState := UpdateState{
+ PreUpdateVersion: m.currentVersion,
+ TargetVersion: updateVersion.String(),
+ }
+
+ if err := m.stateManager.UpdateState(updateState); err != nil {
+ log.Warnf("failed to update state: %v", err)
+ } else {
+ if err = m.stateManager.PersistState(ctx); err != nil {
+ log.Warnf("failed to persist state: %v", err)
+ }
+ }
+
+ if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil {
+ log.Errorf("Error triggering auto-update: %v", err)
+ m.statusRecorder.PublishEvent(
+ cProto.SystemEvent_ERROR,
+ cProto.SystemEvent_SYSTEM,
+ "Auto-update failed",
+ fmt.Sprintf("Auto-update failed: %v", err),
+ nil,
+ )
+ }
+}
+
+// loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it.
+// Returns nil if no state exists.
+func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, error) {
+ stateType := &UpdateState{}
+
+ m.stateManager.RegisterState(stateType)
+ if err := m.stateManager.LoadState(stateType); err != nil {
+ return nil, fmt.Errorf("load state: %w", err)
+ }
+
+ state := m.stateManager.GetState(stateType)
+ if state == nil {
+ return nil, errNoUpdateState
+ }
+
+ updateState, ok := state.(*UpdateState)
+ if !ok {
+ return nil, fmt.Errorf("failed to cast state to UpdateState")
+ }
+
+ if err := m.stateManager.DeleteState(updateState); err != nil {
+ return nil, fmt.Errorf("delete state: %w", err)
+ }
+
+ if err := m.stateManager.PersistState(ctx); err != nil {
+ return nil, fmt.Errorf("persist state: %w", err)
+ }
+
+ return updateState, nil
+}
+
+func (m *Manager) shouldUpdate(updateVersion *v.Version) bool {
+ if m.currentVersion == developmentVersion {
+ log.Debugf("skipping auto-update, running development version")
+ return false
+ }
+ currentVersion, err := v.NewVersion(m.currentVersion)
+ if err != nil {
+ log.Errorf("error checking for update, error parsing version `%s`: %v", m.currentVersion, err)
+ return false
+ }
+ if currentVersion.GreaterThanOrEqual(updateVersion) {
+ log.Infof("current version (%s) is equal to or higher than auto-update version (%s)", m.currentVersion, updateVersion)
+ return false
+ }
+
+ if time.Since(m.lastTrigger) < 5*time.Minute {
+ log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger))
+ return false
+ }
+
+ return true
+}
+
+func (m *Manager) lastResultErrReason() string {
+ inst := installer.New()
+ result := installer.NewResultHandler(inst.TempDir())
+ return result.GetErrorResultReason()
+}
+
+func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error {
+ inst := installer.New()
+ return inst.RunInstallation(ctx, targetVersion)
+}
diff --git a/client/internal/updatemanager/manager_test.go b/client/internal/updatemanager/manager_test.go
new file mode 100644
index 000000000..20ddec10d
--- /dev/null
+++ b/client/internal/updatemanager/manager_test.go
@@ -0,0 +1,214 @@
+//go:build windows || darwin
+
+package updatemanager
+
+import (
+ "context"
+ "fmt"
+ "path"
+ "testing"
+ "time"
+
+ v "github.com/hashicorp/go-version"
+
+ "github.com/netbirdio/netbird/client/internal/peer"
+ "github.com/netbirdio/netbird/client/internal/statemanager"
+)
+
+type versionUpdateMock struct {
+ latestVersion *v.Version
+ onUpdate func()
+}
+
+func (v versionUpdateMock) StopWatch() {}
+
+func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool {
+ return false
+}
+
+func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
+ v.onUpdate = updateFn
+}
+
+func (v versionUpdateMock) LatestVersion() *v.Version {
+ return v.latestVersion
+}
+
+func (v versionUpdateMock) StartFetcher() {}
+
+func Test_LatestVersion(t *testing.T) {
+ testMatrix := []struct {
+ name string
+ daemonVersion string
+ initialLatestVersion *v.Version
+ latestVersion *v.Version
+ shouldUpdateInit bool
+ shouldUpdateLater bool
+ }{
+ {
+ name: "Should only trigger update once due to time between triggers being < 5 Minutes",
+ daemonVersion: "1.0.0",
+ initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
+ latestVersion: v.Must(v.NewSemver("1.0.2")),
+ shouldUpdateInit: true,
+ shouldUpdateLater: false,
+ },
+ {
+ name: "Shouldn't update initially, but should update as soon as latest version is fetched",
+ daemonVersion: "1.0.0",
+ initialLatestVersion: nil,
+ latestVersion: v.Must(v.NewSemver("1.0.1")),
+ shouldUpdateInit: false,
+ shouldUpdateLater: true,
+ },
+ }
+
+ for idx, c := range testMatrix {
+ mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
+ tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
+ m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
+ m.update = mockUpdate
+
+ targetVersionChan := make(chan string, 1)
+
+ m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
+ targetVersionChan <- targetVersion
+ return nil
+ }
+ m.currentVersion = c.daemonVersion
+ m.Start(context.Background())
+ m.SetVersion("latest")
+ var triggeredInit bool
+ select {
+ case targetVersion := <-targetVersionChan:
+ if targetVersion != c.initialLatestVersion.String() {
+ t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion)
+ }
+ triggeredInit = true
+ case <-time.After(10 * time.Millisecond):
+ triggeredInit = false
+ }
+ if triggeredInit != c.shouldUpdateInit {
+ t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
+ }
+
+ mockUpdate.latestVersion = c.latestVersion
+ mockUpdate.onUpdate()
+
+ var triggeredLater bool
+ select {
+ case targetVersion := <-targetVersionChan:
+ if targetVersion != c.latestVersion.String() {
+ t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
+ }
+ triggeredLater = true
+ case <-time.After(10 * time.Millisecond):
+ triggeredLater = false
+ }
+ if triggeredLater != c.shouldUpdateLater {
+ t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
+ }
+
+ m.Stop()
+ }
+}
+
+func Test_HandleUpdate(t *testing.T) {
+ testMatrix := []struct {
+ name string
+ daemonVersion string
+ latestVersion *v.Version
+ expectedVersion string
+ shouldUpdate bool
+ }{
+ {
+ name: "Update to a specific version should update regardless of if latestVersion is available yet",
+ daemonVersion: "0.55.0",
+ latestVersion: nil,
+ expectedVersion: "0.56.0",
+ shouldUpdate: true,
+ },
+ {
+ name: "Update to specific version should not update if version matches",
+ daemonVersion: "0.55.0",
+ latestVersion: nil,
+ expectedVersion: "0.55.0",
+ shouldUpdate: false,
+ },
+ {
+ name: "Update to specific version should not update if current version is newer",
+ daemonVersion: "0.55.0",
+ latestVersion: nil,
+ expectedVersion: "0.54.0",
+ shouldUpdate: false,
+ },
+ {
+ name: "Update to latest version should update if latest is newer",
+ daemonVersion: "0.55.0",
+ latestVersion: v.Must(v.NewSemver("0.56.0")),
+ expectedVersion: "latest",
+ shouldUpdate: true,
+ },
+ {
+ name: "Update to latest version should not update if latest == current",
+ daemonVersion: "0.56.0",
+ latestVersion: v.Must(v.NewSemver("0.56.0")),
+ expectedVersion: "latest",
+ shouldUpdate: false,
+ },
+ {
+ name: "Should not update if daemon version is invalid",
+ daemonVersion: "development",
+ latestVersion: v.Must(v.NewSemver("1.0.0")),
+ expectedVersion: "latest",
+ shouldUpdate: false,
+ },
+ {
+ name: "Should not update if expecting latest and latest version is unavailable",
+ daemonVersion: "0.55.0",
+ latestVersion: nil,
+ expectedVersion: "latest",
+ shouldUpdate: false,
+ },
+ {
+ name: "Should not update if expected version is invalid",
+ daemonVersion: "0.55.0",
+ latestVersion: nil,
+ expectedVersion: "development",
+ shouldUpdate: false,
+ },
+ }
+ for idx, c := range testMatrix {
+ tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
+ m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
+ m.update = &versionUpdateMock{latestVersion: c.latestVersion}
+ targetVersionChan := make(chan string, 1)
+
+ m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
+ targetVersionChan <- targetVersion
+ return nil
+ }
+
+ m.currentVersion = c.daemonVersion
+ m.Start(context.Background())
+ m.SetVersion(c.expectedVersion)
+
+ var updateTriggered bool
+ select {
+ case targetVersion := <-targetVersionChan:
+ if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() {
+ t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
+ } else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion {
+ t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion)
+ }
+ updateTriggered = true
+ case <-time.After(10 * time.Millisecond):
+ updateTriggered = false
+ }
+
+ if updateTriggered != c.shouldUpdate {
+ t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
+ }
+ m.Stop()
+ }
+}
diff --git a/client/internal/updatemanager/manager_unsupported.go b/client/internal/updatemanager/manager_unsupported.go
new file mode 100644
index 000000000..4e87c2d77
--- /dev/null
+++ b/client/internal/updatemanager/manager_unsupported.go
@@ -0,0 +1,39 @@
+//go:build !windows && !darwin
+
+package updatemanager
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/netbirdio/netbird/client/internal/peer"
+ "github.com/netbirdio/netbird/client/internal/statemanager"
+)
+
+// Manager is a no-op stub for unsupported platforms
+type Manager struct{}
+
+// NewManager returns a no-op manager for unsupported platforms
+func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
+ return nil, fmt.Errorf("update manager is not supported on this platform")
+}
+
+// CheckUpdateSuccess is a no-op on unsupported platforms
+func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
+ // no-op
+}
+
+// Start is a no-op on unsupported platforms
+func (m *Manager) Start(ctx context.Context) {
+ // no-op
+}
+
+// SetVersion is a no-op on unsupported platforms
+func (m *Manager) SetVersion(expectedVersion string) {
+ // no-op
+}
+
+// Stop is a no-op on unsupported platforms
+func (m *Manager) Stop() {
+ // no-op
+}
diff --git a/client/internal/updatemanager/reposign/artifact.go b/client/internal/updatemanager/reposign/artifact.go
new file mode 100644
index 000000000..3d4fe9c74
--- /dev/null
+++ b/client/internal/updatemanager/reposign/artifact.go
@@ -0,0 +1,302 @@
+package reposign
+
+import (
+ "crypto/ed25519"
+ "crypto/rand"
+ "encoding/binary"
+ "encoding/json"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "hash"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/crypto/blake2s"
+)
+
+const (
+ tagArtifactPrivate = "ARTIFACT PRIVATE KEY"
+ tagArtifactPublic = "ARTIFACT PUBLIC KEY"
+
+ maxArtifactKeySignatureAge = 10 * 365 * 24 * time.Hour
+ maxArtifactSignatureAge = 10 * 365 * 24 * time.Hour
+)
+
+// ArtifactHash wraps a hash.Hash and counts bytes written
+type ArtifactHash struct {
+ hash.Hash
+}
+
+// NewArtifactHash returns an initialized ArtifactHash using BLAKE2s
+func NewArtifactHash() *ArtifactHash {
+ h, err := blake2s.New256(nil)
+ if err != nil {
+ panic(err) // Should never happen with nil Key
+ }
+ return &ArtifactHash{Hash: h}
+}
+
+func (ah *ArtifactHash) Write(b []byte) (int, error) {
+ return ah.Hash.Write(b)
+}
+
+// ArtifactKey is a signing Key used to sign artifacts
+type ArtifactKey struct {
+ PrivateKey
+}
+
+func (k ArtifactKey) String() string {
+ return fmt.Sprintf(
+ "ArtifactKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]",
+ k.Metadata.ID,
+ k.Metadata.CreatedAt.Format(time.RFC3339),
+ k.Metadata.ExpiresAt.Format(time.RFC3339),
+ )
+}
+
+func GenerateArtifactKey(rootKey *RootKey, expiration time.Duration) (*ArtifactKey, []byte, []byte, []byte, error) {
+ // Verify root key is still valid
+ if !rootKey.Metadata.ExpiresAt.IsZero() && time.Now().After(rootKey.Metadata.ExpiresAt) {
+ return nil, nil, nil, nil, fmt.Errorf("root key has expired on %s", rootKey.Metadata.ExpiresAt.Format(time.RFC3339))
+ }
+
+ now := time.Now()
+ expirationTime := now.Add(expiration)
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ return nil, nil, nil, nil, fmt.Errorf("generate ed25519 key: %w", err)
+ }
+
+ metadata := KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: now.UTC(),
+ ExpiresAt: expirationTime.UTC(),
+ }
+
+ ak := &ArtifactKey{
+ PrivateKey{
+ Key: priv,
+ Metadata: metadata,
+ },
+ }
+
+ // Marshal PrivateKey struct to JSON
+ privJSON, err := json.Marshal(ak.PrivateKey)
+ if err != nil {
+ return nil, nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
+ }
+
+ // Marshal PublicKey struct to JSON
+ pubKey := PublicKey{
+ Key: pub,
+ Metadata: metadata,
+ }
+ pubJSON, err := json.Marshal(pubKey)
+ if err != nil {
+ return nil, nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
+ }
+
+ // Encode to PEM with metadata embedded in bytes
+ privPEM := pem.EncodeToMemory(&pem.Block{
+ Type: tagArtifactPrivate,
+ Bytes: privJSON,
+ })
+
+ pubPEM := pem.EncodeToMemory(&pem.Block{
+ Type: tagArtifactPublic,
+ Bytes: pubJSON,
+ })
+
+ // Sign the public key with the root key
+ signature, err := SignArtifactKey(*rootKey, pubPEM)
+ if err != nil {
+ return nil, nil, nil, nil, fmt.Errorf("failed to sign artifact key: %w", err)
+ }
+
+ return ak, privPEM, pubPEM, signature, nil
+}
+
+func ParseArtifactKey(privKeyPEM []byte) (ArtifactKey, error) {
+ pk, err := parsePrivateKey(privKeyPEM, tagArtifactPrivate)
+ if err != nil {
+ return ArtifactKey{}, fmt.Errorf("failed to parse artifact Key: %w", err)
+ }
+ return ArtifactKey{pk}, nil
+}
+
+func ParseArtifactPubKey(data []byte) (PublicKey, error) {
+ pk, _, err := parsePublicKey(data, tagArtifactPublic)
+ return pk, err
+}
+
+func BundleArtifactKeys(rootKey *RootKey, keys []PublicKey) ([]byte, []byte, error) {
+ if len(keys) == 0 {
+ return nil, nil, errors.New("no keys to bundle")
+ }
+
+ // Create bundle by concatenating PEM-encoded keys
+ var pubBundle []byte
+
+ for _, pk := range keys {
+ // Marshal PublicKey struct to JSON
+ pubJSON, err := json.Marshal(pk)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
+ }
+
+ // Encode to PEM
+ pubPEM := pem.EncodeToMemory(&pem.Block{
+ Type: tagArtifactPublic,
+ Bytes: pubJSON,
+ })
+
+ pubBundle = append(pubBundle, pubPEM...)
+ }
+
+ // Sign the entire bundle with the root key
+ signature, err := SignArtifactKey(*rootKey, pubBundle)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to sign artifact key bundle: %w", err)
+ }
+
+ return pubBundle, signature, nil
+}
+
+func ValidateArtifactKeys(publicRootKeys []PublicKey, data []byte, signature Signature, revocationList *RevocationList) ([]PublicKey, error) {
+ now := time.Now().UTC()
+ if signature.Timestamp.After(now.Add(maxClockSkew)) {
+ err := fmt.Errorf("signature timestamp is in the future: %v", signature.Timestamp)
+ log.Debugf("artifact signature error: %v", err)
+ return nil, err
+ }
+ if now.Sub(signature.Timestamp) > maxArtifactKeySignatureAge {
+ err := fmt.Errorf("signature is too old: %v (created %v)", now.Sub(signature.Timestamp), signature.Timestamp)
+ log.Debugf("artifact signature error: %v", err)
+ return nil, err
+ }
+
+ // Reconstruct the signed message: artifact_key_data || timestamp
+ msg := make([]byte, 0, len(data)+8)
+ msg = append(msg, data...)
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
+
+ if !verifyAny(publicRootKeys, msg, signature.Signature) {
+ return nil, errors.New("failed to verify signature of artifact keys")
+ }
+
+ pubKeys, err := parsePublicKeyBundle(data, tagArtifactPublic)
+ if err != nil {
+ log.Debugf("failed to parse public keys: %s", err)
+ return nil, err
+ }
+
+ validKeys := make([]PublicKey, 0, len(pubKeys))
+ for _, pubKey := range pubKeys {
+ // Filter out expired keys
+ if !pubKey.Metadata.ExpiresAt.IsZero() && now.After(pubKey.Metadata.ExpiresAt) {
+ log.Debugf("Key %s is expired at %v (current time %v)",
+ pubKey.Metadata.ID, pubKey.Metadata.ExpiresAt, now)
+ continue
+ }
+
+ if revocationList != nil {
+ if revTime, revoked := revocationList.Revoked[pubKey.Metadata.ID]; revoked {
+ log.Debugf("Key %s is revoked as of %v (created %v)",
+ pubKey.Metadata.ID, revTime, pubKey.Metadata.CreatedAt)
+ continue
+ }
+ }
+ validKeys = append(validKeys, pubKey)
+ }
+
+ if len(validKeys) == 0 {
+ log.Debugf("no valid public keys found for artifact keys")
+ return nil, fmt.Errorf("all %d artifact keys are revoked", len(pubKeys))
+ }
+
+ return validKeys, nil
+}
+
+func ValidateArtifact(artifactPubKeys []PublicKey, data []byte, signature Signature) error {
+ // Validate signature timestamp
+ now := time.Now().UTC()
+ if signature.Timestamp.After(now.Add(maxClockSkew)) {
+ err := fmt.Errorf("artifact signature timestamp is in the future: %v", signature.Timestamp)
+ log.Debugf("failed to verify signature of artifact: %s", err)
+ return err
+ }
+ if now.Sub(signature.Timestamp) > maxArtifactSignatureAge {
+ return fmt.Errorf("artifact signature is too old: %v (created %v)",
+ now.Sub(signature.Timestamp), signature.Timestamp)
+ }
+
+ h := NewArtifactHash()
+ if _, err := h.Write(data); err != nil {
+ return fmt.Errorf("failed to hash artifact: %w", err)
+ }
+ hash := h.Sum(nil)
+
+ // Reconstruct the signed message: hash || length || timestamp
+ msg := make([]byte, 0, len(hash)+8+8)
+ msg = append(msg, hash...)
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data)))
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
+
+ // Find matching Key and verify
+ for _, keyInfo := range artifactPubKeys {
+ if keyInfo.Metadata.ID == signature.KeyID {
+ // Check Key expiration
+ if !keyInfo.Metadata.ExpiresAt.IsZero() &&
+ signature.Timestamp.After(keyInfo.Metadata.ExpiresAt) {
+ return fmt.Errorf("signing Key %s expired at %v, signature from %v",
+ signature.KeyID, keyInfo.Metadata.ExpiresAt, signature.Timestamp)
+ }
+
+ if ed25519.Verify(keyInfo.Key, msg, signature.Signature) {
+ log.Debugf("artifact verified successfully with Key: %s", signature.KeyID)
+ return nil
+ }
+ return fmt.Errorf("signature verification failed for Key %s", signature.KeyID)
+ }
+ }
+
+ return fmt.Errorf("no signing Key found with ID %s", signature.KeyID)
+}
+
+func SignData(artifactKey ArtifactKey, data []byte) ([]byte, error) {
+ if len(data) == 0 { // Check happens too late
+ return nil, fmt.Errorf("artifact length must be positive, got %d", len(data))
+ }
+
+ h := NewArtifactHash()
+ if _, err := h.Write(data); err != nil {
+ return nil, fmt.Errorf("failed to write artifact hash: %w", err)
+ }
+
+ timestamp := time.Now().UTC()
+
+ if !artifactKey.Metadata.ExpiresAt.IsZero() && timestamp.After(artifactKey.Metadata.ExpiresAt) {
+ return nil, fmt.Errorf("artifact key expired at %v", artifactKey.Metadata.ExpiresAt)
+ }
+
+ hash := h.Sum(nil)
+
+ // Create message: hash || length || timestamp
+ msg := make([]byte, 0, len(hash)+8+8)
+ msg = append(msg, hash...)
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data)))
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
+
+ sig := ed25519.Sign(artifactKey.Key, msg)
+
+ bundle := Signature{
+ Signature: sig,
+ Timestamp: timestamp,
+ KeyID: artifactKey.Metadata.ID,
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ return json.Marshal(bundle)
+}
diff --git a/client/internal/updatemanager/reposign/artifact_test.go b/client/internal/updatemanager/reposign/artifact_test.go
new file mode 100644
index 000000000..8865e2d0a
--- /dev/null
+++ b/client/internal/updatemanager/reposign/artifact_test.go
@@ -0,0 +1,1080 @@
+package reposign
+
+import (
+ "crypto/ed25519"
+ "crypto/rand"
+ "encoding/json"
+ "encoding/pem"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// Test ArtifactHash
+
+func TestNewArtifactHash(t *testing.T) {
+ h := NewArtifactHash()
+ assert.NotNil(t, h)
+ assert.NotNil(t, h.Hash)
+}
+
+func TestArtifactHash_Write(t *testing.T) {
+ h := NewArtifactHash()
+
+ data := []byte("test data")
+ n, err := h.Write(data)
+ require.NoError(t, err)
+ assert.Equal(t, len(data), n)
+
+ hash := h.Sum(nil)
+ assert.NotEmpty(t, hash)
+ assert.Equal(t, 32, len(hash)) // BLAKE2s-256
+}
+
+func TestArtifactHash_Deterministic(t *testing.T) {
+ data := []byte("test data")
+
+ h1 := NewArtifactHash()
+ if _, err := h1.Write(data); err != nil {
+ t.Fatal(err)
+ }
+ hash1 := h1.Sum(nil)
+
+ h2 := NewArtifactHash()
+ if _, err := h2.Write(data); err != nil {
+ t.Fatal(err)
+ }
+ hash2 := h2.Sum(nil)
+
+ assert.Equal(t, hash1, hash2)
+}
+
+func TestArtifactHash_DifferentData(t *testing.T) {
+ h1 := NewArtifactHash()
+ if _, err := h1.Write([]byte("data1")); err != nil {
+ t.Fatal(err)
+ }
+ hash1 := h1.Sum(nil)
+
+ h2 := NewArtifactHash()
+ if _, err := h2.Write([]byte("data2")); err != nil {
+ t.Fatal(err)
+ }
+ hash2 := h2.Sum(nil)
+
+ assert.NotEqual(t, hash1, hash2)
+}
+
+// Test ArtifactKey.String()
+
+func TestArtifactKey_String(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+ expiresAt := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)
+
+ ak := ArtifactKey{
+ PrivateKey{
+ Key: priv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: createdAt,
+ ExpiresAt: expiresAt,
+ },
+ },
+ }
+
+ str := ak.String()
+ assert.Contains(t, str, "ArtifactKey")
+ assert.Contains(t, str, computeKeyID(pub).String())
+ assert.Contains(t, str, "2024-01-15")
+ assert.Contains(t, str, "2025-01-15")
+}
+
+// Test GenerateArtifactKey
+
+func TestGenerateArtifactKey_Valid(t *testing.T) {
+ // Create root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
+ },
+ },
+ }
+
+ // Generate artifact key
+ ak, privPEM, pubPEM, signature, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ assert.NotNil(t, ak)
+ assert.NotEmpty(t, privPEM)
+ assert.NotEmpty(t, pubPEM)
+ assert.NotEmpty(t, signature)
+
+ // Verify expiration
+ assert.True(t, ak.Metadata.ExpiresAt.After(time.Now()))
+ assert.True(t, ak.Metadata.ExpiresAt.Before(time.Now().Add(31*24*time.Hour)))
+}
+
+func TestGenerateArtifactKey_ExpiredRoot(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ // Create expired root key
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().Add(-2 * 365 * 24 * time.Hour).UTC(),
+ ExpiresAt: time.Now().Add(-1 * time.Hour).UTC(), // Expired
+ },
+ },
+ }
+
+ _, _, _, _, err = GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "expired")
+}
+
+func TestGenerateArtifactKey_NoExpiration(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ // Root key with no expiration
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Time{}, // No expiration
+ },
+ },
+ }
+
+ ak, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ assert.NotNil(t, ak)
+}
+
+// Test ParseArtifactKey
+
+func TestParseArtifactKey_Valid(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ original, privPEM, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ // Parse it back
+ parsed, err := ParseArtifactKey(privPEM)
+ require.NoError(t, err)
+
+ assert.Equal(t, original.Key, parsed.Key)
+ assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID)
+}
+
+func TestParseArtifactKey_InvalidPEM(t *testing.T) {
+ _, err := ParseArtifactKey([]byte("invalid pem"))
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to parse")
+}
+
+func TestParseArtifactKey_WrongType(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ // Create a root key (wrong type)
+ rootKey := RootKey{
+ PrivateKey{
+ Key: priv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ privJSON, err := json.Marshal(rootKey.PrivateKey)
+ require.NoError(t, err)
+
+ privPEM := encodePrivateKey(privJSON, tagRootPrivate)
+
+ _, err = ParseArtifactKey(privPEM)
+ assert.Error(t, err)
+}
+
+// Test ParseArtifactPubKey
+
+func TestParseArtifactPubKey_Valid(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ original, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ parsed, err := ParseArtifactPubKey(pubPEM)
+ require.NoError(t, err)
+
+ assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID)
+}
+
+func TestParseArtifactPubKey_Invalid(t *testing.T) {
+ _, err := ParseArtifactPubKey([]byte("invalid"))
+ assert.Error(t, err)
+}
+
+// Test BundleArtifactKeys
+
+func TestBundleArtifactKeys_Single(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ pubKey, err := ParseArtifactPubKey(pubPEM)
+ require.NoError(t, err)
+
+ bundle, signature, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey})
+ require.NoError(t, err)
+ assert.NotEmpty(t, bundle)
+ assert.NotEmpty(t, signature)
+}
+
+func TestBundleArtifactKeys_Multiple(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate 3 artifact keys
+ var pubKeys []PublicKey
+ for i := 0; i < 3; i++ {
+ _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ pubKey, err := ParseArtifactPubKey(pubPEM)
+ require.NoError(t, err)
+ pubKeys = append(pubKeys, pubKey)
+ }
+
+ bundle, signature, err := BundleArtifactKeys(rootKey, pubKeys)
+ require.NoError(t, err)
+ assert.NotEmpty(t, bundle)
+ assert.NotEmpty(t, signature)
+
+ // Verify we can parse the bundle
+ parsed, err := parsePublicKeyBundle(bundle, tagArtifactPublic)
+ require.NoError(t, err)
+ assert.Len(t, parsed, 3)
+}
+
+func TestBundleArtifactKeys_Empty(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ _, _, err = BundleArtifactKeys(rootKey, []PublicKey{})
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "no keys")
+}
+
+// Test ValidateArtifactKeys
+
+func TestSingleValidateArtifactKey_Valid(t *testing.T) {
+ // Create root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate artifact key
+ _, _, pubPEM, sigData, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ sig, _ := ParseSignature(sigData)
+
+ // Validate
+ validKeys, err := ValidateArtifactKeys(rootKeys, pubPEM, *sig, nil)
+ require.NoError(t, err)
+ assert.Len(t, validKeys, 1)
+}
+
+func TestValidateArtifactKeys_Valid(t *testing.T) {
+ // Create root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate artifact key
+ _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ pubKey, err := ParseArtifactPubKey(pubPEM)
+ require.NoError(t, err)
+
+ // Bundle and sign
+ bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey})
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Validate
+ validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil)
+ require.NoError(t, err)
+ assert.Len(t, validKeys, 1)
+}
+
+func TestValidateArtifactKeys_FutureTimestamp(t *testing.T) {
+ rootPub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ sig := Signature{
+ Signature: make([]byte, 64),
+ Timestamp: time.Now().UTC().Add(10 * time.Minute),
+ KeyID: computeKeyID(rootPub),
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ _, err = ValidateArtifactKeys(rootKeys, []byte("data"), sig, nil)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "in the future")
+}
+
+func TestValidateArtifactKeys_TooOld(t *testing.T) {
+ rootPub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ sig := Signature{
+ Signature: make([]byte, 64),
+ Timestamp: time.Now().UTC().Add(-20 * 365 * 24 * time.Hour),
+ KeyID: computeKeyID(rootPub),
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ _, err = ValidateArtifactKeys(rootKeys, []byte("data"), sig, nil)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "too old")
+}
+
+func TestValidateArtifactKeys_InvalidSignature(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ pubKey, err := ParseArtifactPubKey(pubPEM)
+ require.NoError(t, err)
+
+ bundle, _, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey})
+ require.NoError(t, err)
+
+ // Create invalid signature
+ invalidSig := Signature{
+ Signature: make([]byte, 64),
+ Timestamp: time.Now().UTC(),
+ KeyID: computeKeyID(rootPub),
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ _, err = ValidateArtifactKeys(rootKeys, bundle, invalidSig, nil)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to verify")
+}
+
+func TestValidateArtifactKeys_WithRevocation(t *testing.T) {
+ // Create root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate two artifact keys
+ _, _, pubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ pubKey1, err := ParseArtifactPubKey(pubPEM1)
+ require.NoError(t, err)
+
+ _, _, pubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ pubKey2, err := ParseArtifactPubKey(pubPEM2)
+ require.NoError(t, err)
+
+ // Bundle both keys
+ bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey1, pubKey2})
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Create revocation list with first key revoked
+ revocationList := &RevocationList{
+ Revoked: map[KeyID]time.Time{
+ pubKey1.Metadata.ID: time.Now().UTC(),
+ },
+ LastUpdated: time.Now().UTC(),
+ }
+
+ // Validate - should only return second key
+ validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, revocationList)
+ require.NoError(t, err)
+ assert.Len(t, validKeys, 1)
+ assert.Equal(t, pubKey2.Metadata.ID, validKeys[0].Metadata.ID)
+}
+
+func TestValidateArtifactKeys_AllRevoked(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ pubKey, err := ParseArtifactPubKey(pubPEM)
+ require.NoError(t, err)
+
+ bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey})
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Revoke the key
+ revocationList := &RevocationList{
+ Revoked: map[KeyID]time.Time{
+ pubKey.Metadata.ID: time.Now().UTC(),
+ },
+ LastUpdated: time.Now().UTC(),
+ }
+
+ _, err = ValidateArtifactKeys(rootKeys, bundle, *sig, revocationList)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "revoked")
+}
+
+// Test ValidateArtifact
+
+func TestValidateArtifact_Valid(t *testing.T) {
+ // Create root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate artifact key
+ artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ // Sign some data
+ data := []byte("test artifact data")
+ sigData, err := SignData(*artifactKey, data)
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Get public key for validation
+ artifactPubKey := PublicKey{
+ Key: artifactKey.Key.Public().(ed25519.PublicKey),
+ Metadata: artifactKey.Metadata,
+ }
+
+ // Validate
+ err = ValidateArtifact([]PublicKey{artifactPubKey}, data, *sig)
+ assert.NoError(t, err)
+}
+
+func TestValidateArtifact_FutureTimestamp(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ artifactPubKey := PublicKey{
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ sig := Signature{
+ Signature: make([]byte, 64),
+ Timestamp: time.Now().UTC().Add(10 * time.Minute),
+ KeyID: computeKeyID(pub),
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ err = ValidateArtifact([]PublicKey{artifactPubKey}, []byte("data"), sig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "in the future")
+}
+
+func TestValidateArtifact_TooOld(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ artifactPubKey := PublicKey{
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ sig := Signature{
+ Signature: make([]byte, 64),
+ Timestamp: time.Now().UTC().Add(-20 * 365 * 24 * time.Hour),
+ KeyID: computeKeyID(pub),
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ err = ValidateArtifact([]PublicKey{artifactPubKey}, []byte("data"), sig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "too old")
+}
+
+func TestValidateArtifact_ExpiredKey(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate artifact key with very short expiration
+ artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond)
+ require.NoError(t, err)
+
+ // Wait for key to expire
+ time.Sleep(10 * time.Millisecond)
+
+ // Try to sign - should succeed but with old timestamp
+ data := []byte("test data")
+ sigData, err := SignData(*artifactKey, data)
+ require.Error(t, err) // Key is expired, so signing should fail
+ assert.Contains(t, err.Error(), "expired")
+ assert.Nil(t, sigData)
+}
+
+func TestValidateArtifact_WrongKey(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate two artifact keys
+ artifactKey1, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ artifactKey2, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ // Sign with key1
+ data := []byte("test data")
+ sigData, err := SignData(*artifactKey1, data)
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Try to validate with key2 only
+ artifactPubKey2 := PublicKey{
+ Key: artifactKey2.Key.Public().(ed25519.PublicKey),
+ Metadata: artifactKey2.Metadata,
+ }
+
+ err = ValidateArtifact([]PublicKey{artifactPubKey2}, data, *sig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "no signing Key found")
+}
+
+func TestValidateArtifact_TamperedData(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ // Sign original data
+ originalData := []byte("original data")
+ sigData, err := SignData(*artifactKey, originalData)
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ artifactPubKey := PublicKey{
+ Key: artifactKey.Key.Public().(ed25519.PublicKey),
+ Metadata: artifactKey.Metadata,
+ }
+
+ // Try to validate with tampered data
+ tamperedData := []byte("tampered data")
+ err = ValidateArtifact([]PublicKey{artifactPubKey}, tamperedData, *sig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "verification failed")
+}
+
+func TestValidateArtifactKeys_TwoKeysOneExpired(t *testing.T) {
+ // Test ValidateArtifactKeys with a bundle containing two keys where one is expired
+ // Should return only the valid key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate first key with very short expiration
+ _, _, expiredPubPEM, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond)
+ require.NoError(t, err)
+ expiredPubKey, err := ParseArtifactPubKey(expiredPubPEM)
+ require.NoError(t, err)
+
+ // Wait for first key to expire
+ time.Sleep(10 * time.Millisecond)
+
+ // Generate second key with normal expiration
+ _, _, validPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ validPubKey, err := ParseArtifactPubKey(validPubPEM)
+ require.NoError(t, err)
+
+ // Bundle both keys together
+ bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{expiredPubKey, validPubKey})
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // ValidateArtifactKeys should return only the valid key
+ validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil)
+ require.NoError(t, err)
+ assert.Len(t, validKeys, 1)
+ assert.Equal(t, validPubKey.Metadata.ID, validKeys[0].Metadata.ID)
+}
+
+func TestValidateArtifactKeys_TwoKeysBothExpired(t *testing.T) {
+ // Test ValidateArtifactKeys with a bundle containing two expired keys
+ // Should fail because no valid keys remain
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate first key with
+ _, _, pubPEM1, _, err := GenerateArtifactKey(rootKey, 24*time.Hour)
+ require.NoError(t, err)
+ pubKey1, err := ParseArtifactPubKey(pubPEM1)
+ require.NoError(t, err)
+
+ // Generate second key with very short expiration
+ _, _, pubPEM2, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond)
+ require.NoError(t, err)
+ pubKey2, err := ParseArtifactPubKey(pubPEM2)
+ require.NoError(t, err)
+
+ // Wait for expire
+ time.Sleep(10 * time.Millisecond)
+
+ bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey1, pubKey2})
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // ValidateArtifactKeys should fail because all keys are expired
+ keys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil)
+ assert.NoError(t, err)
+ assert.Len(t, keys, 1)
+}
+
+// Test SignData
+
+func TestSignData_Valid(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ data := []byte("test data to sign")
+ sigData, err := SignData(*artifactKey, data)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sigData)
+
+ // Verify signature can be parsed
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sig.Signature)
+ assert.Equal(t, "ed25519", sig.Algorithm)
+ assert.Equal(t, "blake2s", sig.HashAlgo)
+}
+
+func TestSignData_EmptyData(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ _, err = SignData(*artifactKey, []byte{})
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "must be positive")
+}
+
+func TestSignData_ExpiredKey(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate key with very short expiration
+ artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond)
+ require.NoError(t, err)
+
+ // Wait for expiration
+ time.Sleep(10 * time.Millisecond)
+
+ // Try to sign with expired key
+ _, err = SignData(*artifactKey, []byte("data"))
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "expired")
+}
+
+// Integration test
+
+func TestArtifact_FullWorkflow(t *testing.T) {
+ // Step 1: Create root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Step 2: Generate artifact key
+ artifactKey, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ // Step 3: Create and validate key bundle
+ artifactPubKey, err := ParseArtifactPubKey(pubPEM)
+ require.NoError(t, err)
+
+ bundle, bundleSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(bundleSig)
+ require.NoError(t, err)
+
+ validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil)
+ require.NoError(t, err)
+ assert.Len(t, validKeys, 1)
+
+ // Step 4: Sign artifact data
+ artifactData := []byte("This is my artifact data that needs to be signed")
+ artifactSig, err := SignData(*artifactKey, artifactData)
+ require.NoError(t, err)
+
+ // Step 5: Validate artifact
+ parsedSig, err := ParseSignature(artifactSig)
+ require.NoError(t, err)
+
+ err = ValidateArtifact(validKeys, artifactData, *parsedSig)
+ assert.NoError(t, err)
+}
+
+// Helper function for tests
+func encodePrivateKey(jsonData []byte, typeTag string) []byte {
+ return pem.EncodeToMemory(&pem.Block{
+ Type: typeTag,
+ Bytes: jsonData,
+ })
+}
diff --git a/client/internal/updatemanager/reposign/certs/root-pub.pem b/client/internal/updatemanager/reposign/certs/root-pub.pem
new file mode 100644
index 000000000..e7c2fd2c0
--- /dev/null
+++ b/client/internal/updatemanager/reposign/certs/root-pub.pem
@@ -0,0 +1,6 @@
+-----BEGIN ROOT PUBLIC KEY-----
+eyJLZXkiOiJoaGIxdGRDSEZNMFBuQWp1b2w2cXJ1QXRFbWFFSlg1QjFsZUNxWmpn
+V1pvPSIsIk1ldGFkYXRhIjp7ImlkIjoiOWE0OTg2NmI2MzE2MjNiNCIsImNyZWF0
+ZWRfYXQiOiIyMDI1LTExLTI0VDE3OjE1OjI4LjYyNzE3MzE3MVoiLCJleHBpcmVz
+X2F0IjoiMjAzNS0xMS0yMlQxNzoxNToyOC42MjcxNzMxNzFaIn19
+-----END ROOT PUBLIC KEY-----
diff --git a/client/internal/updatemanager/reposign/certsdev/root-pub.pem b/client/internal/updatemanager/reposign/certsdev/root-pub.pem
new file mode 100644
index 000000000..f7145477b
--- /dev/null
+++ b/client/internal/updatemanager/reposign/certsdev/root-pub.pem
@@ -0,0 +1,6 @@
+-----BEGIN ROOT PUBLIC KEY-----
+eyJLZXkiOiJyTDByVTN2MEFOZUNmbDZraitiUUd3TE1waU5CaUJLdVBWSnZtQzgr
+ZS84PSIsIk1ldGFkYXRhIjp7ImlkIjoiMTBkNjQyZTY2N2FmMDNkNCIsImNyZWF0
+ZWRfYXQiOiIyMDI1LTExLTIwVDE3OjI5OjI5LjE4MDk0NjMxNloiLCJleHBpcmVz
+X2F0IjoiMjAyNi0xMS0yMFQxNzoyOToyOS4xODA5NDYzMTZaIn19
+-----END ROOT PUBLIC KEY-----
diff --git a/client/internal/updatemanager/reposign/doc.go b/client/internal/updatemanager/reposign/doc.go
new file mode 100644
index 000000000..660b9d11d
--- /dev/null
+++ b/client/internal/updatemanager/reposign/doc.go
@@ -0,0 +1,174 @@
+// Package reposign implements a cryptographic signing and verification system
+// for NetBird software update artifacts. It provides a hierarchical key
+// management system with support for key rotation, revocation, and secure
+// artifact distribution.
+//
+// # Architecture
+//
+// The package uses a two-tier key hierarchy:
+//
+// - Root Keys: Long-lived keys that sign artifact keys. These are embedded
+// in the client binary and establish the root of trust. Root keys should
+// be kept offline and highly secured.
+//
+// - Artifact Keys: Short-lived keys that sign release artifacts (binaries,
+// packages, etc.). These are rotated regularly and can be revoked if
+// compromised. Artifact keys are signed by root keys and distributed via
+// a public repository.
+//
+// This separation allows for operational flexibility: artifact keys can be
+// rotated frequently without requiring client updates, while root keys remain
+// stable and embedded in the software.
+//
+// # Cryptographic Primitives
+//
+// The package uses strong, modern cryptographic algorithms:
+// - Ed25519: Fast, secure digital signatures (no timing attacks)
+// - BLAKE2s-256: Fast cryptographic hash for artifacts
+// - SHA-256: Key ID generation
+// - JSON: Structured key and signature serialization
+// - PEM: Standard key encoding format
+//
+// # Security Features
+//
+// Timestamp Binding:
+// - All signatures include cryptographically-bound timestamps
+// - Prevents replay attacks and enforces signature freshness
+// - Clock skew tolerance: 5 minutes
+//
+// Key Expiration:
+// - All keys have expiration times
+// - Expired keys are automatically rejected
+// - Signing with an expired key fails immediately
+//
+// Key Revocation:
+// - Compromised keys can be revoked via a signed revocation list
+// - Revocation list is checked during artifact validation
+// - Revoked keys are filtered out before artifact verification
+//
+// # File Structure
+//
+// The package expects the following file layout in the key repository:
+//
+// signrepo/
+// artifact-key-pub.pem # Bundle of artifact public keys
+// artifact-key-pub.pem.sig # Root signature of the bundle
+// revocation-list.json # List of revoked key IDs
+// revocation-list.json.sig # Root signature of revocation list
+//
+// And in the artifacts repository:
+//
+// releases/
+// v0.28.0/
+// netbird-linux-amd64
+// netbird-linux-amd64.sig # Artifact signature
+// netbird-darwin-amd64
+// netbird-darwin-amd64.sig
+// ...
+//
+// # Embedded Root Keys
+//
+// Root public keys are embedded in the client binary at compile time:
+// - Production keys: certs/ directory
+// - Development keys: certsdev/ directory
+//
+// The build tag determines which keys are embedded:
+// - Production builds: //go:build !devartifactsign
+// - Development builds: //go:build devartifactsign
+//
+// This ensures that development artifacts cannot be verified using production
+// keys and vice versa.
+//
+// # Key Rotation Strategies
+//
+// Root Key Rotation:
+//
+// Root keys can be rotated without breaking existing clients by leveraging
+// the multi-key verification system. The loadEmbeddedPublicKeys function
+// reads ALL files from the certs/ directory and accepts signatures from ANY
+// of the embedded root keys.
+//
+// To rotate root keys:
+//
+// 1. Generate a new root key pair:
+// newRootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
+//
+// 2. Add the new public key to the certs/ directory as a new file:
+// certs/
+// root-pub-2024.pem # Old key (keep this!)
+// root-pub-2025.pem # New key (add this)
+//
+// 3. Build new client versions with both keys embedded. The verification
+// will accept signatures from either key.
+//
+// 4. Start signing new artifact keys with the new root key. Old clients
+// with only the old root key will reject these, but new clients with
+// both keys will accept them.
+//
+// Each file in certs/ can contain a single key or a bundle of keys (multiple
+// PEM blocks). The system will parse all keys from all files and use them
+// for verification. This provides maximum flexibility for key management.
+//
+// Important: Never remove all old root keys at once. Always maintain at least
+// one overlapping key between releases to ensure smooth transitions.
+//
+// Artifact Key Rotation:
+//
+// Artifact keys should be rotated regularly (e.g., every 90 days) using the
+// bundling mechanism. The BundleArtifactKeys function allows multiple artifact
+// keys to be bundled together in a single signed package, and ValidateArtifact
+// will accept signatures from ANY key in the bundle.
+//
+// To rotate artifact keys smoothly:
+//
+// 1. Generate a new artifact key while keeping the old one:
+// newKey, newPrivPEM, newPubPEM, newSig, err := GenerateArtifactKey(rootKey, 90 * 24 * time.Hour)
+// // Keep oldPubPEM and oldKey available
+//
+// 2. Create a bundle containing both old and new public keys
+//
+// 3. Upload the bundle and its signature to the key repository:
+// signrepo/artifact-key-pub.pem # Contains both keys
+// signrepo/artifact-key-pub.pem.sig # Root signature
+//
+// 4. Start signing new releases with the NEW key, but keep the bundle
+// unchanged. Clients will download the bundle (containing both keys)
+// and accept signatures from either key.
+//
+// Key bundle validation workflow:
+// 1. Client downloads artifact-key-pub.pem and artifact-key-pub.pem.sig
+// 2. ValidateArtifactKeys verifies the bundle signature with ANY embedded root key
+// 3. ValidateArtifactKeys parses all public keys from the bundle
+// 4. ValidateArtifactKeys filters out expired or revoked keys
+// 5. When verifying an artifact, ValidateArtifact tries each key until one succeeds
+//
+// This multi-key acceptance model enables overlapping validity periods and
+// smooth transitions without client update requirements.
+//
+// # Best Practices
+//
+// Root Key Management:
+// - Generate root keys offline on an air-gapped machine
+// - Store root private keys in hardware security modules (HSM) if possible
+// - Use separate root keys for production and development
+// - Rotate root keys infrequently (e.g., every 5-10 years)
+// - Plan for root key rotation: embed multiple root public keys
+//
+// Artifact Key Management:
+// - Rotate artifact keys regularly (e.g., every 90 days)
+// - Use separate artifact keys for different release channels if needed
+// - Revoke keys immediately upon suspected compromise
+// - Bundle multiple artifact keys to enable smooth rotation
+//
+// Signing Process:
+// - Sign artifacts in a secure CI/CD environment
+// - Never commit private keys to version control
+// - Use environment variables or secret management for keys
+// - Verify signatures immediately after signing
+//
+// Distribution:
+// - Serve keys and revocation lists from a reliable CDN
+// - Use HTTPS for all key and artifact downloads
+// - Monitor download failures and signature verification failures
+// - Keep revocation list up to date
+package reposign
diff --git a/client/internal/updatemanager/reposign/embed_dev.go b/client/internal/updatemanager/reposign/embed_dev.go
new file mode 100644
index 000000000..ef8f77373
--- /dev/null
+++ b/client/internal/updatemanager/reposign/embed_dev.go
@@ -0,0 +1,10 @@
+//go:build devartifactsign
+
+package reposign
+
+import "embed"
+
+//go:embed certsdev
+var embeddedCerts embed.FS
+
+const embeddedCertsDir = "certsdev"
diff --git a/client/internal/updatemanager/reposign/embed_prod.go b/client/internal/updatemanager/reposign/embed_prod.go
new file mode 100644
index 000000000..91530e5f4
--- /dev/null
+++ b/client/internal/updatemanager/reposign/embed_prod.go
@@ -0,0 +1,10 @@
+//go:build !devartifactsign
+
+package reposign
+
+import "embed"
+
+//go:embed certs
+var embeddedCerts embed.FS
+
+const embeddedCertsDir = "certs"
diff --git a/client/internal/updatemanager/reposign/key.go b/client/internal/updatemanager/reposign/key.go
new file mode 100644
index 000000000..bedfef70d
--- /dev/null
+++ b/client/internal/updatemanager/reposign/key.go
@@ -0,0 +1,171 @@
+package reposign
+
+import (
+ "crypto/ed25519"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "time"
+)
+
+const (
+ maxClockSkew = 5 * time.Minute
+)
+
+// KeyID is a unique identifier for a Key (first 8 bytes of SHA-256 of public Key)
+type KeyID [8]byte
+
+// computeKeyID generates a unique ID from a public Key
+func computeKeyID(pub ed25519.PublicKey) KeyID {
+ h := sha256.Sum256(pub)
+ var id KeyID
+ copy(id[:], h[:8])
+ return id
+}
+
+// MarshalJSON implements json.Marshaler for KeyID
+func (k KeyID) MarshalJSON() ([]byte, error) {
+ return json.Marshal(k.String())
+}
+
+// UnmarshalJSON implements json.Unmarshaler for KeyID
+func (k *KeyID) UnmarshalJSON(data []byte) error {
+ var s string
+ if err := json.Unmarshal(data, &s); err != nil {
+ return err
+ }
+
+ parsed, err := ParseKeyID(s)
+ if err != nil {
+ return err
+ }
+
+ *k = parsed
+ return nil
+}
+
+// ParseKeyID parses a hex string (16 hex chars = 8 bytes) into a KeyID.
+func ParseKeyID(s string) (KeyID, error) {
+ var id KeyID
+ if len(s) != 16 {
+ return id, fmt.Errorf("invalid KeyID length: got %d, want 16 hex chars (8 bytes)", len(s))
+ }
+
+ b, err := hex.DecodeString(s)
+ if err != nil {
+ return id, fmt.Errorf("failed to decode KeyID: %w", err)
+ }
+
+ copy(id[:], b)
+ return id, nil
+}
+
+func (k KeyID) String() string {
+ return fmt.Sprintf("%x", k[:])
+}
+
+// KeyMetadata contains versioning and lifecycle information for a Key
+type KeyMetadata struct {
+ ID KeyID `json:"id"`
+ CreatedAt time.Time `json:"created_at"`
+ ExpiresAt time.Time `json:"expires_at,omitempty"` // Optional expiration
+}
+
+// PublicKey wraps a public Key with its Metadata
+type PublicKey struct {
+ Key ed25519.PublicKey
+ Metadata KeyMetadata
+}
+
+func parsePublicKeyBundle(bundle []byte, typeTag string) ([]PublicKey, error) {
+ var keys []PublicKey
+ for len(bundle) > 0 {
+ keyInfo, rest, err := parsePublicKey(bundle, typeTag)
+ if err != nil {
+ return nil, err
+ }
+ keys = append(keys, keyInfo)
+ bundle = rest
+ }
+ if len(keys) == 0 {
+ return nil, errors.New("no keys found in bundle")
+ }
+ return keys, nil
+}
+
+func parsePublicKey(data []byte, typeTag string) (PublicKey, []byte, error) {
+ b, rest := pem.Decode(data)
+ if b == nil {
+ return PublicKey{}, nil, errors.New("failed to decode PEM data")
+ }
+ if b.Type != typeTag {
+ return PublicKey{}, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag)
+ }
+
+ // Unmarshal JSON-embedded format
+ var pub PublicKey
+ if err := json.Unmarshal(b.Bytes, &pub); err != nil {
+ return PublicKey{}, nil, fmt.Errorf("failed to unmarshal public key: %w", err)
+ }
+
+ // Validate key length
+ if len(pub.Key) != ed25519.PublicKeySize {
+ return PublicKey{}, nil, fmt.Errorf("incorrect Ed25519 public key size: expected %d, got %d",
+ ed25519.PublicKeySize, len(pub.Key))
+ }
+
+ // Always recompute ID to ensure integrity
+ pub.Metadata.ID = computeKeyID(pub.Key)
+
+ return pub, rest, nil
+}
+
+type PrivateKey struct {
+ Key ed25519.PrivateKey
+ Metadata KeyMetadata
+}
+
+func parsePrivateKey(data []byte, typeTag string) (PrivateKey, error) {
+ b, rest := pem.Decode(data)
+ if b == nil {
+ return PrivateKey{}, errors.New("failed to decode PEM data")
+ }
+ if len(rest) > 0 {
+ return PrivateKey{}, errors.New("trailing PEM data")
+ }
+ if b.Type != typeTag {
+ return PrivateKey{}, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag)
+ }
+
+ // Unmarshal JSON-embedded format
+ var pk PrivateKey
+ if err := json.Unmarshal(b.Bytes, &pk); err != nil {
+ return PrivateKey{}, fmt.Errorf("failed to unmarshal private key: %w", err)
+ }
+
+ // Validate key length
+ if len(pk.Key) != ed25519.PrivateKeySize {
+ return PrivateKey{}, fmt.Errorf("incorrect Ed25519 private key size: expected %d, got %d",
+ ed25519.PrivateKeySize, len(pk.Key))
+ }
+
+ return pk, nil
+}
+
+func verifyAny(publicRootKeys []PublicKey, msg, sig []byte) bool {
+ // Verify with root keys
+ var rootKeys []ed25519.PublicKey
+ for _, r := range publicRootKeys {
+ rootKeys = append(rootKeys, r.Key)
+ }
+
+ for _, k := range rootKeys {
+ if ed25519.Verify(k, msg, sig) {
+ return true
+ }
+ }
+ return false
+}
diff --git a/client/internal/updatemanager/reposign/key_test.go b/client/internal/updatemanager/reposign/key_test.go
new file mode 100644
index 000000000..f8e1676fb
--- /dev/null
+++ b/client/internal/updatemanager/reposign/key_test.go
@@ -0,0 +1,636 @@
+package reposign
+
+import (
+ "crypto/ed25519"
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/json"
+ "encoding/pem"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// Test KeyID functions
+
+func TestComputeKeyID(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ keyID := computeKeyID(pub)
+
+ // Verify it's the first 8 bytes of SHA-256
+ h := sha256.Sum256(pub)
+ expectedID := KeyID{}
+ copy(expectedID[:], h[:8])
+
+ assert.Equal(t, expectedID, keyID)
+}
+
+func TestComputeKeyID_Deterministic(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ // Computing KeyID multiple times should give the same result
+ keyID1 := computeKeyID(pub)
+ keyID2 := computeKeyID(pub)
+
+ assert.Equal(t, keyID1, keyID2)
+}
+
+func TestComputeKeyID_DifferentKeys(t *testing.T) {
+ pub1, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ pub2, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ keyID1 := computeKeyID(pub1)
+ keyID2 := computeKeyID(pub2)
+
+ // Different keys should produce different IDs
+ assert.NotEqual(t, keyID1, keyID2)
+}
+
+func TestParseKeyID_Valid(t *testing.T) {
+ hexStr := "0123456789abcdef"
+
+ keyID, err := ParseKeyID(hexStr)
+ require.NoError(t, err)
+
+ expected := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
+ assert.Equal(t, expected, keyID)
+}
+
+func TestParseKeyID_InvalidLength(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ }{
+ {"too short", "01234567"},
+ {"too long", "0123456789abcdef00"},
+ {"empty", ""},
+ {"odd length", "0123456789abcde"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := ParseKeyID(tt.input)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid KeyID length")
+ })
+ }
+}
+
+func TestParseKeyID_InvalidHex(t *testing.T) {
+ invalidHex := "0123456789abcxyz" // 'xyz' are not valid hex
+
+ _, err := ParseKeyID(invalidHex)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to decode KeyID")
+}
+
+func TestKeyID_String(t *testing.T) {
+ keyID := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
+
+ str := keyID.String()
+ assert.Equal(t, "0123456789abcdef", str)
+}
+
+func TestKeyID_RoundTrip(t *testing.T) {
+ original := "fedcba9876543210"
+
+ keyID, err := ParseKeyID(original)
+ require.NoError(t, err)
+
+ result := keyID.String()
+ assert.Equal(t, original, result)
+}
+
+func TestKeyID_ZeroValue(t *testing.T) {
+ keyID := KeyID{}
+ str := keyID.String()
+ assert.Equal(t, "0000000000000000", str)
+}
+
+// Test KeyMetadata
+
+func TestKeyMetadata_JSONMarshaling(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ metadata := KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
+ ExpiresAt: time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC),
+ }
+
+ jsonData, err := json.Marshal(metadata)
+ require.NoError(t, err)
+
+ var decoded KeyMetadata
+ err = json.Unmarshal(jsonData, &decoded)
+ require.NoError(t, err)
+
+ assert.Equal(t, metadata.ID, decoded.ID)
+ assert.Equal(t, metadata.CreatedAt.Unix(), decoded.CreatedAt.Unix())
+ assert.Equal(t, metadata.ExpiresAt.Unix(), decoded.ExpiresAt.Unix())
+}
+
+func TestKeyMetadata_NoExpiration(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ metadata := KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
+ ExpiresAt: time.Time{}, // Zero value = no expiration
+ }
+
+ jsonData, err := json.Marshal(metadata)
+ require.NoError(t, err)
+
+ var decoded KeyMetadata
+ err = json.Unmarshal(jsonData, &decoded)
+ require.NoError(t, err)
+
+ assert.True(t, decoded.ExpiresAt.IsZero())
+}
+
+// Test PublicKey
+
+func TestPublicKey_JSONMarshaling(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ pubKey := PublicKey{
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(pubKey)
+ require.NoError(t, err)
+
+ var decoded PublicKey
+ err = json.Unmarshal(jsonData, &decoded)
+ require.NoError(t, err)
+
+ assert.Equal(t, pubKey.Key, decoded.Key)
+ assert.Equal(t, pubKey.Metadata.ID, decoded.Metadata.ID)
+}
+
+// Test parsePublicKey
+
+func TestParsePublicKey_Valid(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ metadata := KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
+ }
+
+ pubKey := PublicKey{
+ Key: pub,
+ Metadata: metadata,
+ }
+
+ // Marshal to JSON
+ jsonData, err := json.Marshal(pubKey)
+ require.NoError(t, err)
+
+ // Encode to PEM
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPublic,
+ Bytes: jsonData,
+ })
+
+ // Parse it back
+ parsed, rest, err := parsePublicKey(pemData, tagRootPublic)
+ require.NoError(t, err)
+ assert.Empty(t, rest)
+ assert.Equal(t, pub, parsed.Key)
+ assert.Equal(t, metadata.ID, parsed.Metadata.ID)
+}
+
+func TestParsePublicKey_InvalidPEM(t *testing.T) {
+ invalidPEM := []byte("not a PEM")
+
+ _, _, err := parsePublicKey(invalidPEM, tagRootPublic)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to decode PEM")
+}
+
+func TestParsePublicKey_WrongType(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ pubKey := PublicKey{
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(pubKey)
+ require.NoError(t, err)
+
+ // Encode with wrong type
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: "WRONG TYPE",
+ Bytes: jsonData,
+ })
+
+ _, _, err = parsePublicKey(pemData, tagRootPublic)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "PEM type")
+}
+
+func TestParsePublicKey_InvalidJSON(t *testing.T) {
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPublic,
+ Bytes: []byte("invalid json"),
+ })
+
+ _, _, err := parsePublicKey(pemData, tagRootPublic)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to unmarshal")
+}
+
+func TestParsePublicKey_InvalidKeySize(t *testing.T) {
+ // Create a public key with wrong size
+ pubKey := PublicKey{
+ Key: []byte{0x01, 0x02, 0x03}, // Too short
+ Metadata: KeyMetadata{
+ ID: KeyID{},
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(pubKey)
+ require.NoError(t, err)
+
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPublic,
+ Bytes: jsonData,
+ })
+
+ _, _, err = parsePublicKey(pemData, tagRootPublic)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "incorrect Ed25519 public key size")
+}
+
+func TestParsePublicKey_IDRecomputation(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ // Create a public key with WRONG ID
+ wrongID := KeyID{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
+ pubKey := PublicKey{
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: wrongID,
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(pubKey)
+ require.NoError(t, err)
+
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPublic,
+ Bytes: jsonData,
+ })
+
+ // Parse should recompute the correct ID
+ parsed, _, err := parsePublicKey(pemData, tagRootPublic)
+ require.NoError(t, err)
+
+ correctID := computeKeyID(pub)
+ assert.Equal(t, correctID, parsed.Metadata.ID)
+ assert.NotEqual(t, wrongID, parsed.Metadata.ID)
+}
+
+// Test parsePublicKeyBundle
+
+func TestParsePublicKeyBundle_Single(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ pubKey := PublicKey{
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(pubKey)
+ require.NoError(t, err)
+
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPublic,
+ Bytes: jsonData,
+ })
+
+ keys, err := parsePublicKeyBundle(pemData, tagRootPublic)
+ require.NoError(t, err)
+ assert.Len(t, keys, 1)
+ assert.Equal(t, pub, keys[0].Key)
+}
+
+func TestParsePublicKeyBundle_Multiple(t *testing.T) {
+ var bundle []byte
+
+ // Create 3 keys
+ for i := 0; i < 3; i++ {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ pubKey := PublicKey{
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(pubKey)
+ require.NoError(t, err)
+
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPublic,
+ Bytes: jsonData,
+ })
+
+ bundle = append(bundle, pemData...)
+ }
+
+ keys, err := parsePublicKeyBundle(bundle, tagRootPublic)
+ require.NoError(t, err)
+ assert.Len(t, keys, 3)
+}
+
+func TestParsePublicKeyBundle_Empty(t *testing.T) {
+ _, err := parsePublicKeyBundle([]byte{}, tagRootPublic)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "no keys found")
+}
+
+func TestParsePublicKeyBundle_Invalid(t *testing.T) {
+ _, err := parsePublicKeyBundle([]byte("invalid data"), tagRootPublic)
+ assert.Error(t, err)
+}
+
+// Test PrivateKey
+
+func TestPrivateKey_JSONMarshaling(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ privKey := PrivateKey{
+ Key: priv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(privKey)
+ require.NoError(t, err)
+
+ var decoded PrivateKey
+ err = json.Unmarshal(jsonData, &decoded)
+ require.NoError(t, err)
+
+ assert.Equal(t, privKey.Key, decoded.Key)
+ assert.Equal(t, privKey.Metadata.ID, decoded.Metadata.ID)
+}
+
+// Test parsePrivateKey
+
+func TestParsePrivateKey_Valid(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ privKey := PrivateKey{
+ Key: priv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(privKey)
+ require.NoError(t, err)
+
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPrivate,
+ Bytes: jsonData,
+ })
+
+ parsed, err := parsePrivateKey(pemData, tagRootPrivate)
+ require.NoError(t, err)
+ assert.Equal(t, priv, parsed.Key)
+}
+
+func TestParsePrivateKey_InvalidPEM(t *testing.T) {
+ _, err := parsePrivateKey([]byte("not a PEM"), tagRootPrivate)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to decode PEM")
+}
+
+func TestParsePrivateKey_TrailingData(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ privKey := PrivateKey{
+ Key: priv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(privKey)
+ require.NoError(t, err)
+
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPrivate,
+ Bytes: jsonData,
+ })
+
+ // Add trailing data
+ pemData = append(pemData, []byte("extra data")...)
+
+ _, err = parsePrivateKey(pemData, tagRootPrivate)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "trailing PEM data")
+}
+
+func TestParsePrivateKey_WrongType(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ privKey := PrivateKey{
+ Key: priv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(privKey)
+ require.NoError(t, err)
+
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: "WRONG TYPE",
+ Bytes: jsonData,
+ })
+
+ _, err = parsePrivateKey(pemData, tagRootPrivate)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "PEM type")
+}
+
+func TestParsePrivateKey_InvalidKeySize(t *testing.T) {
+ privKey := PrivateKey{
+ Key: []byte{0x01, 0x02, 0x03}, // Too short
+ Metadata: KeyMetadata{
+ ID: KeyID{},
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(privKey)
+ require.NoError(t, err)
+
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPrivate,
+ Bytes: jsonData,
+ })
+
+ _, err = parsePrivateKey(pemData, tagRootPrivate)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "incorrect Ed25519 private key size")
+}
+
+// Test verifyAny
+
+func TestVerifyAny_ValidSignature(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ message := []byte("test message")
+ signature := ed25519.Sign(priv, message)
+
+ rootKeys := []PublicKey{
+ {
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ result := verifyAny(rootKeys, message, signature)
+ assert.True(t, result)
+}
+
+func TestVerifyAny_InvalidSignature(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ message := []byte("test message")
+ invalidSignature := make([]byte, ed25519.SignatureSize)
+
+ rootKeys := []PublicKey{
+ {
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ result := verifyAny(rootKeys, message, invalidSignature)
+ assert.False(t, result)
+}
+
+func TestVerifyAny_MultipleKeys(t *testing.T) {
+ // Create 3 key pairs
+ pub1, priv1, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ pub2, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ pub3, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ message := []byte("test message")
+ signature := ed25519.Sign(priv1, message)
+
+ rootKeys := []PublicKey{
+ {Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}},
+ {Key: pub1, Metadata: KeyMetadata{ID: computeKeyID(pub1)}}, // Correct key in middle
+ {Key: pub3, Metadata: KeyMetadata{ID: computeKeyID(pub3)}},
+ }
+
+ result := verifyAny(rootKeys, message, signature)
+ assert.True(t, result)
+}
+
+func TestVerifyAny_NoMatchingKey(t *testing.T) {
+ _, priv1, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ pub2, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ message := []byte("test message")
+ signature := ed25519.Sign(priv1, message)
+
+ // Only include pub2, not pub1
+ rootKeys := []PublicKey{
+ {Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}},
+ }
+
+ result := verifyAny(rootKeys, message, signature)
+ assert.False(t, result)
+}
+
+func TestVerifyAny_EmptyKeys(t *testing.T) {
+ message := []byte("test message")
+ signature := make([]byte, ed25519.SignatureSize)
+
+ result := verifyAny([]PublicKey{}, message, signature)
+ assert.False(t, result)
+}
+
+func TestVerifyAny_TamperedMessage(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ message := []byte("test message")
+ signature := ed25519.Sign(priv, message)
+
+ rootKeys := []PublicKey{
+ {Key: pub, Metadata: KeyMetadata{ID: computeKeyID(pub)}},
+ }
+
+ // Verify with different message
+ tamperedMessage := []byte("different message")
+ result := verifyAny(rootKeys, tamperedMessage, signature)
+ assert.False(t, result)
+}
diff --git a/client/internal/updatemanager/reposign/revocation.go b/client/internal/updatemanager/reposign/revocation.go
new file mode 100644
index 000000000..e679e212f
--- /dev/null
+++ b/client/internal/updatemanager/reposign/revocation.go
@@ -0,0 +1,229 @@
+package reposign
+
+import (
+ "crypto/ed25519"
+ "encoding/binary"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ maxRevocationSignatureAge = 10 * 365 * 24 * time.Hour
+ defaultRevocationListExpiration = 365 * 24 * time.Hour
+)
+
+type RevocationList struct {
+ Revoked map[KeyID]time.Time `json:"revoked"` // KeyID -> revocation time
+ LastUpdated time.Time `json:"last_updated"` // When the list was last modified
+ ExpiresAt time.Time `json:"expires_at"` // When the list expires
+}
+
+func (rl RevocationList) MarshalJSON() ([]byte, error) {
+ // Convert map[KeyID]time.Time to map[string]time.Time
+ strMap := make(map[string]time.Time, len(rl.Revoked))
+ for k, v := range rl.Revoked {
+ strMap[k.String()] = v
+ }
+
+ return json.Marshal(map[string]interface{}{
+ "revoked": strMap,
+ "last_updated": rl.LastUpdated,
+ "expires_at": rl.ExpiresAt,
+ })
+}
+
+func (rl *RevocationList) UnmarshalJSON(data []byte) error {
+ var temp struct {
+ Revoked map[string]time.Time `json:"revoked"`
+ LastUpdated time.Time `json:"last_updated"`
+ ExpiresAt time.Time `json:"expires_at"`
+ Version int `json:"version"`
+ }
+
+ if err := json.Unmarshal(data, &temp); err != nil {
+ return err
+ }
+
+ // Convert map[string]time.Time back to map[KeyID]time.Time
+ rl.Revoked = make(map[KeyID]time.Time, len(temp.Revoked))
+ for k, v := range temp.Revoked {
+ kid, err := ParseKeyID(k)
+ if err != nil {
+ return fmt.Errorf("failed to parse KeyID %q: %w", k, err)
+ }
+ rl.Revoked[kid] = v
+ }
+
+ rl.LastUpdated = temp.LastUpdated
+ rl.ExpiresAt = temp.ExpiresAt
+
+ return nil
+}
+
+func ParseRevocationList(data []byte) (*RevocationList, error) {
+ var rl RevocationList
+ if err := json.Unmarshal(data, &rl); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal revocation list: %w", err)
+ }
+
+ // Initialize the map if it's nil (in case of empty JSON object)
+ if rl.Revoked == nil {
+ rl.Revoked = make(map[KeyID]time.Time)
+ }
+
+ if rl.LastUpdated.IsZero() {
+ return nil, fmt.Errorf("revocation list missing last_updated timestamp")
+ }
+
+ if rl.ExpiresAt.IsZero() {
+ return nil, fmt.Errorf("revocation list missing expires_at timestamp")
+ }
+
+ return &rl, nil
+}
+
+func ValidateRevocationList(publicRootKeys []PublicKey, data []byte, signature Signature) (*RevocationList, error) {
+ revoList, err := ParseRevocationList(data)
+ if err != nil {
+ log.Debugf("failed to parse revocation list: %s", err)
+ return nil, err
+ }
+
+ now := time.Now().UTC()
+
+ // Validate signature timestamp
+ if signature.Timestamp.After(now.Add(maxClockSkew)) {
+ err := fmt.Errorf("revocation signature timestamp is in the future: %v", signature.Timestamp)
+ log.Debugf("revocation list signature error: %v", err)
+ return nil, err
+ }
+
+ if now.Sub(signature.Timestamp) > maxRevocationSignatureAge {
+ err := fmt.Errorf("revocation list signature is too old: %v (created %v)",
+ now.Sub(signature.Timestamp), signature.Timestamp)
+ log.Debugf("revocation list signature error: %v", err)
+ return nil, err
+ }
+
+ // Ensure LastUpdated is not in the future (with clock skew tolerance)
+ if revoList.LastUpdated.After(now.Add(maxClockSkew)) {
+ err := fmt.Errorf("revocation list LastUpdated is in the future: %v", revoList.LastUpdated)
+ log.Errorf("rejecting future-dated revocation list: %v", err)
+ return nil, err
+ }
+
+ // Check if the revocation list has expired
+ if now.After(revoList.ExpiresAt) {
+ err := fmt.Errorf("revocation list expired at %v (current time: %v)", revoList.ExpiresAt, now)
+ log.Errorf("rejecting expired revocation list: %v", err)
+ return nil, err
+ }
+
+ // Ensure ExpiresAt is not in the future by more than the expected expiration window
+ // (allows some clock skew but prevents maliciously long expiration times)
+ if revoList.ExpiresAt.After(now.Add(maxRevocationSignatureAge)) {
+ err := fmt.Errorf("revocation list ExpiresAt is too far in the future: %v", revoList.ExpiresAt)
+ log.Errorf("rejecting revocation list with invalid expiration: %v", err)
+ return nil, err
+ }
+
+ // Validate signature timestamp is close to LastUpdated
+ // (prevents signing old lists with new timestamps)
+ timeDiff := signature.Timestamp.Sub(revoList.LastUpdated).Abs()
+ if timeDiff > maxClockSkew {
+ err := fmt.Errorf("signature timestamp %v differs too much from list LastUpdated %v (diff: %v)",
+ signature.Timestamp, revoList.LastUpdated, timeDiff)
+ log.Errorf("timestamp mismatch in revocation list: %v", err)
+ return nil, err
+ }
+
+ // Reconstruct the signed message: revocation_list_data || timestamp || version
+ msg := make([]byte, 0, len(data)+8)
+ msg = append(msg, data...)
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
+
+ if !verifyAny(publicRootKeys, msg, signature.Signature) {
+ return nil, errors.New("revocation list verification failed")
+ }
+ return revoList, nil
+}
+
+func CreateRevocationList(privateRootKey RootKey, expiration time.Duration) ([]byte, []byte, error) {
+ now := time.Now()
+ rl := RevocationList{
+ Revoked: make(map[KeyID]time.Time),
+ LastUpdated: now.UTC(),
+ ExpiresAt: now.Add(expiration).UTC(),
+ }
+
+ signature, err := signRevocationList(privateRootKey, rl)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err)
+ }
+
+ rlData, err := json.Marshal(&rl)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err)
+ }
+
+ signData, err := json.Marshal(signature)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to marshal signature: %w", err)
+ }
+
+ return rlData, signData, nil
+}
+
+func ExtendRevocationList(privateRootKey RootKey, rl RevocationList, kid KeyID, expiration time.Duration) ([]byte, []byte, error) {
+ now := time.Now().UTC()
+
+ rl.Revoked[kid] = now
+ rl.LastUpdated = now
+ rl.ExpiresAt = now.Add(expiration)
+
+ signature, err := signRevocationList(privateRootKey, rl)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err)
+ }
+
+ rlData, err := json.Marshal(&rl)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err)
+ }
+
+ signData, err := json.Marshal(signature)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to marshal signature: %w", err)
+ }
+
+ return rlData, signData, nil
+}
+
+func signRevocationList(privateRootKey RootKey, rl RevocationList) (*Signature, error) {
+ data, err := json.Marshal(rl)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal revocation list for signing: %w", err)
+ }
+
+ timestamp := time.Now().UTC()
+
+ msg := make([]byte, 0, len(data)+8)
+ msg = append(msg, data...)
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
+
+ sig := ed25519.Sign(privateRootKey.Key, msg)
+
+ signature := &Signature{
+ Signature: sig,
+ Timestamp: timestamp,
+ KeyID: privateRootKey.Metadata.ID,
+ Algorithm: "ed25519",
+ HashAlgo: "sha512",
+ }
+
+ return signature, nil
+}
diff --git a/client/internal/updatemanager/reposign/revocation_test.go b/client/internal/updatemanager/reposign/revocation_test.go
new file mode 100644
index 000000000..d6d748f3d
--- /dev/null
+++ b/client/internal/updatemanager/reposign/revocation_test.go
@@ -0,0 +1,860 @@
+package reposign
+
+import (
+ "crypto/ed25519"
+ "crypto/rand"
+ "encoding/json"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// Test RevocationList marshaling/unmarshaling
+
+func TestRevocationList_MarshalJSON(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ keyID := computeKeyID(pub)
+ revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+ lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
+ expiresAt := time.Date(2024, 4, 15, 11, 0, 0, 0, time.UTC)
+
+ rl := &RevocationList{
+ Revoked: map[KeyID]time.Time{
+ keyID: revokedTime,
+ },
+ LastUpdated: lastUpdated,
+ ExpiresAt: expiresAt,
+ }
+
+ jsonData, err := json.Marshal(rl)
+ require.NoError(t, err)
+
+ // Verify it can be unmarshaled back
+ var decoded map[string]interface{}
+ err = json.Unmarshal(jsonData, &decoded)
+ require.NoError(t, err)
+
+ assert.Contains(t, decoded, "revoked")
+ assert.Contains(t, decoded, "last_updated")
+ assert.Contains(t, decoded, "expires_at")
+}
+
+func TestRevocationList_UnmarshalJSON(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ keyID := computeKeyID(pub)
+ revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+ lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
+
+ jsonData := map[string]interface{}{
+ "revoked": map[string]string{
+ keyID.String(): revokedTime.Format(time.RFC3339),
+ },
+ "last_updated": lastUpdated.Format(time.RFC3339),
+ }
+
+ jsonBytes, err := json.Marshal(jsonData)
+ require.NoError(t, err)
+
+ var rl RevocationList
+ err = json.Unmarshal(jsonBytes, &rl)
+ require.NoError(t, err)
+
+ assert.Len(t, rl.Revoked, 1)
+ assert.Contains(t, rl.Revoked, keyID)
+ assert.Equal(t, lastUpdated.Unix(), rl.LastUpdated.Unix())
+}
+
+func TestRevocationList_MarshalUnmarshal_Roundtrip(t *testing.T) {
+ pub1, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+ pub2, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ keyID1 := computeKeyID(pub1)
+ keyID2 := computeKeyID(pub2)
+
+ original := &RevocationList{
+ Revoked: map[KeyID]time.Time{
+ keyID1: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
+ keyID2: time.Date(2024, 2, 20, 14, 45, 0, 0, time.UTC),
+ },
+ LastUpdated: time.Date(2024, 2, 20, 15, 0, 0, 0, time.UTC),
+ }
+
+ // Marshal
+ jsonData, err := original.MarshalJSON()
+ require.NoError(t, err)
+
+ // Unmarshal
+ var decoded RevocationList
+ err = decoded.UnmarshalJSON(jsonData)
+ require.NoError(t, err)
+
+ // Verify
+ assert.Len(t, decoded.Revoked, 2)
+ assert.Equal(t, original.Revoked[keyID1].Unix(), decoded.Revoked[keyID1].Unix())
+ assert.Equal(t, original.Revoked[keyID2].Unix(), decoded.Revoked[keyID2].Unix())
+ assert.Equal(t, original.LastUpdated.Unix(), decoded.LastUpdated.Unix())
+}
+
+func TestRevocationList_UnmarshalJSON_InvalidKeyID(t *testing.T) {
+ jsonData := []byte(`{
+ "revoked": {
+ "invalid_key_id": "2024-01-15T10:30:00Z"
+ },
+ "last_updated": "2024-01-15T11:00:00Z"
+ }`)
+
+ var rl RevocationList
+ err := json.Unmarshal(jsonData, &rl)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to parse KeyID")
+}
+
+func TestRevocationList_EmptyRevoked(t *testing.T) {
+ rl := &RevocationList{
+ Revoked: make(map[KeyID]time.Time),
+ LastUpdated: time.Now().UTC(),
+ }
+
+ jsonData, err := rl.MarshalJSON()
+ require.NoError(t, err)
+
+ var decoded RevocationList
+ err = decoded.UnmarshalJSON(jsonData)
+ require.NoError(t, err)
+
+ assert.Empty(t, decoded.Revoked)
+ assert.NotNil(t, decoded.Revoked)
+}
+
+// Test ParseRevocationList
+
+func TestParseRevocationList_Valid(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ keyID := computeKeyID(pub)
+ revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+ lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
+
+ rl := RevocationList{
+ Revoked: map[KeyID]time.Time{
+ keyID: revokedTime,
+ },
+ LastUpdated: lastUpdated,
+ ExpiresAt: time.Date(2025, 2, 20, 14, 45, 0, 0, time.UTC),
+ }
+
+ jsonData, err := rl.MarshalJSON()
+ require.NoError(t, err)
+
+ parsed, err := ParseRevocationList(jsonData)
+ require.NoError(t, err)
+ assert.NotNil(t, parsed)
+ assert.Len(t, parsed.Revoked, 1)
+ assert.Equal(t, lastUpdated.Unix(), parsed.LastUpdated.Unix())
+}
+
+func TestParseRevocationList_InvalidJSON(t *testing.T) {
+ invalidJSON := []byte("not valid json")
+
+ _, err := ParseRevocationList(invalidJSON)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to unmarshal")
+}
+
+func TestParseRevocationList_MissingLastUpdated(t *testing.T) {
+ jsonData := []byte(`{
+ "revoked": {}
+ }`)
+
+ _, err := ParseRevocationList(jsonData)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "missing last_updated")
+}
+
+func TestParseRevocationList_EmptyObject(t *testing.T) {
+ jsonData := []byte(`{}`)
+
+ _, err := ParseRevocationList(jsonData)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "missing last_updated")
+}
+
+func TestParseRevocationList_NilRevoked(t *testing.T) {
+ lastUpdated := time.Now().UTC()
+ expiresAt := lastUpdated.Add(90 * 24 * time.Hour)
+ jsonData := []byte(`{
+ "last_updated": "` + lastUpdated.Format(time.RFC3339) + `",
+ "expires_at": "` + expiresAt.Format(time.RFC3339) + `"
+ }`)
+
+ parsed, err := ParseRevocationList(jsonData)
+ require.NoError(t, err)
+ assert.NotNil(t, parsed.Revoked)
+ assert.Empty(t, parsed.Revoked)
+}
+
+func TestParseRevocationList_MissingExpiresAt(t *testing.T) {
+ lastUpdated := time.Now().UTC()
+ jsonData := []byte(`{
+ "revoked": {},
+ "last_updated": "` + lastUpdated.Format(time.RFC3339) + `"
+ }`)
+
+ _, err := ParseRevocationList(jsonData)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "missing expires_at")
+}
+
+// Test ValidateRevocationList
+
+func TestValidateRevocationList_Valid(t *testing.T) {
+ // Generate root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create revocation list
+ rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ signature, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Validate
+ rl, err := ValidateRevocationList(rootKeys, rlData, *signature)
+ require.NoError(t, err)
+ assert.NotNil(t, rl)
+ assert.Empty(t, rl.Revoked)
+}
+
+func TestValidateRevocationList_InvalidSignature(t *testing.T) {
+ // Generate root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create revocation list
+ rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ // Create invalid signature
+ invalidSig := Signature{
+ Signature: make([]byte, 64),
+ Timestamp: time.Now().UTC(),
+ KeyID: computeKeyID(rootPub),
+ Algorithm: "ed25519",
+ HashAlgo: "sha512",
+ }
+
+ // Validate should fail
+ _, err = ValidateRevocationList(rootKeys, rlData, invalidSig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "verification failed")
+}
+
+func TestValidateRevocationList_FutureTimestamp(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ signature, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Modify timestamp to be in the future
+ signature.Timestamp = time.Now().UTC().Add(10 * time.Minute)
+
+ _, err = ValidateRevocationList(rootKeys, rlData, *signature)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "in the future")
+}
+
+func TestValidateRevocationList_TooOld(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ signature, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Modify timestamp to be too old
+ signature.Timestamp = time.Now().UTC().Add(-20 * 365 * 24 * time.Hour)
+
+ _, err = ValidateRevocationList(rootKeys, rlData, *signature)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "too old")
+}
+
+func TestValidateRevocationList_InvalidJSON(t *testing.T) {
+ rootPub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ signature := Signature{
+ Signature: make([]byte, 64),
+ Timestamp: time.Now().UTC(),
+ KeyID: computeKeyID(rootPub),
+ Algorithm: "ed25519",
+ HashAlgo: "sha512",
+ }
+
+ _, err = ValidateRevocationList(rootKeys, []byte("invalid json"), signature)
+ assert.Error(t, err)
+}
+
+func TestValidateRevocationList_FutureLastUpdated(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create revocation list with future LastUpdated
+ rl := RevocationList{
+ Revoked: make(map[KeyID]time.Time),
+ LastUpdated: time.Now().UTC().Add(10 * time.Minute),
+ ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour),
+ }
+
+ rlData, err := json.Marshal(rl)
+ require.NoError(t, err)
+
+ // Sign it
+ sig, err := signRevocationList(rootKey, rl)
+ require.NoError(t, err)
+
+ _, err = ValidateRevocationList(rootKeys, rlData, *sig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "LastUpdated is in the future")
+}
+
+func TestValidateRevocationList_TimestampMismatch(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create revocation list with LastUpdated far in the past
+ rl := RevocationList{
+ Revoked: make(map[KeyID]time.Time),
+ LastUpdated: time.Now().UTC().Add(-1 * time.Hour),
+ ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour),
+ }
+
+ rlData, err := json.Marshal(rl)
+ require.NoError(t, err)
+
+ // Sign it with current timestamp
+ sig, err := signRevocationList(rootKey, rl)
+ require.NoError(t, err)
+
+ // Modify signature timestamp to differ too much from LastUpdated
+ sig.Timestamp = time.Now().UTC()
+
+ _, err = ValidateRevocationList(rootKeys, rlData, *sig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "differs too much")
+}
+
+func TestValidateRevocationList_Expired(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create revocation list that expired in the past
+ now := time.Now().UTC()
+ rl := RevocationList{
+ Revoked: make(map[KeyID]time.Time),
+ LastUpdated: now.Add(-100 * 24 * time.Hour),
+ ExpiresAt: now.Add(-10 * 24 * time.Hour), // Expired 10 days ago
+ }
+
+ rlData, err := json.Marshal(rl)
+ require.NoError(t, err)
+
+ // Sign it
+ sig, err := signRevocationList(rootKey, rl)
+ require.NoError(t, err)
+ // Adjust signature timestamp to match LastUpdated
+ sig.Timestamp = rl.LastUpdated
+
+ _, err = ValidateRevocationList(rootKeys, rlData, *sig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "expired")
+}
+
+func TestValidateRevocationList_ExpiresAtTooFarInFuture(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create revocation list with ExpiresAt too far in the future (beyond maxRevocationSignatureAge)
+ now := time.Now().UTC()
+ rl := RevocationList{
+ Revoked: make(map[KeyID]time.Time),
+ LastUpdated: now,
+ ExpiresAt: now.Add(15 * 365 * 24 * time.Hour), // 15 years in the future
+ }
+
+ rlData, err := json.Marshal(rl)
+ require.NoError(t, err)
+
+ // Sign it
+ sig, err := signRevocationList(rootKey, rl)
+ require.NoError(t, err)
+
+ _, err = ValidateRevocationList(rootKeys, rlData, *sig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "too far in the future")
+}
+
+// Test CreateRevocationList
+
+func TestCreateRevocationList_Valid(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+ assert.NotEmpty(t, rlData)
+ assert.NotEmpty(t, sigData)
+
+ // Verify it can be parsed
+ rl, err := ParseRevocationList(rlData)
+ require.NoError(t, err)
+ assert.Empty(t, rl.Revoked)
+ assert.False(t, rl.LastUpdated.IsZero())
+
+ // Verify signature can be parsed
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sig.Signature)
+}
+
+// Test ExtendRevocationList
+
+func TestExtendRevocationList_AddKey(t *testing.T) {
+ // Generate root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create empty revocation list
+ rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err := ParseRevocationList(rlData)
+ require.NoError(t, err)
+ assert.Empty(t, rl.Revoked)
+
+ // Generate a key to revoke
+ revokedPub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+ revokedKeyID := computeKeyID(revokedPub)
+
+ // Extend the revocation list
+ newRLData, newSigData, err := ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ // Verify the new list
+ newRL, err := ParseRevocationList(newRLData)
+ require.NoError(t, err)
+ assert.Len(t, newRL.Revoked, 1)
+ assert.Contains(t, newRL.Revoked, revokedKeyID)
+
+ // Verify signature
+ sig, err := ParseSignature(newSigData)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sig.Signature)
+}
+
+func TestExtendRevocationList_MultipleKeys(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create empty revocation list
+ rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err := ParseRevocationList(rlData)
+ require.NoError(t, err)
+
+ // Add first key
+ key1Pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+ key1ID := computeKeyID(key1Pub)
+
+ rlData, _, err = ExtendRevocationList(rootKey, *rl, key1ID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err = ParseRevocationList(rlData)
+ require.NoError(t, err)
+ assert.Len(t, rl.Revoked, 1)
+
+ // Add second key
+ key2Pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+ key2ID := computeKeyID(key2Pub)
+
+ rlData, _, err = ExtendRevocationList(rootKey, *rl, key2ID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err = ParseRevocationList(rlData)
+ require.NoError(t, err)
+ assert.Len(t, rl.Revoked, 2)
+ assert.Contains(t, rl.Revoked, key1ID)
+ assert.Contains(t, rl.Revoked, key2ID)
+}
+
+func TestExtendRevocationList_DuplicateKey(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create empty revocation list
+ rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err := ParseRevocationList(rlData)
+ require.NoError(t, err)
+
+ // Add a key
+ keyPub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+ keyID := computeKeyID(keyPub)
+
+ rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err = ParseRevocationList(rlData)
+ require.NoError(t, err)
+ firstRevocationTime := rl.Revoked[keyID]
+
+ // Wait a bit
+ time.Sleep(10 * time.Millisecond)
+
+ // Add the same key again
+ rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err = ParseRevocationList(rlData)
+ require.NoError(t, err)
+ assert.Len(t, rl.Revoked, 1)
+
+ // The revocation time should be updated
+ secondRevocationTime := rl.Revoked[keyID]
+ assert.True(t, secondRevocationTime.After(firstRevocationTime) || secondRevocationTime.Equal(firstRevocationTime))
+}
+
+func TestExtendRevocationList_UpdatesLastUpdated(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create revocation list
+ rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err := ParseRevocationList(rlData)
+ require.NoError(t, err)
+ firstLastUpdated := rl.LastUpdated
+
+ // Wait a bit
+ time.Sleep(10 * time.Millisecond)
+
+ // Extend list
+ keyPub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+ keyID := computeKeyID(keyPub)
+
+ rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err = ParseRevocationList(rlData)
+ require.NoError(t, err)
+
+ // LastUpdated should be updated
+ assert.True(t, rl.LastUpdated.After(firstLastUpdated))
+}
+
+// Integration test
+
+func TestRevocationList_FullWorkflow(t *testing.T) {
+ // Create root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Step 1: Create empty revocation list
+ rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ // Step 2: Validate it
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ rl, err := ValidateRevocationList(rootKeys, rlData, *sig)
+ require.NoError(t, err)
+ assert.Empty(t, rl.Revoked)
+
+ // Step 3: Revoke a key
+ revokedPub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+ revokedKeyID := computeKeyID(revokedPub)
+
+ rlData, sigData, err = ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ // Step 4: Validate the extended list
+ sig, err = ParseSignature(sigData)
+ require.NoError(t, err)
+
+ rl, err = ValidateRevocationList(rootKeys, rlData, *sig)
+ require.NoError(t, err)
+ assert.Len(t, rl.Revoked, 1)
+ assert.Contains(t, rl.Revoked, revokedKeyID)
+
+ // Step 5: Verify the revocation time is reasonable
+ revTime := rl.Revoked[revokedKeyID]
+ now := time.Now().UTC()
+ assert.True(t, revTime.Before(now) || revTime.Equal(now))
+ assert.True(t, now.Sub(revTime) < time.Minute)
+}
diff --git a/client/internal/updatemanager/reposign/root.go b/client/internal/updatemanager/reposign/root.go
new file mode 100644
index 000000000..2c3ca54a0
--- /dev/null
+++ b/client/internal/updatemanager/reposign/root.go
@@ -0,0 +1,120 @@
+package reposign
+
+import (
+ "crypto/ed25519"
+ "crypto/rand"
+ "encoding/binary"
+ "encoding/json"
+ "encoding/pem"
+ "fmt"
+ "time"
+)
+
+const (
+ tagRootPrivate = "ROOT PRIVATE KEY"
+ tagRootPublic = "ROOT PUBLIC KEY"
+)
+
+// RootKey is a root Key used to sign signing keys
+type RootKey struct {
+ PrivateKey
+}
+
+func (k RootKey) String() string {
+ return fmt.Sprintf(
+ "RootKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]",
+ k.Metadata.ID,
+ k.Metadata.CreatedAt.Format(time.RFC3339),
+ k.Metadata.ExpiresAt.Format(time.RFC3339),
+ )
+}
+
+func ParseRootKey(privKeyPEM []byte) (*RootKey, error) {
+ pk, err := parsePrivateKey(privKeyPEM, tagRootPrivate)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse root Key: %w", err)
+ }
+ return &RootKey{pk}, nil
+}
+
+// ParseRootPublicKey parses a root public key from PEM format
+func ParseRootPublicKey(pubKeyPEM []byte) (PublicKey, error) {
+ pk, _, err := parsePublicKey(pubKeyPEM, tagRootPublic)
+ if err != nil {
+ return PublicKey{}, fmt.Errorf("failed to parse root public key: %w", err)
+ }
+ return pk, nil
+}
+
+// GenerateRootKey generates a new root Key pair with Metadata
+func GenerateRootKey(expiration time.Duration) (*RootKey, []byte, []byte, error) {
+ now := time.Now()
+ expirationTime := now.Add(expiration)
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ metadata := KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: now.UTC(),
+ ExpiresAt: expirationTime.UTC(),
+ }
+
+ rk := &RootKey{
+ PrivateKey{
+ Key: priv,
+ Metadata: metadata,
+ },
+ }
+
+ // Marshal PrivateKey struct to JSON
+ privJSON, err := json.Marshal(rk.PrivateKey)
+ if err != nil {
+ return nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
+ }
+
+ // Marshal PublicKey struct to JSON
+ pubKey := PublicKey{
+ Key: pub,
+ Metadata: metadata,
+ }
+ pubJSON, err := json.Marshal(pubKey)
+ if err != nil {
+ return nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
+ }
+
+ // Encode to PEM with metadata embedded in bytes
+ privPEM := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPrivate,
+ Bytes: privJSON,
+ })
+
+ pubPEM := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPublic,
+ Bytes: pubJSON,
+ })
+
+ return rk, privPEM, pubPEM, nil
+}
+
+func SignArtifactKey(rootKey RootKey, data []byte) ([]byte, error) {
+ timestamp := time.Now().UTC()
+
+ // This ensures the timestamp is cryptographically bound to the signature
+ msg := make([]byte, 0, len(data)+8)
+ msg = append(msg, data...)
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
+
+ sig := ed25519.Sign(rootKey.Key, msg)
+ // Create signature bundle with timestamp and Metadata
+ bundle := Signature{
+ Signature: sig,
+ Timestamp: timestamp,
+ KeyID: rootKey.Metadata.ID,
+ Algorithm: "ed25519",
+ HashAlgo: "sha512",
+ }
+
+ return json.Marshal(bundle)
+}
diff --git a/client/internal/updatemanager/reposign/root_test.go b/client/internal/updatemanager/reposign/root_test.go
new file mode 100644
index 000000000..e75e29729
--- /dev/null
+++ b/client/internal/updatemanager/reposign/root_test.go
@@ -0,0 +1,476 @@
+package reposign
+
+import (
+ "crypto/ed25519"
+ "crypto/rand"
+ "encoding/binary"
+ "encoding/json"
+ "encoding/pem"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// Test RootKey.String()
+
+func TestRootKey_String(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+ expiresAt := time.Date(2034, 1, 15, 10, 30, 0, 0, time.UTC)
+
+ rk := RootKey{
+ PrivateKey{
+ Key: priv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: createdAt,
+ ExpiresAt: expiresAt,
+ },
+ },
+ }
+
+ str := rk.String()
+ assert.Contains(t, str, "RootKey")
+ assert.Contains(t, str, computeKeyID(pub).String())
+ assert.Contains(t, str, "2024-01-15")
+ assert.Contains(t, str, "2034-01-15")
+}
+
+func TestRootKey_String_NoExpiration(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+
+ rk := RootKey{
+ PrivateKey{
+ Key: priv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: createdAt,
+ ExpiresAt: time.Time{}, // No expiration
+ },
+ },
+ }
+
+ str := rk.String()
+ assert.Contains(t, str, "RootKey")
+ assert.Contains(t, str, "0001-01-01") // Zero time format
+}
+
+// Test GenerateRootKey
+
+func TestGenerateRootKey_Valid(t *testing.T) {
+ expiration := 10 * 365 * 24 * time.Hour // 10 years
+
+ rk, privPEM, pubPEM, err := GenerateRootKey(expiration)
+ require.NoError(t, err)
+ assert.NotNil(t, rk)
+ assert.NotEmpty(t, privPEM)
+ assert.NotEmpty(t, pubPEM)
+
+ // Verify the key has correct metadata
+ assert.False(t, rk.Metadata.CreatedAt.IsZero())
+ assert.False(t, rk.Metadata.ExpiresAt.IsZero())
+ assert.True(t, rk.Metadata.ExpiresAt.After(rk.Metadata.CreatedAt))
+
+ // Verify expiration is approximately correct
+ expectedExpiration := time.Now().Add(expiration)
+ timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration)
+ assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute)
+}
+
+func TestGenerateRootKey_ShortExpiration(t *testing.T) {
+ expiration := 24 * time.Hour // 1 day
+
+ rk, _, _, err := GenerateRootKey(expiration)
+ require.NoError(t, err)
+ assert.NotNil(t, rk)
+
+ // Verify expiration
+ expectedExpiration := time.Now().Add(expiration)
+ timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration)
+ assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute)
+}
+
+func TestGenerateRootKey_ZeroExpiration(t *testing.T) {
+ rk, _, _, err := GenerateRootKey(0)
+ require.NoError(t, err)
+ assert.NotNil(t, rk)
+
+ // With zero expiration, ExpiresAt should be equal to CreatedAt
+ assert.Equal(t, rk.Metadata.CreatedAt, rk.Metadata.ExpiresAt)
+}
+
+func TestGenerateRootKey_PEMFormat(t *testing.T) {
+ rk, privPEM, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Verify private key PEM
+ privBlock, _ := pem.Decode(privPEM)
+ require.NotNil(t, privBlock)
+ assert.Equal(t, tagRootPrivate, privBlock.Type)
+
+ var privKey PrivateKey
+ err = json.Unmarshal(privBlock.Bytes, &privKey)
+ require.NoError(t, err)
+ assert.Equal(t, rk.Key, privKey.Key)
+
+ // Verify public key PEM
+ pubBlock, _ := pem.Decode(pubPEM)
+ require.NotNil(t, pubBlock)
+ assert.Equal(t, tagRootPublic, pubBlock.Type)
+
+ var pubKey PublicKey
+ err = json.Unmarshal(pubBlock.Bytes, &pubKey)
+ require.NoError(t, err)
+ assert.Equal(t, rk.Metadata.ID, pubKey.Metadata.ID)
+}
+
+func TestGenerateRootKey_KeySize(t *testing.T) {
+ rk, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Ed25519 private key should be 64 bytes
+ assert.Equal(t, ed25519.PrivateKeySize, len(rk.Key))
+
+ // Ed25519 public key should be 32 bytes
+ pubKey := rk.Key.Public().(ed25519.PublicKey)
+ assert.Equal(t, ed25519.PublicKeySize, len(pubKey))
+}
+
+func TestGenerateRootKey_UniqueKeys(t *testing.T) {
+ rk1, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ rk2, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Different keys should have different IDs
+ assert.NotEqual(t, rk1.Metadata.ID, rk2.Metadata.ID)
+ assert.NotEqual(t, rk1.Key, rk2.Key)
+}
+
+// Test ParseRootKey
+
+func TestParseRootKey_Valid(t *testing.T) {
+ original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ parsed, err := ParseRootKey(privPEM)
+ require.NoError(t, err)
+ assert.NotNil(t, parsed)
+
+ // Verify the parsed key matches the original
+ assert.Equal(t, original.Key, parsed.Key)
+ assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID)
+ assert.Equal(t, original.Metadata.CreatedAt.Unix(), parsed.Metadata.CreatedAt.Unix())
+ assert.Equal(t, original.Metadata.ExpiresAt.Unix(), parsed.Metadata.ExpiresAt.Unix())
+}
+
+func TestParseRootKey_InvalidPEM(t *testing.T) {
+ _, err := ParseRootKey([]byte("not a valid PEM"))
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to parse")
+}
+
+func TestParseRootKey_EmptyData(t *testing.T) {
+ _, err := ParseRootKey([]byte{})
+ assert.Error(t, err)
+}
+
+func TestParseRootKey_WrongType(t *testing.T) {
+ // Generate an artifact key instead of root key
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ artifactKey, privPEM, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ // Try to parse artifact key as root key
+ _, err = ParseRootKey(privPEM)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "PEM type")
+
+ // Just to use artifactKey to avoid unused variable warning
+ _ = artifactKey
+}
+
+func TestParseRootKey_CorruptedJSON(t *testing.T) {
+ // Create PEM with corrupted JSON
+ corruptedPEM := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPrivate,
+ Bytes: []byte("corrupted json data"),
+ })
+
+ _, err := ParseRootKey(corruptedPEM)
+ assert.Error(t, err)
+}
+
+func TestParseRootKey_InvalidKeySize(t *testing.T) {
+ // Create a key with invalid size
+ invalidKey := PrivateKey{
+ Key: []byte{0x01, 0x02, 0x03}, // Too short
+ Metadata: KeyMetadata{
+ ID: KeyID{},
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ privJSON, err := json.Marshal(invalidKey)
+ require.NoError(t, err)
+
+ invalidPEM := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPrivate,
+ Bytes: privJSON,
+ })
+
+ _, err = ParseRootKey(invalidPEM)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "incorrect Ed25519 private key size")
+}
+
+func TestParseRootKey_Roundtrip(t *testing.T) {
+ // Generate a key
+ original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Parse it
+ parsed, err := ParseRootKey(privPEM)
+ require.NoError(t, err)
+
+ // Generate PEM again from parsed key
+ privJSON2, err := json.Marshal(parsed.PrivateKey)
+ require.NoError(t, err)
+
+ privPEM2 := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPrivate,
+ Bytes: privJSON2,
+ })
+
+ // Parse again
+ parsed2, err := ParseRootKey(privPEM2)
+ require.NoError(t, err)
+
+ // Should still match original
+ assert.Equal(t, original.Key, parsed2.Key)
+ assert.Equal(t, original.Metadata.ID, parsed2.Metadata.ID)
+}
+
+// Test SignArtifactKey
+
+func TestSignArtifactKey_Valid(t *testing.T) {
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ data := []byte("test data to sign")
+ sigData, err := SignArtifactKey(*rootKey, data)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sigData)
+
+ // Parse and verify signature
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sig.Signature)
+ assert.Equal(t, rootKey.Metadata.ID, sig.KeyID)
+ assert.Equal(t, "ed25519", sig.Algorithm)
+ assert.Equal(t, "sha512", sig.HashAlgo)
+ assert.False(t, sig.Timestamp.IsZero())
+}
+
+func TestSignArtifactKey_EmptyData(t *testing.T) {
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ sigData, err := SignArtifactKey(*rootKey, []byte{})
+ require.NoError(t, err)
+ assert.NotEmpty(t, sigData)
+
+ // Should still be able to parse
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sig.Signature)
+}
+
+func TestSignArtifactKey_Verify(t *testing.T) {
+ rootKey, _, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Parse public key
+ pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic)
+ require.NoError(t, err)
+
+ // Sign some data
+ data := []byte("test data for verification")
+ sigData, err := SignArtifactKey(*rootKey, data)
+ require.NoError(t, err)
+
+ // Parse signature
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Reconstruct message
+ msg := make([]byte, 0, len(data)+8)
+ msg = append(msg, data...)
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix()))
+
+ // Verify signature
+ valid := ed25519.Verify(pubKey.Key, msg, sig.Signature)
+ assert.True(t, valid)
+}
+
+func TestSignArtifactKey_DifferentData(t *testing.T) {
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ data1 := []byte("data1")
+ data2 := []byte("data2")
+
+ sig1, err := SignArtifactKey(*rootKey, data1)
+ require.NoError(t, err)
+
+ sig2, err := SignArtifactKey(*rootKey, data2)
+ require.NoError(t, err)
+
+ // Different data should produce different signatures
+ assert.NotEqual(t, sig1, sig2)
+}
+
+func TestSignArtifactKey_MultipleSignatures(t *testing.T) {
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ data := []byte("test data")
+
+ // Sign twice with a small delay
+ sig1, err := SignArtifactKey(*rootKey, data)
+ require.NoError(t, err)
+
+ time.Sleep(10 * time.Millisecond)
+
+ sig2, err := SignArtifactKey(*rootKey, data)
+ require.NoError(t, err)
+
+ // Signatures should be different due to different timestamps
+ assert.NotEqual(t, sig1, sig2)
+
+ // Parse both signatures
+ parsed1, err := ParseSignature(sig1)
+ require.NoError(t, err)
+
+ parsed2, err := ParseSignature(sig2)
+ require.NoError(t, err)
+
+ // Timestamps should be different
+ assert.True(t, parsed2.Timestamp.After(parsed1.Timestamp))
+}
+
+func TestSignArtifactKey_LargeData(t *testing.T) {
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Create 1MB of data
+ largeData := make([]byte, 1024*1024)
+ for i := range largeData {
+ largeData[i] = byte(i % 256)
+ }
+
+ sigData, err := SignArtifactKey(*rootKey, largeData)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sigData)
+
+ // Verify signature can be parsed
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sig.Signature)
+}
+
+func TestSignArtifactKey_TimestampInSignature(t *testing.T) {
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ beforeSign := time.Now().UTC()
+ data := []byte("test data")
+ sigData, err := SignArtifactKey(*rootKey, data)
+ require.NoError(t, err)
+ afterSign := time.Now().UTC()
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Timestamp should be between before and after
+ assert.True(t, sig.Timestamp.After(beforeSign.Add(-time.Second)))
+ assert.True(t, sig.Timestamp.Before(afterSign.Add(time.Second)))
+}
+
+// Integration test
+
+func TestRootKey_FullWorkflow(t *testing.T) {
+ // Step 1: Generate root key
+ rootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
+ require.NoError(t, err)
+ assert.NotNil(t, rootKey)
+ assert.NotEmpty(t, privPEM)
+ assert.NotEmpty(t, pubPEM)
+
+ // Step 2: Parse the private key back
+ parsedRootKey, err := ParseRootKey(privPEM)
+ require.NoError(t, err)
+ assert.Equal(t, rootKey.Key, parsedRootKey.Key)
+ assert.Equal(t, rootKey.Metadata.ID, parsedRootKey.Metadata.ID)
+
+ // Step 3: Generate an artifact key using root key
+ artifactKey, _, artifactPubPEM, artifactSig, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ assert.NotNil(t, artifactKey)
+
+ // Step 4: Verify the artifact key signature
+ pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic)
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(artifactSig)
+ require.NoError(t, err)
+
+ artifactPubKey, _, err := parsePublicKey(artifactPubPEM, tagArtifactPublic)
+ require.NoError(t, err)
+
+ // Reconstruct message - SignArtifactKey signs the PEM, not the JSON
+ msg := make([]byte, 0, len(artifactPubPEM)+8)
+ msg = append(msg, artifactPubPEM...)
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix()))
+
+ // Verify with root public key
+ valid := ed25519.Verify(pubKey.Key, msg, sig.Signature)
+ assert.True(t, valid, "Artifact key signature should be valid")
+
+ // Step 5: Use artifact key to sign data
+ testData := []byte("This is test artifact data")
+ dataSig, err := SignData(*artifactKey, testData)
+ require.NoError(t, err)
+ assert.NotEmpty(t, dataSig)
+
+ // Step 6: Verify the artifact data signature
+ dataSigParsed, err := ParseSignature(dataSig)
+ require.NoError(t, err)
+
+ err = ValidateArtifact([]PublicKey{artifactPubKey}, testData, *dataSigParsed)
+ assert.NoError(t, err, "Artifact data signature should be valid")
+}
+
+func TestRootKey_ExpiredKeyWorkflow(t *testing.T) {
+ // Generate a root key that expires very soon
+ rootKey, _, _, err := GenerateRootKey(1 * time.Millisecond)
+ require.NoError(t, err)
+
+ // Wait for expiration
+ time.Sleep(10 * time.Millisecond)
+
+ // Try to generate artifact key with expired root key
+ _, _, _, _, err = GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "expired")
+}
diff --git a/client/internal/updatemanager/reposign/signature.go b/client/internal/updatemanager/reposign/signature.go
new file mode 100644
index 000000000..c7f06e94e
--- /dev/null
+++ b/client/internal/updatemanager/reposign/signature.go
@@ -0,0 +1,24 @@
+package reposign
+
+import (
+ "encoding/json"
+ "time"
+)
+
+// Signature contains a signature with associated Metadata
+type Signature struct {
+ Signature []byte `json:"signature"`
+ Timestamp time.Time `json:"timestamp"`
+ KeyID KeyID `json:"key_id"`
+ Algorithm string `json:"algorithm"` // "ed25519"
+ HashAlgo string `json:"hash_algo"` // "blake2s" or sha512
+}
+
+func ParseSignature(data []byte) (*Signature, error) {
+ var signature Signature
+ if err := json.Unmarshal(data, &signature); err != nil {
+ return nil, err
+ }
+
+ return &signature, nil
+}
diff --git a/client/internal/updatemanager/reposign/signature_test.go b/client/internal/updatemanager/reposign/signature_test.go
new file mode 100644
index 000000000..1960c5518
--- /dev/null
+++ b/client/internal/updatemanager/reposign/signature_test.go
@@ -0,0 +1,277 @@
+package reposign
+
+import (
+ "encoding/json"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseSignature_Valid(t *testing.T) {
+ timestamp := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+ keyID, err := ParseKeyID("0123456789abcdef")
+ require.NoError(t, err)
+
+ signatureData := []byte{0x01, 0x02, 0x03, 0x04}
+
+ jsonData, err := json.Marshal(Signature{
+ Signature: signatureData,
+ Timestamp: timestamp,
+ KeyID: keyID,
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ })
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(jsonData)
+ require.NoError(t, err)
+ assert.NotNil(t, sig)
+ assert.Equal(t, signatureData, sig.Signature)
+ assert.Equal(t, timestamp.Unix(), sig.Timestamp.Unix())
+ assert.Equal(t, keyID, sig.KeyID)
+ assert.Equal(t, "ed25519", sig.Algorithm)
+ assert.Equal(t, "blake2s", sig.HashAlgo)
+}
+
+func TestParseSignature_InvalidJSON(t *testing.T) {
+ invalidJSON := []byte(`{invalid json}`)
+
+ sig, err := ParseSignature(invalidJSON)
+ assert.Error(t, err)
+ assert.Nil(t, sig)
+}
+
+func TestParseSignature_EmptyData(t *testing.T) {
+ emptyJSON := []byte(`{}`)
+
+ sig, err := ParseSignature(emptyJSON)
+ require.NoError(t, err)
+ assert.NotNil(t, sig)
+ assert.Empty(t, sig.Signature)
+ assert.True(t, sig.Timestamp.IsZero())
+ assert.Equal(t, KeyID{}, sig.KeyID)
+ assert.Empty(t, sig.Algorithm)
+ assert.Empty(t, sig.HashAlgo)
+}
+
+func TestParseSignature_MissingFields(t *testing.T) {
+ // JSON with only some fields
+ partialJSON := []byte(`{
+ "signature": "AQIDBA==",
+ "algorithm": "ed25519"
+ }`)
+
+ sig, err := ParseSignature(partialJSON)
+ require.NoError(t, err)
+ assert.NotNil(t, sig)
+ assert.NotEmpty(t, sig.Signature)
+ assert.Equal(t, "ed25519", sig.Algorithm)
+ assert.True(t, sig.Timestamp.IsZero())
+}
+
+func TestSignature_MarshalUnmarshal_Roundtrip(t *testing.T) {
+ timestamp := time.Date(2024, 6, 20, 14, 45, 30, 0, time.UTC)
+ keyID, err := ParseKeyID("fedcba9876543210")
+ require.NoError(t, err)
+
+ original := Signature{
+ Signature: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe},
+ Timestamp: timestamp,
+ KeyID: keyID,
+ Algorithm: "ed25519",
+ HashAlgo: "sha512",
+ }
+
+ // Marshal
+ jsonData, err := json.Marshal(original)
+ require.NoError(t, err)
+
+ // Unmarshal
+ parsed, err := ParseSignature(jsonData)
+ require.NoError(t, err)
+
+ // Verify
+ assert.Equal(t, original.Signature, parsed.Signature)
+ assert.Equal(t, original.Timestamp.Unix(), parsed.Timestamp.Unix())
+ assert.Equal(t, original.KeyID, parsed.KeyID)
+ assert.Equal(t, original.Algorithm, parsed.Algorithm)
+ assert.Equal(t, original.HashAlgo, parsed.HashAlgo)
+}
+
+func TestSignature_NilSignatureBytes(t *testing.T) {
+ timestamp := time.Now().UTC()
+ keyID, err := ParseKeyID("0011223344556677")
+ require.NoError(t, err)
+
+ sig := Signature{
+ Signature: nil,
+ Timestamp: timestamp,
+ KeyID: keyID,
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ jsonData, err := json.Marshal(sig)
+ require.NoError(t, err)
+
+ parsed, err := ParseSignature(jsonData)
+ require.NoError(t, err)
+ assert.Nil(t, parsed.Signature)
+}
+
+func TestSignature_LargeSignature(t *testing.T) {
+ timestamp := time.Now().UTC()
+ keyID, err := ParseKeyID("aabbccddeeff0011")
+ require.NoError(t, err)
+
+ // Create a large signature (64 bytes for ed25519)
+ largeSignature := make([]byte, 64)
+ for i := range largeSignature {
+ largeSignature[i] = byte(i)
+ }
+
+ sig := Signature{
+ Signature: largeSignature,
+ Timestamp: timestamp,
+ KeyID: keyID,
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ jsonData, err := json.Marshal(sig)
+ require.NoError(t, err)
+
+ parsed, err := ParseSignature(jsonData)
+ require.NoError(t, err)
+ assert.Equal(t, largeSignature, parsed.Signature)
+}
+
+func TestSignature_WithDifferentHashAlgorithms(t *testing.T) {
+ tests := []struct {
+ name string
+ hashAlgo string
+ }{
+ {"blake2s", "blake2s"},
+ {"sha512", "sha512"},
+ {"sha256", "sha256"},
+ {"empty", ""},
+ }
+
+ keyID, err := ParseKeyID("1122334455667788")
+ require.NoError(t, err)
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ sig := Signature{
+ Signature: []byte{0x01, 0x02},
+ Timestamp: time.Now().UTC(),
+ KeyID: keyID,
+ Algorithm: "ed25519",
+ HashAlgo: tt.hashAlgo,
+ }
+
+ jsonData, err := json.Marshal(sig)
+ require.NoError(t, err)
+
+ parsed, err := ParseSignature(jsonData)
+ require.NoError(t, err)
+ assert.Equal(t, tt.hashAlgo, parsed.HashAlgo)
+ })
+ }
+}
+
+func TestSignature_TimestampPrecision(t *testing.T) {
+ // Test that timestamp preserves precision through JSON marshaling
+ timestamp := time.Date(2024, 3, 15, 10, 30, 45, 123456789, time.UTC)
+ keyID, err := ParseKeyID("8877665544332211")
+ require.NoError(t, err)
+
+ sig := Signature{
+ Signature: []byte{0xaa, 0xbb},
+ Timestamp: timestamp,
+ KeyID: keyID,
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ jsonData, err := json.Marshal(sig)
+ require.NoError(t, err)
+
+ parsed, err := ParseSignature(jsonData)
+ require.NoError(t, err)
+
+ // JSON timestamps typically have second or millisecond precision
+ // so we check that at least seconds match
+ assert.Equal(t, timestamp.Unix(), parsed.Timestamp.Unix())
+}
+
+func TestParseSignature_MalformedKeyID(t *testing.T) {
+ // Test with a malformed KeyID field
+ malformedJSON := []byte(`{
+ "signature": "AQID",
+ "timestamp": "2024-01-15T10:30:00Z",
+ "key_id": "invalid_keyid_format",
+ "algorithm": "ed25519",
+ "hash_algo": "blake2s"
+ }`)
+
+ // This should fail since "invalid_keyid_format" is not a valid KeyID
+ sig, err := ParseSignature(malformedJSON)
+ assert.Error(t, err)
+ assert.Nil(t, sig)
+}
+
+func TestParseSignature_InvalidTimestamp(t *testing.T) {
+ // Test with an invalid timestamp format
+ invalidTimestampJSON := []byte(`{
+ "signature": "AQID",
+ "timestamp": "not-a-timestamp",
+ "key_id": "0123456789abcdef",
+ "algorithm": "ed25519",
+ "hash_algo": "blake2s"
+ }`)
+
+ sig, err := ParseSignature(invalidTimestampJSON)
+ assert.Error(t, err)
+ assert.Nil(t, sig)
+}
+
+func TestSignature_ZeroKeyID(t *testing.T) {
+ // Test with a zero KeyID
+ sig := Signature{
+ Signature: []byte{0x01, 0x02, 0x03},
+ Timestamp: time.Now().UTC(),
+ KeyID: KeyID{},
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ jsonData, err := json.Marshal(sig)
+ require.NoError(t, err)
+
+ parsed, err := ParseSignature(jsonData)
+ require.NoError(t, err)
+ assert.Equal(t, KeyID{}, parsed.KeyID)
+}
+
+func TestParseSignature_ExtraFields(t *testing.T) {
+ // JSON with extra fields that should be ignored
+ jsonWithExtra := []byte(`{
+ "signature": "AQIDBA==",
+ "timestamp": "2024-01-15T10:30:00Z",
+ "key_id": "0123456789abcdef",
+ "algorithm": "ed25519",
+ "hash_algo": "blake2s",
+ "extra_field": "should be ignored",
+ "another_extra": 12345
+ }`)
+
+ sig, err := ParseSignature(jsonWithExtra)
+ require.NoError(t, err)
+ assert.NotNil(t, sig)
+ assert.NotEmpty(t, sig.Signature)
+ assert.Equal(t, "ed25519", sig.Algorithm)
+ assert.Equal(t, "blake2s", sig.HashAlgo)
+}
diff --git a/client/internal/updatemanager/reposign/verify.go b/client/internal/updatemanager/reposign/verify.go
new file mode 100644
index 000000000..0af2a8c9e
--- /dev/null
+++ b/client/internal/updatemanager/reposign/verify.go
@@ -0,0 +1,187 @@
+package reposign
+
+import (
+ "context"
+ "fmt"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
+)
+
+const (
+ artifactPubKeysFileName = "artifact-key-pub.pem"
+ artifactPubKeysSigFileName = "artifact-key-pub.pem.sig"
+ revocationFileName = "revocation-list.json"
+ revocationSignFileName = "revocation-list.json.sig"
+
+ keySizeLimit = 5 * 1024 * 1024 //5MB
+ signatureLimit = 1024
+ revocationLimit = 10 * 1024 * 1024
+)
+
+type ArtifactVerify struct {
+ rootKeys []PublicKey
+ keysBaseURL *url.URL
+
+ revocationList *RevocationList
+}
+
+func NewArtifactVerify(keysBaseURL string) (*ArtifactVerify, error) {
+ allKeys, err := loadEmbeddedPublicKeys()
+ if err != nil {
+ return nil, err
+ }
+
+ return newArtifactVerify(keysBaseURL, allKeys)
+}
+
+func newArtifactVerify(keysBaseURL string, allKeys []PublicKey) (*ArtifactVerify, error) {
+ ku, err := url.Parse(keysBaseURL)
+ if err != nil {
+ return nil, fmt.Errorf("invalid keys base URL %q: %v", keysBaseURL, err)
+ }
+
+ a := &ArtifactVerify{
+ rootKeys: allKeys,
+ keysBaseURL: ku,
+ }
+ return a, nil
+}
+
+func (a *ArtifactVerify) Verify(ctx context.Context, version string, artifactFile string) error {
+ version = strings.TrimPrefix(version, "v")
+
+ revocationList, err := a.loadRevocationList(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to load revocation list: %v", err)
+ }
+ a.revocationList = revocationList
+
+ artifactPubKeys, err := a.loadArtifactKeys(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to load artifact keys: %v", err)
+ }
+
+ signature, err := a.loadArtifactSignature(ctx, version, artifactFile)
+ if err != nil {
+ return fmt.Errorf("failed to download signature file for: %s, %v", filepath.Base(artifactFile), err)
+ }
+
+ artifactData, err := os.ReadFile(artifactFile)
+ if err != nil {
+ log.Errorf("failed to read artifact file: %v", err)
+ return fmt.Errorf("failed to read artifact file: %w", err)
+ }
+
+ if err := ValidateArtifact(artifactPubKeys, artifactData, *signature); err != nil {
+ return fmt.Errorf("failed to validate artifact: %v", err)
+ }
+
+ return nil
+}
+
+func (a *ArtifactVerify) loadRevocationList(ctx context.Context) (*RevocationList, error) {
+ downloadURL := a.keysBaseURL.JoinPath("keys", revocationFileName).String()
+ data, err := downloader.DownloadToMemory(ctx, downloadURL, revocationLimit)
+ if err != nil {
+ log.Debugf("failed to download revocation list '%s': %s", downloadURL, err)
+ return nil, err
+ }
+
+ downloadURL = a.keysBaseURL.JoinPath("keys", revocationSignFileName).String()
+ sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
+ if err != nil {
+ log.Debugf("failed to download revocation list '%s': %s", downloadURL, err)
+ return nil, err
+ }
+
+ signature, err := ParseSignature(sigData)
+ if err != nil {
+ log.Debugf("failed to parse revocation list signature: %s", err)
+ return nil, err
+ }
+
+ return ValidateRevocationList(a.rootKeys, data, *signature)
+}
+
+func (a *ArtifactVerify) loadArtifactKeys(ctx context.Context) ([]PublicKey, error) {
+ downloadURL := a.keysBaseURL.JoinPath("keys", artifactPubKeysFileName).String()
+ log.Debugf("starting downloading artifact keys from: %s", downloadURL)
+ data, err := downloader.DownloadToMemory(ctx, downloadURL, keySizeLimit)
+ if err != nil {
+ log.Debugf("failed to download artifact keys: %s", err)
+ return nil, err
+ }
+
+ downloadURL = a.keysBaseURL.JoinPath("keys", artifactPubKeysSigFileName).String()
+ log.Debugf("start downloading signature of artifact pub key from: %s", downloadURL)
+ sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
+ if err != nil {
+ log.Debugf("failed to download signature of public keys: %s", err)
+ return nil, err
+ }
+
+ signature, err := ParseSignature(sigData)
+ if err != nil {
+ log.Debugf("failed to parse signature of public keys: %s", err)
+ return nil, err
+ }
+
+ return ValidateArtifactKeys(a.rootKeys, data, *signature, a.revocationList)
+}
+
+func (a *ArtifactVerify) loadArtifactSignature(ctx context.Context, version string, artifactFile string) (*Signature, error) {
+ artifactFile = filepath.Base(artifactFile)
+ downloadURL := a.keysBaseURL.JoinPath("tag", "v"+version, artifactFile+".sig").String()
+ data, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
+ if err != nil {
+ log.Debugf("failed to download artifact signature: %s", err)
+ return nil, err
+ }
+
+ signature, err := ParseSignature(data)
+ if err != nil {
+ log.Debugf("failed to parse artifact signature: %s", err)
+ return nil, err
+ }
+
+ return signature, nil
+
+}
+
+func loadEmbeddedPublicKeys() ([]PublicKey, error) {
+ files, err := embeddedCerts.ReadDir(embeddedCertsDir)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read embedded certs: %w", err)
+ }
+
+ var allKeys []PublicKey
+ for _, file := range files {
+ if file.IsDir() {
+ continue
+ }
+
+ data, err := embeddedCerts.ReadFile(embeddedCertsDir + "/" + file.Name())
+ if err != nil {
+ return nil, fmt.Errorf("failed to read cert file %s: %w", file.Name(), err)
+ }
+
+ keys, err := parsePublicKeyBundle(data, tagRootPublic)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse cert %s: %w", file.Name(), err)
+ }
+
+ allKeys = append(allKeys, keys...)
+ }
+
+ if len(allKeys) == 0 {
+ return nil, fmt.Errorf("no valid public keys found in embedded certs")
+ }
+
+ return allKeys, nil
+}
diff --git a/client/internal/updatemanager/reposign/verify_test.go b/client/internal/updatemanager/reposign/verify_test.go
new file mode 100644
index 000000000..c29393bad
--- /dev/null
+++ b/client/internal/updatemanager/reposign/verify_test.go
@@ -0,0 +1,528 @@
+package reposign
+
+import (
+ "context"
+ "crypto/ed25519"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// Test ArtifactVerify construction
+
+func TestArtifactVerify_Construction(t *testing.T) {
+ // Generate test root key
+ rootKey, _, rootPubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ rootPubKey, _, err := parsePublicKey(rootPubPEM, tagRootPublic)
+ require.NoError(t, err)
+
+ keysBaseURL := "http://localhost:8080/artifact-signatures"
+
+ av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ assert.NotNil(t, av)
+ assert.NotEmpty(t, av.rootKeys)
+ assert.Equal(t, keysBaseURL, av.keysBaseURL.String())
+
+ // Verify root key structure
+ assert.NotEmpty(t, av.rootKeys[0].Key)
+ assert.Equal(t, rootKey.Metadata.ID, av.rootKeys[0].Metadata.ID)
+ assert.False(t, av.rootKeys[0].Metadata.CreatedAt.IsZero())
+}
+
+func TestArtifactVerify_MultipleRootKeys(t *testing.T) {
+ // Generate multiple test root keys
+ rootKey1, _, rootPubPEM1, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+ rootPubKey1, _, err := parsePublicKey(rootPubPEM1, tagRootPublic)
+ require.NoError(t, err)
+
+ rootKey2, _, rootPubPEM2, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+ rootPubKey2, _, err := parsePublicKey(rootPubPEM2, tagRootPublic)
+ require.NoError(t, err)
+
+ keysBaseURL := "http://localhost:8080/artifact-signatures"
+
+ av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey1, rootPubKey2})
+ assert.NoError(t, err)
+ assert.Len(t, av.rootKeys, 2)
+ assert.NotEqual(t, rootKey1.Metadata.ID, rootKey2.Metadata.ID)
+}
+
+// Test Verify workflow with mock HTTP server
+
+func TestArtifactVerify_FullWorkflow(t *testing.T) {
+ // Create temporary test directory
+ tempDir := t.TempDir()
+
+ // Step 1: Generate root key
+ rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Step 2: Generate artifact key
+ artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
+ require.NoError(t, err)
+
+ // Step 3: Create revocation list
+ revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ // Step 4: Bundle artifact keys
+ artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
+ require.NoError(t, err)
+
+ // Step 5: Create test artifact
+ artifactPath := filepath.Join(tempDir, "test-artifact.bin")
+ artifactData := []byte("This is test artifact data for verification")
+ err = os.WriteFile(artifactPath, artifactData, 0644)
+ require.NoError(t, err)
+
+ // Step 6: Sign artifact
+ artifactSigData, err := SignData(*artifactKey, artifactData)
+ require.NoError(t, err)
+
+ // Step 7: Setup mock HTTP server
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/artifact-signatures/keys/" + revocationFileName:
+ _, _ = w.Write(revocationData)
+ case "/artifact-signatures/keys/" + revocationSignFileName:
+ _, _ = w.Write(revocationSig)
+ case "/artifact-signatures/keys/" + artifactPubKeysFileName:
+ _, _ = w.Write(artifactKeysBundle)
+ case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
+ _, _ = w.Write(artifactKeysSig)
+ case "/artifacts/v1.0.0/test-artifact.bin":
+ _, _ = w.Write(artifactData)
+ case "/artifact-signatures/tag/v1.0.0/test-artifact.bin.sig":
+ _, _ = w.Write(artifactSigData)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ // Step 8: Create ArtifactVerify with test root key
+ rootPubKey := PublicKey{
+ Key: rootKey.Key.Public().(ed25519.PublicKey),
+ Metadata: rootKey.Metadata,
+ }
+
+ av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ // Step 9: Verify artifact
+ ctx := context.Background()
+ err = av.Verify(ctx, "1.0.0", artifactPath)
+ assert.NoError(t, err)
+}
+
+func TestArtifactVerify_InvalidRevocationList(t *testing.T) {
+ tempDir := t.TempDir()
+ artifactPath := filepath.Join(tempDir, "test.bin")
+ err := os.WriteFile(artifactPath, []byte("test"), 0644)
+ require.NoError(t, err)
+
+ // Setup server with invalid revocation list
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/artifact-signatures/keys/" + revocationFileName:
+ _, _ = w.Write([]byte("invalid data"))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ rootPubKey := PublicKey{
+ Key: rootKey.Key.Public().(ed25519.PublicKey),
+ Metadata: rootKey.Metadata,
+ }
+
+ av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ err = av.Verify(ctx, "1.0.0", artifactPath)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to load revocation list")
+}
+
+func TestArtifactVerify_MissingArtifactFile(t *testing.T) {
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ rootPubKey := PublicKey{
+ Key: rootKey.Key.Public().(ed25519.PublicKey),
+ Metadata: rootKey.Metadata,
+ }
+
+ // Create revocation list
+ revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
+ require.NoError(t, err)
+
+ artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
+ require.NoError(t, err)
+
+ // Create signature for non-existent file
+ testData := []byte("test")
+ artifactSigData, err := SignData(*artifactKey, testData)
+ require.NoError(t, err)
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/artifact-signatures/keys/" + revocationFileName:
+ _, _ = w.Write(revocationData)
+ case "/artifact-signatures/keys/" + revocationSignFileName:
+ _, _ = w.Write(revocationSig)
+ case "/artifact-signatures/keys/" + artifactPubKeysFileName:
+ _, _ = w.Write(artifactKeysBundle)
+ case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
+ _, _ = w.Write(artifactKeysSig)
+ case "/artifact-signatures/tag/v1.0.0/missing.bin.sig":
+ _, _ = w.Write(artifactSigData)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ err = av.Verify(ctx, "1.0.0", "file.bin")
+ assert.Error(t, err)
+}
+
+func TestArtifactVerify_ServerUnavailable(t *testing.T) {
+ tempDir := t.TempDir()
+ artifactPath := filepath.Join(tempDir, "test.bin")
+ err := os.WriteFile(artifactPath, []byte("test"), 0644)
+ require.NoError(t, err)
+
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ rootPubKey := PublicKey{
+ Key: rootKey.Key.Public().(ed25519.PublicKey),
+ Metadata: rootKey.Metadata,
+ }
+
+ // Use URL that doesn't exist
+ av, err := newArtifactVerify("http://localhost:19999/keys", []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer cancel()
+
+ err = av.Verify(ctx, "1.0.0", artifactPath)
+ assert.Error(t, err)
+}
+
+func TestArtifactVerify_ContextCancellation(t *testing.T) {
+ tempDir := t.TempDir()
+ artifactPath := filepath.Join(tempDir, "test.bin")
+ err := os.WriteFile(artifactPath, []byte("test"), 0644)
+ require.NoError(t, err)
+
+ // Create a server that delays response
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ time.Sleep(500 * time.Millisecond)
+ _, _ = w.Write([]byte("data"))
+ }))
+ defer server.Close()
+
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ rootPubKey := PublicKey{
+ Key: rootKey.Key.Public().(ed25519.PublicKey),
+ Metadata: rootKey.Metadata,
+ }
+
+ av, err := newArtifactVerify(server.URL, []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ // Create context that cancels quickly
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
+ defer cancel()
+
+ err = av.Verify(ctx, "1.0.0", artifactPath)
+ assert.Error(t, err)
+}
+
+func TestArtifactVerify_WithRevocation(t *testing.T) {
+ tempDir := t.TempDir()
+
+ // Generate root key
+ rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Generate two artifact keys
+ artifactKey1, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1)
+ require.NoError(t, err)
+
+ _, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2)
+ require.NoError(t, err)
+
+ // Create revocation list with first key revoked
+ emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ parsedRevocation, err := ParseRevocationList(emptyRevocation)
+ require.NoError(t, err)
+
+ revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ // Bundle both keys
+ artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2})
+ require.NoError(t, err)
+
+ // Create artifact signed by revoked key
+ artifactPath := filepath.Join(tempDir, "test.bin")
+ artifactData := []byte("test data")
+ err = os.WriteFile(artifactPath, artifactData, 0644)
+ require.NoError(t, err)
+
+ artifactSigData, err := SignData(*artifactKey1, artifactData)
+ require.NoError(t, err)
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/artifact-signatures/keys/" + revocationFileName:
+ _, _ = w.Write(revocationData)
+ case "/artifact-signatures/keys/" + revocationSignFileName:
+ _, _ = w.Write(revocationSig)
+ case "/artifact-signatures/keys/" + artifactPubKeysFileName:
+ _, _ = w.Write(artifactKeysBundle)
+ case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
+ _, _ = w.Write(artifactKeysSig)
+ case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
+ _, _ = w.Write(artifactSigData)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ rootPubKey := PublicKey{
+ Key: rootKey.Key.Public().(ed25519.PublicKey),
+ Metadata: rootKey.Metadata,
+ }
+
+ av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ err = av.Verify(ctx, "1.0.0", artifactPath)
+ // Should fail because the signing key is revoked
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "no signing Key found")
+}
+
+func TestArtifactVerify_ValidWithSecondKey(t *testing.T) {
+ tempDir := t.TempDir()
+
+ // Generate root key
+ rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Generate two artifact keys
+ _, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1)
+ require.NoError(t, err)
+
+ artifactKey2, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2)
+ require.NoError(t, err)
+
+ // Create revocation list with first key revoked
+ emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ parsedRevocation, err := ParseRevocationList(emptyRevocation)
+ require.NoError(t, err)
+
+ revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ // Bundle both keys
+ artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2})
+ require.NoError(t, err)
+
+ // Create artifact signed by second key (not revoked)
+ artifactPath := filepath.Join(tempDir, "test.bin")
+ artifactData := []byte("test data")
+ err = os.WriteFile(artifactPath, artifactData, 0644)
+ require.NoError(t, err)
+
+ artifactSigData, err := SignData(*artifactKey2, artifactData)
+ require.NoError(t, err)
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/artifact-signatures/keys/" + revocationFileName:
+ _, _ = w.Write(revocationData)
+ case "/artifact-signatures/keys/" + revocationSignFileName:
+ _, _ = w.Write(revocationSig)
+ case "/artifact-signatures/keys/" + artifactPubKeysFileName:
+ _, _ = w.Write(artifactKeysBundle)
+ case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
+ _, _ = w.Write(artifactKeysSig)
+ case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
+ _, _ = w.Write(artifactSigData)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ rootPubKey := PublicKey{
+ Key: rootKey.Key.Public().(ed25519.PublicKey),
+ Metadata: rootKey.Metadata,
+ }
+
+ av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ err = av.Verify(ctx, "1.0.0", artifactPath)
+ // Should succeed because second key is not revoked
+ assert.NoError(t, err)
+}
+
+func TestArtifactVerify_TamperedArtifact(t *testing.T) {
+ tempDir := t.TempDir()
+
+ // Generate root key and artifact key
+ rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
+ require.NoError(t, err)
+
+ // Create revocation list
+ revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ // Bundle keys
+ artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
+ require.NoError(t, err)
+
+ // Sign original data
+ originalData := []byte("original data")
+ artifactSigData, err := SignData(*artifactKey, originalData)
+ require.NoError(t, err)
+
+ // Write tampered data to file
+ artifactPath := filepath.Join(tempDir, "test.bin")
+ tamperedData := []byte("tampered data")
+ err = os.WriteFile(artifactPath, tamperedData, 0644)
+ require.NoError(t, err)
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/artifact-signatures/keys/" + revocationFileName:
+ _, _ = w.Write(revocationData)
+ case "/artifact-signatures/keys/" + revocationSignFileName:
+ _, _ = w.Write(revocationSig)
+ case "/artifact-signatures/keys/" + artifactPubKeysFileName:
+ _, _ = w.Write(artifactKeysBundle)
+ case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
+ _, _ = w.Write(artifactKeysSig)
+ case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
+ _, _ = w.Write(artifactSigData)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ rootPubKey := PublicKey{
+ Key: rootKey.Key.Public().(ed25519.PublicKey),
+ Metadata: rootKey.Metadata,
+ }
+
+ av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ err = av.Verify(ctx, "1.0.0", artifactPath)
+ // Should fail because artifact was tampered
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to validate artifact")
+}
+
+// Test URL validation
+
+func TestArtifactVerify_URLParsing(t *testing.T) {
+ tests := []struct {
+ name string
+ keysBaseURL string
+ expectError bool
+ }{
+ {
+ name: "Valid HTTP URL",
+ keysBaseURL: "http://example.com/artifact-signatures",
+ expectError: false,
+ },
+ {
+ name: "Valid HTTPS URL",
+ keysBaseURL: "https://example.com/artifact-signatures",
+ expectError: false,
+ },
+ {
+ name: "URL with port",
+ keysBaseURL: "http://localhost:8080/artifact-signatures",
+ expectError: false,
+ },
+ {
+ name: "Invalid URL",
+ keysBaseURL: "://invalid",
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := newArtifactVerify(tt.keysBaseURL, nil)
+ if tt.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
diff --git a/client/internal/updatemanager/update.go b/client/internal/updatemanager/update.go
new file mode 100644
index 000000000..875b50b49
--- /dev/null
+++ b/client/internal/updatemanager/update.go
@@ -0,0 +1,11 @@
+package updatemanager
+
+import v "github.com/hashicorp/go-version"
+
+type UpdateInterface interface {
+ StopWatch()
+ SetDaemonVersion(newVersion string) bool
+ SetOnUpdateListener(updateFn func())
+ LatestVersion() *v.Version
+ StartFetcher()
+}
diff --git a/client/internal/winregistry/volatile_windows.go b/client/internal/winregistry/volatile_windows.go
new file mode 100644
index 000000000..a8e350fe7
--- /dev/null
+++ b/client/internal/winregistry/volatile_windows.go
@@ -0,0 +1,59 @@
+package winregistry
+
+import (
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/windows/registry"
+)
+
+var (
+ advapi = syscall.NewLazyDLL("advapi32.dll")
+ regCreateKeyExW = advapi.NewProc("RegCreateKeyExW")
+)
+
+const (
+ // Registry key options
+ regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted
+ regOptionVolatile = 0x1 // Key is not preserved when system is rebooted
+
+ // Registry disposition values
+ regCreatedNewKey = 0x1
+ regOpenedExistingKey = 0x2
+)
+
+// CreateVolatileKey creates a volatile registry key named path under open key root.
+// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed.
+// The access parameter specifies the access rights for the key to be created.
+//
+// Volatile keys are stored in memory and are automatically deleted when the system is shut down.
+// This provides automatic cleanup without requiring manual registry maintenance.
+func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) {
+ pathPtr, err := syscall.UTF16PtrFromString(path)
+ if err != nil {
+ return 0, false, err
+ }
+
+ var (
+ handle syscall.Handle
+ disposition uint32
+ )
+
+ ret, _, _ := regCreateKeyExW.Call(
+ uintptr(root),
+ uintptr(unsafe.Pointer(pathPtr)),
+ 0, // reserved
+ 0, // class
+ uintptr(regOptionVolatile), // options - volatile key
+ uintptr(access), // desired access
+ 0, // security attributes
+ uintptr(unsafe.Pointer(&handle)),
+ uintptr(unsafe.Pointer(&disposition)),
+ )
+
+ if ret != 0 {
+ return 0, false, syscall.Errno(ret)
+ }
+
+ return registry.Key(handle), disposition == regOpenedExistingKey, nil
+}
diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go
index 2109d4b15..f3458ccea 100644
--- a/client/ios/NetBirdSDK/client.go
+++ b/client/ios/NetBirdSDK/client.go
@@ -1,9 +1,12 @@
+//go:build ios
+
package NetBirdSDK
import (
"context"
"fmt"
"net/netip"
+ "os"
"sort"
"strings"
"sync"
@@ -20,8 +23,8 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
- "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
+ "github.com/netbirdio/netbird/shared/management/domain"
)
// ConnectionListener export internal Listener for mobile
@@ -90,7 +93,8 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s
}
// Run start the internal client. It is a blocker function
-func (c *Client) Run(fd int32, interfaceName string) error {
+func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
+ exportEnvList(envList)
log.Infof("Starting NetBird client")
log.Debugf("Tunnel uses interface: %s", interfaceName)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
@@ -127,7 +131,7 @@ func (c *Client) Run(fd int32, interfaceName string) error {
c.onHostDnsFn = func([]string) {}
cfg.WgIface = interfaceName
- c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
+ c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
}
@@ -228,7 +232,7 @@ func (c *Client) LoginForMobile() string {
ConfigPath: c.cfgFile,
})
- oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false)
+ oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, false, "")
if err != nil {
return err.Error()
}
@@ -433,3 +437,19 @@ func toNetIDs(routes []string) []route.NetID {
}
return netIDs
}
+
+func exportEnvList(list *EnvList) {
+ if list == nil {
+ return
+ }
+ for k, v := range list.AllItems() {
+ log.Debugf("Env variable %s's value is currently: %s", k, os.Getenv(k))
+ log.Debugf("Setting env variable %s: %s", k, v)
+
+ if err := os.Setenv(k, v); err != nil {
+ log.Errorf("could not set env variable %s: %v", k, err)
+ } else {
+ log.Debugf("Env variable %s was set successfully", k)
+ }
+ }
+}
diff --git a/client/ios/NetBirdSDK/env_list.go b/client/ios/NetBirdSDK/env_list.go
new file mode 100644
index 000000000..4800803d7
--- /dev/null
+++ b/client/ios/NetBirdSDK/env_list.go
@@ -0,0 +1,34 @@
+//go:build ios
+
+package NetBirdSDK
+
+import "github.com/netbirdio/netbird/client/internal/peer"
+
+// EnvList is an exported struct to be bound by gomobile
+type EnvList struct {
+ data map[string]string
+}
+
+// NewEnvList creates a new EnvList
+func NewEnvList() *EnvList {
+ return &EnvList{data: make(map[string]string)}
+}
+
+// Put adds a key-value pair
+func (el *EnvList) Put(key, value string) {
+ el.data[key] = value
+}
+
+// Get retrieves a value by key
+func (el *EnvList) Get(key string) string {
+ return el.data[key]
+}
+
+func (el *EnvList) AllItems() map[string]string {
+ return el.data
+}
+
+// GetEnvKeyNBForceRelay Exports the environment variable for the iOS client
+func GetEnvKeyNBForceRelay() string {
+ return peer.EnvKeyNBForceRelay
+}
diff --git a/client/ios/NetBirdSDK/gomobile.go b/client/ios/NetBirdSDK/gomobile.go
index 9eadd6a7f..79bf0c2ac 100644
--- a/client/ios/NetBirdSDK/gomobile.go
+++ b/client/ios/NetBirdSDK/gomobile.go
@@ -1,3 +1,5 @@
+//go:build ios
+
package NetBirdSDK
import _ "golang.org/x/mobile/bind"
diff --git a/client/ios/NetBirdSDK/logger.go b/client/ios/NetBirdSDK/logger.go
index f1ad1b9f6..531d0ba89 100644
--- a/client/ios/NetBirdSDK/logger.go
+++ b/client/ios/NetBirdSDK/logger.go
@@ -1,3 +1,5 @@
+//go:build ios
+
package NetBirdSDK
import (
diff --git a/client/ios/NetBirdSDK/login.go b/client/ios/NetBirdSDK/login.go
index 570c44f80..1c2b38a61 100644
--- a/client/ios/NetBirdSDK/login.go
+++ b/client/ios/NetBirdSDK/login.go
@@ -1,3 +1,5 @@
+//go:build ios
+
package NetBirdSDK
import (
diff --git a/client/ios/NetBirdSDK/peer_notifier.go b/client/ios/NetBirdSDK/peer_notifier.go
index 16c5039eb..9b00568be 100644
--- a/client/ios/NetBirdSDK/peer_notifier.go
+++ b/client/ios/NetBirdSDK/peer_notifier.go
@@ -1,3 +1,5 @@
+//go:build ios
+
package NetBirdSDK
// PeerInfo describe information about the peers. It designed for the UI usage
diff --git a/client/ios/NetBirdSDK/preferences.go b/client/ios/NetBirdSDK/preferences.go
index 5e7050465..39ae06538 100644
--- a/client/ios/NetBirdSDK/preferences.go
+++ b/client/ios/NetBirdSDK/preferences.go
@@ -1,3 +1,5 @@
+//go:build ios
+
package NetBirdSDK
import (
diff --git a/client/ios/NetBirdSDK/preferences_test.go b/client/ios/NetBirdSDK/preferences_test.go
index 780443a7b..5f75e7c9a 100644
--- a/client/ios/NetBirdSDK/preferences_test.go
+++ b/client/ios/NetBirdSDK/preferences_test.go
@@ -1,3 +1,5 @@
+//go:build ios
+
package NetBirdSDK
import (
diff --git a/client/ios/NetBirdSDK/routes.go b/client/ios/NetBirdSDK/routes.go
index 30d0d0d0a..7b84d6e1c 100644
--- a/client/ios/NetBirdSDK/routes.go
+++ b/client/ios/NetBirdSDK/routes.go
@@ -1,3 +1,5 @@
+//go:build ios
+
package NetBirdSDK
// RoutesSelectionInfoCollection made for Java layer to get non default types as collection
diff --git a/client/net/conn.go b/client/net/conn.go
index 918e7f628..bf54c792d 100644
--- a/client/net/conn.go
+++ b/client/net/conn.go
@@ -17,8 +17,7 @@ type Conn struct {
ID hooks.ConnectionID
}
-// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
-// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
+// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection.
func (c *Conn) Close() error {
return closeConn(c.ID, c.Conn)
}
@@ -29,7 +28,7 @@ type TCPConn struct {
ID hooks.ConnectionID
}
-// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection.
+// Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection.
func (c *TCPConn) Close() error {
return closeConn(c.ID, c.TCPConn)
}
@@ -37,13 +36,16 @@ func (c *TCPConn) Close() error {
// closeConn is a helper function to close connections and execute close hooks.
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
err := conn.Close()
+ cleanupConnID(id)
+ return err
+}
+// cleanupConnID executes close hooks for a connection ID.
+func cleanupConnID(id hooks.ConnectionID) {
closeHooks := hooks.GetCloseHooks()
for _, hook := range closeHooks {
if err := hook(id); err != nil {
log.Errorf("Error executing close hook: %v", err)
}
}
-
- return err
}
diff --git a/client/net/dial.go b/client/net/dial.go
index 041a00e5d..17c9ff98a 100644
--- a/client/net/dial.go
+++ b/client/net/dial.go
@@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro
}
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
}
-
if err := conn.Close(); err != nil {
log.Errorf("failed to close connection: %v", err)
}
diff --git a/client/net/dialer_dial.go b/client/net/dialer_dial.go
index 2e1eb53d8..1e275013f 100644
--- a/client/net/dialer_dial.go
+++ b/client/net/dialer_dial.go
@@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
+ cleanupConnID(connID)
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
}
@@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str
ips, err := resolver.LookupIPAddr(ctx, host)
if err != nil {
- return fmt.Errorf("failed to resolve address %s: %w", address, err)
+ return fmt.Errorf("resolve address %s: %w", address, err)
}
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
diff --git a/client/net/listener_listen.go b/client/net/listener_listen.go
index 0bb5ad67d..a150172b4 100644
--- a/client/net/listener_listen.go
+++ b/client/net/listener_listen.go
@@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.PacketConn.WriteTo(b, addr)
}
-// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
+// Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection.
func (c *PacketConn) Close() error {
defer c.seenAddrs.Clear()
return closeConn(c.ID, c.PacketConn)
@@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.UDPConn.WriteTo(b, addr)
}
-// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
+// Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection.
func (c *UDPConn) Close() error {
defer c.seenAddrs.Clear()
return closeConn(c.ID, c.UDPConn)
diff --git a/client/netbird.wxs b/client/netbird.wxs
index ba827debf..03221dd91 100644
--- a/client/netbird.wxs
+++ b/client/netbird.wxs
@@ -51,7 +51,7 @@
-
+
diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go
index 34230a5b4..80e5bb9c5 100644
--- a/client/proto/daemon.pb.go
+++ b/client/proto/daemon.pb.go
@@ -1,21 +1,20 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
-// protoc v6.32.1
+// protoc v6.33.1
// source: daemon.proto
package proto
import (
- reflect "reflect"
- sync "sync"
- unsafe "unsafe"
-
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
_ "google.golang.org/protobuf/types/descriptorpb"
durationpb "google.golang.org/protobuf/types/known/durationpb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
+ reflect "reflect"
+ sync "sync"
+ unsafe "unsafe"
)
const (
@@ -89,6 +88,56 @@ func (LogLevel) EnumDescriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{0}
}
+// avoid collision with loglevel enum
+type OSLifecycleRequest_CycleType int32
+
+const (
+ OSLifecycleRequest_UNKNOWN OSLifecycleRequest_CycleType = 0
+ OSLifecycleRequest_SLEEP OSLifecycleRequest_CycleType = 1
+ OSLifecycleRequest_WAKEUP OSLifecycleRequest_CycleType = 2
+)
+
+// Enum value maps for OSLifecycleRequest_CycleType.
+var (
+ OSLifecycleRequest_CycleType_name = map[int32]string{
+ 0: "UNKNOWN",
+ 1: "SLEEP",
+ 2: "WAKEUP",
+ }
+ OSLifecycleRequest_CycleType_value = map[string]int32{
+ "UNKNOWN": 0,
+ "SLEEP": 1,
+ "WAKEUP": 2,
+ }
+)
+
+func (x OSLifecycleRequest_CycleType) Enum() *OSLifecycleRequest_CycleType {
+ p := new(OSLifecycleRequest_CycleType)
+ *p = x
+ return p
+}
+
+func (x OSLifecycleRequest_CycleType) String() string {
+ return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
+}
+
+func (OSLifecycleRequest_CycleType) Descriptor() protoreflect.EnumDescriptor {
+ return file_daemon_proto_enumTypes[1].Descriptor()
+}
+
+func (OSLifecycleRequest_CycleType) Type() protoreflect.EnumType {
+ return &file_daemon_proto_enumTypes[1]
+}
+
+func (x OSLifecycleRequest_CycleType) Number() protoreflect.EnumNumber {
+ return protoreflect.EnumNumber(x)
+}
+
+// Deprecated: Use OSLifecycleRequest_CycleType.Descriptor instead.
+func (OSLifecycleRequest_CycleType) EnumDescriptor() ([]byte, []int) {
+ return file_daemon_proto_rawDescGZIP(), []int{1, 0}
+}
+
type SystemEvent_Severity int32
const (
@@ -125,11 +174,11 @@ func (x SystemEvent_Severity) String() string {
}
func (SystemEvent_Severity) Descriptor() protoreflect.EnumDescriptor {
- return file_daemon_proto_enumTypes[1].Descriptor()
+ return file_daemon_proto_enumTypes[2].Descriptor()
}
func (SystemEvent_Severity) Type() protoreflect.EnumType {
- return &file_daemon_proto_enumTypes[1]
+ return &file_daemon_proto_enumTypes[2]
}
func (x SystemEvent_Severity) Number() protoreflect.EnumNumber {
@@ -138,7 +187,7 @@ func (x SystemEvent_Severity) Number() protoreflect.EnumNumber {
// Deprecated: Use SystemEvent_Severity.Descriptor instead.
func (SystemEvent_Severity) EnumDescriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{49, 0}
+ return file_daemon_proto_rawDescGZIP(), []int{53, 0}
}
type SystemEvent_Category int32
@@ -180,11 +229,11 @@ func (x SystemEvent_Category) String() string {
}
func (SystemEvent_Category) Descriptor() protoreflect.EnumDescriptor {
- return file_daemon_proto_enumTypes[2].Descriptor()
+ return file_daemon_proto_enumTypes[3].Descriptor()
}
func (SystemEvent_Category) Type() protoreflect.EnumType {
- return &file_daemon_proto_enumTypes[2]
+ return &file_daemon_proto_enumTypes[3]
}
func (x SystemEvent_Category) Number() protoreflect.EnumNumber {
@@ -193,7 +242,7 @@ func (x SystemEvent_Category) Number() protoreflect.EnumNumber {
// Deprecated: Use SystemEvent_Category.Descriptor instead.
func (SystemEvent_Category) EnumDescriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{49, 1}
+ return file_daemon_proto_rawDescGZIP(), []int{53, 1}
}
type EmptyRequest struct {
@@ -232,6 +281,86 @@ func (*EmptyRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{0}
}
+type OSLifecycleRequest struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ Type OSLifecycleRequest_CycleType `protobuf:"varint,1,opt,name=type,proto3,enum=daemon.OSLifecycleRequest_CycleType" json:"type,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *OSLifecycleRequest) Reset() {
+ *x = OSLifecycleRequest{}
+ mi := &file_daemon_proto_msgTypes[1]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *OSLifecycleRequest) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*OSLifecycleRequest) ProtoMessage() {}
+
+func (x *OSLifecycleRequest) ProtoReflect() protoreflect.Message {
+ mi := &file_daemon_proto_msgTypes[1]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use OSLifecycleRequest.ProtoReflect.Descriptor instead.
+func (*OSLifecycleRequest) Descriptor() ([]byte, []int) {
+ return file_daemon_proto_rawDescGZIP(), []int{1}
+}
+
+func (x *OSLifecycleRequest) GetType() OSLifecycleRequest_CycleType {
+ if x != nil {
+ return x.Type
+ }
+ return OSLifecycleRequest_UNKNOWN
+}
+
+type OSLifecycleResponse struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *OSLifecycleResponse) Reset() {
+ *x = OSLifecycleResponse{}
+ mi := &file_daemon_proto_msgTypes[2]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *OSLifecycleResponse) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*OSLifecycleResponse) ProtoMessage() {}
+
+func (x *OSLifecycleResponse) ProtoReflect() protoreflect.Message {
+ mi := &file_daemon_proto_msgTypes[2]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use OSLifecycleResponse.ProtoReflect.Descriptor instead.
+func (*OSLifecycleResponse) Descriptor() ([]byte, []int) {
+ return file_daemon_proto_rawDescGZIP(), []int{2}
+}
+
type LoginRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
// setupKey netbird setup key.
@@ -280,13 +409,21 @@ type LoginRequest struct {
ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"`
Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"`
- unknownFields protoimpl.UnknownFields
- sizeCache protoimpl.SizeCache
+ // hint is used to pre-fill the email/username field during SSO authentication
+ Hint *string `protobuf:"bytes,33,opt,name=hint,proto3,oneof" json:"hint,omitempty"`
+ EnableSSHRoot *bool `protobuf:"varint,34,opt,name=enableSSHRoot,proto3,oneof" json:"enableSSHRoot,omitempty"`
+ EnableSSHSFTP *bool `protobuf:"varint,35,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"`
+ EnableSSHLocalPortForwarding *bool `protobuf:"varint,36,opt,name=enableSSHLocalPortForwarding,proto3,oneof" json:"enableSSHLocalPortForwarding,omitempty"`
+ EnableSSHRemotePortForwarding *bool `protobuf:"varint,37,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
+ DisableSSHAuth *bool `protobuf:"varint,38,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
+ SshJWTCacheTTL *int32 `protobuf:"varint,39,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
}
func (x *LoginRequest) Reset() {
*x = LoginRequest{}
- mi := &file_daemon_proto_msgTypes[1]
+ mi := &file_daemon_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -298,7 +435,7 @@ func (x *LoginRequest) String() string {
func (*LoginRequest) ProtoMessage() {}
func (x *LoginRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[1]
+ mi := &file_daemon_proto_msgTypes[3]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -311,7 +448,7 @@ func (x *LoginRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use LoginRequest.ProtoReflect.Descriptor instead.
func (*LoginRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{1}
+ return file_daemon_proto_rawDescGZIP(), []int{3}
}
func (x *LoginRequest) GetSetupKey() string {
@@ -539,6 +676,55 @@ func (x *LoginRequest) GetMtu() int64 {
return 0
}
+func (x *LoginRequest) GetHint() string {
+ if x != nil && x.Hint != nil {
+ return *x.Hint
+ }
+ return ""
+}
+
+func (x *LoginRequest) GetEnableSSHRoot() bool {
+ if x != nil && x.EnableSSHRoot != nil {
+ return *x.EnableSSHRoot
+ }
+ return false
+}
+
+func (x *LoginRequest) GetEnableSSHSFTP() bool {
+ if x != nil && x.EnableSSHSFTP != nil {
+ return *x.EnableSSHSFTP
+ }
+ return false
+}
+
+func (x *LoginRequest) GetEnableSSHLocalPortForwarding() bool {
+ if x != nil && x.EnableSSHLocalPortForwarding != nil {
+ return *x.EnableSSHLocalPortForwarding
+ }
+ return false
+}
+
+func (x *LoginRequest) GetEnableSSHRemotePortForwarding() bool {
+ if x != nil && x.EnableSSHRemotePortForwarding != nil {
+ return *x.EnableSSHRemotePortForwarding
+ }
+ return false
+}
+
+func (x *LoginRequest) GetDisableSSHAuth() bool {
+ if x != nil && x.DisableSSHAuth != nil {
+ return *x.DisableSSHAuth
+ }
+ return false
+}
+
+func (x *LoginRequest) GetSshJWTCacheTTL() int32 {
+ if x != nil && x.SshJWTCacheTTL != nil {
+ return *x.SshJWTCacheTTL
+ }
+ return 0
+}
+
type LoginResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
@@ -551,7 +737,7 @@ type LoginResponse struct {
func (x *LoginResponse) Reset() {
*x = LoginResponse{}
- mi := &file_daemon_proto_msgTypes[2]
+ mi := &file_daemon_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -563,7 +749,7 @@ func (x *LoginResponse) String() string {
func (*LoginResponse) ProtoMessage() {}
func (x *LoginResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[2]
+ mi := &file_daemon_proto_msgTypes[4]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -576,7 +762,7 @@ func (x *LoginResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use LoginResponse.ProtoReflect.Descriptor instead.
func (*LoginResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{2}
+ return file_daemon_proto_rawDescGZIP(), []int{4}
}
func (x *LoginResponse) GetNeedsSSOLogin() bool {
@@ -617,7 +803,7 @@ type WaitSSOLoginRequest struct {
func (x *WaitSSOLoginRequest) Reset() {
*x = WaitSSOLoginRequest{}
- mi := &file_daemon_proto_msgTypes[3]
+ mi := &file_daemon_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -629,7 +815,7 @@ func (x *WaitSSOLoginRequest) String() string {
func (*WaitSSOLoginRequest) ProtoMessage() {}
func (x *WaitSSOLoginRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[3]
+ mi := &file_daemon_proto_msgTypes[5]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -642,7 +828,7 @@ func (x *WaitSSOLoginRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use WaitSSOLoginRequest.ProtoReflect.Descriptor instead.
func (*WaitSSOLoginRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{3}
+ return file_daemon_proto_rawDescGZIP(), []int{5}
}
func (x *WaitSSOLoginRequest) GetUserCode() string {
@@ -668,7 +854,7 @@ type WaitSSOLoginResponse struct {
func (x *WaitSSOLoginResponse) Reset() {
*x = WaitSSOLoginResponse{}
- mi := &file_daemon_proto_msgTypes[4]
+ mi := &file_daemon_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -680,7 +866,7 @@ func (x *WaitSSOLoginResponse) String() string {
func (*WaitSSOLoginResponse) ProtoMessage() {}
func (x *WaitSSOLoginResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[4]
+ mi := &file_daemon_proto_msgTypes[6]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -693,7 +879,7 @@ func (x *WaitSSOLoginResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use WaitSSOLoginResponse.ProtoReflect.Descriptor instead.
func (*WaitSSOLoginResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{4}
+ return file_daemon_proto_rawDescGZIP(), []int{6}
}
func (x *WaitSSOLoginResponse) GetEmail() string {
@@ -707,13 +893,14 @@ type UpRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"`
+ AutoUpdate *bool `protobuf:"varint,3,opt,name=autoUpdate,proto3,oneof" json:"autoUpdate,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *UpRequest) Reset() {
*x = UpRequest{}
- mi := &file_daemon_proto_msgTypes[5]
+ mi := &file_daemon_proto_msgTypes[7]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -725,7 +912,7 @@ func (x *UpRequest) String() string {
func (*UpRequest) ProtoMessage() {}
func (x *UpRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[5]
+ mi := &file_daemon_proto_msgTypes[7]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -738,7 +925,7 @@ func (x *UpRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use UpRequest.ProtoReflect.Descriptor instead.
func (*UpRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{5}
+ return file_daemon_proto_rawDescGZIP(), []int{7}
}
func (x *UpRequest) GetProfileName() string {
@@ -755,6 +942,13 @@ func (x *UpRequest) GetUsername() string {
return ""
}
+func (x *UpRequest) GetAutoUpdate() bool {
+ if x != nil && x.AutoUpdate != nil {
+ return *x.AutoUpdate
+ }
+ return false
+}
+
type UpResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -763,7 +957,7 @@ type UpResponse struct {
func (x *UpResponse) Reset() {
*x = UpResponse{}
- mi := &file_daemon_proto_msgTypes[6]
+ mi := &file_daemon_proto_msgTypes[8]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -775,7 +969,7 @@ func (x *UpResponse) String() string {
func (*UpResponse) ProtoMessage() {}
func (x *UpResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[6]
+ mi := &file_daemon_proto_msgTypes[8]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -788,7 +982,7 @@ func (x *UpResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use UpResponse.ProtoReflect.Descriptor instead.
func (*UpResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{6}
+ return file_daemon_proto_rawDescGZIP(), []int{8}
}
type StatusRequest struct {
@@ -803,7 +997,7 @@ type StatusRequest struct {
func (x *StatusRequest) Reset() {
*x = StatusRequest{}
- mi := &file_daemon_proto_msgTypes[7]
+ mi := &file_daemon_proto_msgTypes[9]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -815,7 +1009,7 @@ func (x *StatusRequest) String() string {
func (*StatusRequest) ProtoMessage() {}
func (x *StatusRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[7]
+ mi := &file_daemon_proto_msgTypes[9]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -828,7 +1022,7 @@ func (x *StatusRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use StatusRequest.ProtoReflect.Descriptor instead.
func (*StatusRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{7}
+ return file_daemon_proto_rawDescGZIP(), []int{9}
}
func (x *StatusRequest) GetGetFullPeerStatus() bool {
@@ -865,7 +1059,7 @@ type StatusResponse struct {
func (x *StatusResponse) Reset() {
*x = StatusResponse{}
- mi := &file_daemon_proto_msgTypes[8]
+ mi := &file_daemon_proto_msgTypes[10]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -877,7 +1071,7 @@ func (x *StatusResponse) String() string {
func (*StatusResponse) ProtoMessage() {}
func (x *StatusResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[8]
+ mi := &file_daemon_proto_msgTypes[10]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -890,7 +1084,7 @@ func (x *StatusResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use StatusResponse.ProtoReflect.Descriptor instead.
func (*StatusResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{8}
+ return file_daemon_proto_rawDescGZIP(), []int{10}
}
func (x *StatusResponse) GetStatus() string {
@@ -922,7 +1116,7 @@ type DownRequest struct {
func (x *DownRequest) Reset() {
*x = DownRequest{}
- mi := &file_daemon_proto_msgTypes[9]
+ mi := &file_daemon_proto_msgTypes[11]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -934,7 +1128,7 @@ func (x *DownRequest) String() string {
func (*DownRequest) ProtoMessage() {}
func (x *DownRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[9]
+ mi := &file_daemon_proto_msgTypes[11]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -947,7 +1141,7 @@ func (x *DownRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use DownRequest.ProtoReflect.Descriptor instead.
func (*DownRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{9}
+ return file_daemon_proto_rawDescGZIP(), []int{11}
}
type DownResponse struct {
@@ -958,7 +1152,7 @@ type DownResponse struct {
func (x *DownResponse) Reset() {
*x = DownResponse{}
- mi := &file_daemon_proto_msgTypes[10]
+ mi := &file_daemon_proto_msgTypes[12]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -970,7 +1164,7 @@ func (x *DownResponse) String() string {
func (*DownResponse) ProtoMessage() {}
func (x *DownResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[10]
+ mi := &file_daemon_proto_msgTypes[12]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -983,7 +1177,7 @@ func (x *DownResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use DownResponse.ProtoReflect.Descriptor instead.
func (*DownResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{10}
+ return file_daemon_proto_rawDescGZIP(), []int{12}
}
type GetConfigRequest struct {
@@ -996,7 +1190,7 @@ type GetConfigRequest struct {
func (x *GetConfigRequest) Reset() {
*x = GetConfigRequest{}
- mi := &file_daemon_proto_msgTypes[11]
+ mi := &file_daemon_proto_msgTypes[13]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1008,7 +1202,7 @@ func (x *GetConfigRequest) String() string {
func (*GetConfigRequest) ProtoMessage() {}
func (x *GetConfigRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[11]
+ mi := &file_daemon_proto_msgTypes[13]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1021,7 +1215,7 @@ func (x *GetConfigRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetConfigRequest.ProtoReflect.Descriptor instead.
func (*GetConfigRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{11}
+ return file_daemon_proto_rawDescGZIP(), []int{13}
}
func (x *GetConfigRequest) GetProfileName() string {
@@ -1049,30 +1243,36 @@ type GetConfigResponse struct {
// preSharedKey settings value.
PreSharedKey string `protobuf:"bytes,4,opt,name=preSharedKey,proto3" json:"preSharedKey,omitempty"`
// adminURL settings value.
- AdminURL string `protobuf:"bytes,5,opt,name=adminURL,proto3" json:"adminURL,omitempty"`
- InterfaceName string `protobuf:"bytes,6,opt,name=interfaceName,proto3" json:"interfaceName,omitempty"`
- WireguardPort int64 `protobuf:"varint,7,opt,name=wireguardPort,proto3" json:"wireguardPort,omitempty"`
- Mtu int64 `protobuf:"varint,8,opt,name=mtu,proto3" json:"mtu,omitempty"`
- DisableAutoConnect bool `protobuf:"varint,9,opt,name=disableAutoConnect,proto3" json:"disableAutoConnect,omitempty"`
- ServerSSHAllowed bool `protobuf:"varint,10,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"`
- RosenpassEnabled bool `protobuf:"varint,11,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"`
- RosenpassPermissive bool `protobuf:"varint,12,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"`
- DisableNotifications bool `protobuf:"varint,13,opt,name=disable_notifications,json=disableNotifications,proto3" json:"disable_notifications,omitempty"`
- LazyConnectionEnabled bool `protobuf:"varint,14,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"`
- BlockInbound bool `protobuf:"varint,15,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"`
- NetworkMonitor bool `protobuf:"varint,16,opt,name=networkMonitor,proto3" json:"networkMonitor,omitempty"`
- DisableDns bool `protobuf:"varint,17,opt,name=disable_dns,json=disableDns,proto3" json:"disable_dns,omitempty"`
- DisableClientRoutes bool `protobuf:"varint,18,opt,name=disable_client_routes,json=disableClientRoutes,proto3" json:"disable_client_routes,omitempty"`
- DisableServerRoutes bool `protobuf:"varint,19,opt,name=disable_server_routes,json=disableServerRoutes,proto3" json:"disable_server_routes,omitempty"`
- BlockLanAccess bool `protobuf:"varint,20,opt,name=block_lan_access,json=blockLanAccess,proto3" json:"block_lan_access,omitempty"`
- DisableFirewall bool `protobuf:"varint,21,opt,name=disable_firewall,json=disableFirewall,proto3" json:"disable_firewall,omitempty"`
- unknownFields protoimpl.UnknownFields
- sizeCache protoimpl.SizeCache
+ AdminURL string `protobuf:"bytes,5,opt,name=adminURL,proto3" json:"adminURL,omitempty"`
+ InterfaceName string `protobuf:"bytes,6,opt,name=interfaceName,proto3" json:"interfaceName,omitempty"`
+ WireguardPort int64 `protobuf:"varint,7,opt,name=wireguardPort,proto3" json:"wireguardPort,omitempty"`
+ Mtu int64 `protobuf:"varint,8,opt,name=mtu,proto3" json:"mtu,omitempty"`
+ DisableAutoConnect bool `protobuf:"varint,9,opt,name=disableAutoConnect,proto3" json:"disableAutoConnect,omitempty"`
+ ServerSSHAllowed bool `protobuf:"varint,10,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"`
+ RosenpassEnabled bool `protobuf:"varint,11,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"`
+ RosenpassPermissive bool `protobuf:"varint,12,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"`
+ DisableNotifications bool `protobuf:"varint,13,opt,name=disable_notifications,json=disableNotifications,proto3" json:"disable_notifications,omitempty"`
+ LazyConnectionEnabled bool `protobuf:"varint,14,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"`
+ BlockInbound bool `protobuf:"varint,15,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"`
+ NetworkMonitor bool `protobuf:"varint,16,opt,name=networkMonitor,proto3" json:"networkMonitor,omitempty"`
+ DisableDns bool `protobuf:"varint,17,opt,name=disable_dns,json=disableDns,proto3" json:"disable_dns,omitempty"`
+ DisableClientRoutes bool `protobuf:"varint,18,opt,name=disable_client_routes,json=disableClientRoutes,proto3" json:"disable_client_routes,omitempty"`
+ DisableServerRoutes bool `protobuf:"varint,19,opt,name=disable_server_routes,json=disableServerRoutes,proto3" json:"disable_server_routes,omitempty"`
+ BlockLanAccess bool `protobuf:"varint,20,opt,name=block_lan_access,json=blockLanAccess,proto3" json:"block_lan_access,omitempty"`
+ EnableSSHRoot bool `protobuf:"varint,21,opt,name=enableSSHRoot,proto3" json:"enableSSHRoot,omitempty"`
+ EnableSSHSFTP bool `protobuf:"varint,24,opt,name=enableSSHSFTP,proto3" json:"enableSSHSFTP,omitempty"`
+ EnableSSHLocalPortForwarding bool `protobuf:"varint,22,opt,name=enableSSHLocalPortForwarding,proto3" json:"enableSSHLocalPortForwarding,omitempty"`
+ EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"`
+ DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"`
+ SshJWTCacheTTL int32 `protobuf:"varint,26,opt,name=sshJWTCacheTTL,proto3" json:"sshJWTCacheTTL,omitempty"`
+ DisableFirewall bool `protobuf:"varint,27,opt,name=disable_firewall,json=disableFirewall,proto3" json:"disable_firewall,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
}
func (x *GetConfigResponse) Reset() {
*x = GetConfigResponse{}
- mi := &file_daemon_proto_msgTypes[12]
+ mi := &file_daemon_proto_msgTypes[14]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1084,7 +1284,7 @@ func (x *GetConfigResponse) String() string {
func (*GetConfigResponse) ProtoMessage() {}
func (x *GetConfigResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[12]
+ mi := &file_daemon_proto_msgTypes[14]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1097,7 +1297,7 @@ func (x *GetConfigResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetConfigResponse.ProtoReflect.Descriptor instead.
func (*GetConfigResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{12}
+ return file_daemon_proto_rawDescGZIP(), []int{14}
}
func (x *GetConfigResponse) GetManagementUrl() string {
@@ -1240,6 +1440,48 @@ func (x *GetConfigResponse) GetBlockLanAccess() bool {
return false
}
+func (x *GetConfigResponse) GetEnableSSHRoot() bool {
+ if x != nil {
+ return x.EnableSSHRoot
+ }
+ return false
+}
+
+func (x *GetConfigResponse) GetEnableSSHSFTP() bool {
+ if x != nil {
+ return x.EnableSSHSFTP
+ }
+ return false
+}
+
+func (x *GetConfigResponse) GetEnableSSHLocalPortForwarding() bool {
+ if x != nil {
+ return x.EnableSSHLocalPortForwarding
+ }
+ return false
+}
+
+func (x *GetConfigResponse) GetEnableSSHRemotePortForwarding() bool {
+ if x != nil {
+ return x.EnableSSHRemotePortForwarding
+ }
+ return false
+}
+
+func (x *GetConfigResponse) GetDisableSSHAuth() bool {
+ if x != nil {
+ return x.DisableSSHAuth
+ }
+ return false
+}
+
+func (x *GetConfigResponse) GetSshJWTCacheTTL() int32 {
+ if x != nil {
+ return x.SshJWTCacheTTL
+ }
+ return 0
+}
+
func (x *GetConfigResponse) GetDisableFirewall() bool {
if x != nil {
return x.DisableFirewall
@@ -1267,13 +1509,14 @@ type PeerState struct {
Networks []string `protobuf:"bytes,16,rep,name=networks,proto3" json:"networks,omitempty"`
Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"`
RelayAddress string `protobuf:"bytes,18,opt,name=relayAddress,proto3" json:"relayAddress,omitempty"`
+ SshHostKey []byte `protobuf:"bytes,19,opt,name=sshHostKey,proto3" json:"sshHostKey,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *PeerState) Reset() {
*x = PeerState{}
- mi := &file_daemon_proto_msgTypes[13]
+ mi := &file_daemon_proto_msgTypes[15]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1285,7 +1528,7 @@ func (x *PeerState) String() string {
func (*PeerState) ProtoMessage() {}
func (x *PeerState) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[13]
+ mi := &file_daemon_proto_msgTypes[15]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1298,7 +1541,7 @@ func (x *PeerState) ProtoReflect() protoreflect.Message {
// Deprecated: Use PeerState.ProtoReflect.Descriptor instead.
func (*PeerState) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{13}
+ return file_daemon_proto_rawDescGZIP(), []int{15}
}
func (x *PeerState) GetIP() string {
@@ -1420,6 +1663,13 @@ func (x *PeerState) GetRelayAddress() string {
return ""
}
+func (x *PeerState) GetSshHostKey() []byte {
+ if x != nil {
+ return x.SshHostKey
+ }
+ return nil
+}
+
// LocalPeerState contains the latest state of the local peer
type LocalPeerState struct {
state protoimpl.MessageState `protogen:"open.v1"`
@@ -1436,7 +1686,7 @@ type LocalPeerState struct {
func (x *LocalPeerState) Reset() {
*x = LocalPeerState{}
- mi := &file_daemon_proto_msgTypes[14]
+ mi := &file_daemon_proto_msgTypes[16]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1448,7 +1698,7 @@ func (x *LocalPeerState) String() string {
func (*LocalPeerState) ProtoMessage() {}
func (x *LocalPeerState) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[14]
+ mi := &file_daemon_proto_msgTypes[16]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1461,7 +1711,7 @@ func (x *LocalPeerState) ProtoReflect() protoreflect.Message {
// Deprecated: Use LocalPeerState.ProtoReflect.Descriptor instead.
func (*LocalPeerState) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{14}
+ return file_daemon_proto_rawDescGZIP(), []int{16}
}
func (x *LocalPeerState) GetIP() string {
@@ -1525,7 +1775,7 @@ type SignalState struct {
func (x *SignalState) Reset() {
*x = SignalState{}
- mi := &file_daemon_proto_msgTypes[15]
+ mi := &file_daemon_proto_msgTypes[17]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1537,7 +1787,7 @@ func (x *SignalState) String() string {
func (*SignalState) ProtoMessage() {}
func (x *SignalState) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[15]
+ mi := &file_daemon_proto_msgTypes[17]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1550,7 +1800,7 @@ func (x *SignalState) ProtoReflect() protoreflect.Message {
// Deprecated: Use SignalState.ProtoReflect.Descriptor instead.
func (*SignalState) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{15}
+ return file_daemon_proto_rawDescGZIP(), []int{17}
}
func (x *SignalState) GetURL() string {
@@ -1586,7 +1836,7 @@ type ManagementState struct {
func (x *ManagementState) Reset() {
*x = ManagementState{}
- mi := &file_daemon_proto_msgTypes[16]
+ mi := &file_daemon_proto_msgTypes[18]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1598,7 +1848,7 @@ func (x *ManagementState) String() string {
func (*ManagementState) ProtoMessage() {}
func (x *ManagementState) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[16]
+ mi := &file_daemon_proto_msgTypes[18]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1611,7 +1861,7 @@ func (x *ManagementState) ProtoReflect() protoreflect.Message {
// Deprecated: Use ManagementState.ProtoReflect.Descriptor instead.
func (*ManagementState) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{16}
+ return file_daemon_proto_rawDescGZIP(), []int{18}
}
func (x *ManagementState) GetURL() string {
@@ -1647,7 +1897,7 @@ type RelayState struct {
func (x *RelayState) Reset() {
*x = RelayState{}
- mi := &file_daemon_proto_msgTypes[17]
+ mi := &file_daemon_proto_msgTypes[19]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1659,7 +1909,7 @@ func (x *RelayState) String() string {
func (*RelayState) ProtoMessage() {}
func (x *RelayState) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[17]
+ mi := &file_daemon_proto_msgTypes[19]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1672,7 +1922,7 @@ func (x *RelayState) ProtoReflect() protoreflect.Message {
// Deprecated: Use RelayState.ProtoReflect.Descriptor instead.
func (*RelayState) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{17}
+ return file_daemon_proto_rawDescGZIP(), []int{19}
}
func (x *RelayState) GetURI() string {
@@ -1708,7 +1958,7 @@ type NSGroupState struct {
func (x *NSGroupState) Reset() {
*x = NSGroupState{}
- mi := &file_daemon_proto_msgTypes[18]
+ mi := &file_daemon_proto_msgTypes[20]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1720,7 +1970,7 @@ func (x *NSGroupState) String() string {
func (*NSGroupState) ProtoMessage() {}
func (x *NSGroupState) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[18]
+ mi := &file_daemon_proto_msgTypes[20]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1733,7 +1983,7 @@ func (x *NSGroupState) ProtoReflect() protoreflect.Message {
// Deprecated: Use NSGroupState.ProtoReflect.Descriptor instead.
func (*NSGroupState) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{18}
+ return file_daemon_proto_rawDescGZIP(), []int{20}
}
func (x *NSGroupState) GetServers() []string {
@@ -1764,6 +2014,128 @@ func (x *NSGroupState) GetError() string {
return ""
}
+// SSHSessionInfo contains information about an active SSH session
+type SSHSessionInfo struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
+ RemoteAddress string `protobuf:"bytes,2,opt,name=remoteAddress,proto3" json:"remoteAddress,omitempty"`
+ Command string `protobuf:"bytes,3,opt,name=command,proto3" json:"command,omitempty"`
+ JwtUsername string `protobuf:"bytes,4,opt,name=jwtUsername,proto3" json:"jwtUsername,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *SSHSessionInfo) Reset() {
+ *x = SSHSessionInfo{}
+ mi := &file_daemon_proto_msgTypes[21]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *SSHSessionInfo) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*SSHSessionInfo) ProtoMessage() {}
+
+func (x *SSHSessionInfo) ProtoReflect() protoreflect.Message {
+ mi := &file_daemon_proto_msgTypes[21]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use SSHSessionInfo.ProtoReflect.Descriptor instead.
+func (*SSHSessionInfo) Descriptor() ([]byte, []int) {
+ return file_daemon_proto_rawDescGZIP(), []int{21}
+}
+
+func (x *SSHSessionInfo) GetUsername() string {
+ if x != nil {
+ return x.Username
+ }
+ return ""
+}
+
+func (x *SSHSessionInfo) GetRemoteAddress() string {
+ if x != nil {
+ return x.RemoteAddress
+ }
+ return ""
+}
+
+func (x *SSHSessionInfo) GetCommand() string {
+ if x != nil {
+ return x.Command
+ }
+ return ""
+}
+
+func (x *SSHSessionInfo) GetJwtUsername() string {
+ if x != nil {
+ return x.JwtUsername
+ }
+ return ""
+}
+
+// SSHServerState contains the latest state of the SSH server
+type SSHServerState struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"`
+ Sessions []*SSHSessionInfo `protobuf:"bytes,2,rep,name=sessions,proto3" json:"sessions,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *SSHServerState) Reset() {
+ *x = SSHServerState{}
+ mi := &file_daemon_proto_msgTypes[22]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *SSHServerState) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*SSHServerState) ProtoMessage() {}
+
+func (x *SSHServerState) ProtoReflect() protoreflect.Message {
+ mi := &file_daemon_proto_msgTypes[22]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use SSHServerState.ProtoReflect.Descriptor instead.
+func (*SSHServerState) Descriptor() ([]byte, []int) {
+ return file_daemon_proto_rawDescGZIP(), []int{22}
+}
+
+func (x *SSHServerState) GetEnabled() bool {
+ if x != nil {
+ return x.Enabled
+ }
+ return false
+}
+
+func (x *SSHServerState) GetSessions() []*SSHSessionInfo {
+ if x != nil {
+ return x.Sessions
+ }
+ return nil
+}
+
// FullStatus contains the full state held by the Status instance
type FullStatus struct {
state protoimpl.MessageState `protogen:"open.v1"`
@@ -1776,13 +2148,14 @@ type FullStatus struct {
NumberOfForwardingRules int32 `protobuf:"varint,8,opt,name=NumberOfForwardingRules,proto3" json:"NumberOfForwardingRules,omitempty"`
Events []*SystemEvent `protobuf:"bytes,7,rep,name=events,proto3" json:"events,omitempty"`
LazyConnectionEnabled bool `protobuf:"varint,9,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"`
+ SshServerState *SSHServerState `protobuf:"bytes,10,opt,name=sshServerState,proto3" json:"sshServerState,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *FullStatus) Reset() {
*x = FullStatus{}
- mi := &file_daemon_proto_msgTypes[19]
+ mi := &file_daemon_proto_msgTypes[23]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1794,7 +2167,7 @@ func (x *FullStatus) String() string {
func (*FullStatus) ProtoMessage() {}
func (x *FullStatus) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[19]
+ mi := &file_daemon_proto_msgTypes[23]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1807,7 +2180,7 @@ func (x *FullStatus) ProtoReflect() protoreflect.Message {
// Deprecated: Use FullStatus.ProtoReflect.Descriptor instead.
func (*FullStatus) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{19}
+ return file_daemon_proto_rawDescGZIP(), []int{23}
}
func (x *FullStatus) GetManagementState() *ManagementState {
@@ -1873,6 +2246,13 @@ func (x *FullStatus) GetLazyConnectionEnabled() bool {
return false
}
+func (x *FullStatus) GetSshServerState() *SSHServerState {
+ if x != nil {
+ return x.SshServerState
+ }
+ return nil
+}
+
// Networks
type ListNetworksRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
@@ -1882,7 +2262,7 @@ type ListNetworksRequest struct {
func (x *ListNetworksRequest) Reset() {
*x = ListNetworksRequest{}
- mi := &file_daemon_proto_msgTypes[20]
+ mi := &file_daemon_proto_msgTypes[24]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1894,7 +2274,7 @@ func (x *ListNetworksRequest) String() string {
func (*ListNetworksRequest) ProtoMessage() {}
func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[20]
+ mi := &file_daemon_proto_msgTypes[24]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1907,7 +2287,7 @@ func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use ListNetworksRequest.ProtoReflect.Descriptor instead.
func (*ListNetworksRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{20}
+ return file_daemon_proto_rawDescGZIP(), []int{24}
}
type ListNetworksResponse struct {
@@ -1919,7 +2299,7 @@ type ListNetworksResponse struct {
func (x *ListNetworksResponse) Reset() {
*x = ListNetworksResponse{}
- mi := &file_daemon_proto_msgTypes[21]
+ mi := &file_daemon_proto_msgTypes[25]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1931,7 +2311,7 @@ func (x *ListNetworksResponse) String() string {
func (*ListNetworksResponse) ProtoMessage() {}
func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[21]
+ mi := &file_daemon_proto_msgTypes[25]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1944,7 +2324,7 @@ func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use ListNetworksResponse.ProtoReflect.Descriptor instead.
func (*ListNetworksResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{21}
+ return file_daemon_proto_rawDescGZIP(), []int{25}
}
func (x *ListNetworksResponse) GetRoutes() []*Network {
@@ -1965,7 +2345,7 @@ type SelectNetworksRequest struct {
func (x *SelectNetworksRequest) Reset() {
*x = SelectNetworksRequest{}
- mi := &file_daemon_proto_msgTypes[22]
+ mi := &file_daemon_proto_msgTypes[26]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1977,7 +2357,7 @@ func (x *SelectNetworksRequest) String() string {
func (*SelectNetworksRequest) ProtoMessage() {}
func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[22]
+ mi := &file_daemon_proto_msgTypes[26]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1990,7 +2370,7 @@ func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use SelectNetworksRequest.ProtoReflect.Descriptor instead.
func (*SelectNetworksRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{22}
+ return file_daemon_proto_rawDescGZIP(), []int{26}
}
func (x *SelectNetworksRequest) GetNetworkIDs() []string {
@@ -2022,7 +2402,7 @@ type SelectNetworksResponse struct {
func (x *SelectNetworksResponse) Reset() {
*x = SelectNetworksResponse{}
- mi := &file_daemon_proto_msgTypes[23]
+ mi := &file_daemon_proto_msgTypes[27]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2034,7 +2414,7 @@ func (x *SelectNetworksResponse) String() string {
func (*SelectNetworksResponse) ProtoMessage() {}
func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[23]
+ mi := &file_daemon_proto_msgTypes[27]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2047,7 +2427,7 @@ func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use SelectNetworksResponse.ProtoReflect.Descriptor instead.
func (*SelectNetworksResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{23}
+ return file_daemon_proto_rawDescGZIP(), []int{27}
}
type IPList struct {
@@ -2059,7 +2439,7 @@ type IPList struct {
func (x *IPList) Reset() {
*x = IPList{}
- mi := &file_daemon_proto_msgTypes[24]
+ mi := &file_daemon_proto_msgTypes[28]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2071,7 +2451,7 @@ func (x *IPList) String() string {
func (*IPList) ProtoMessage() {}
func (x *IPList) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[24]
+ mi := &file_daemon_proto_msgTypes[28]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2084,7 +2464,7 @@ func (x *IPList) ProtoReflect() protoreflect.Message {
// Deprecated: Use IPList.ProtoReflect.Descriptor instead.
func (*IPList) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{24}
+ return file_daemon_proto_rawDescGZIP(), []int{28}
}
func (x *IPList) GetIps() []string {
@@ -2107,7 +2487,7 @@ type Network struct {
func (x *Network) Reset() {
*x = Network{}
- mi := &file_daemon_proto_msgTypes[25]
+ mi := &file_daemon_proto_msgTypes[29]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2119,7 +2499,7 @@ func (x *Network) String() string {
func (*Network) ProtoMessage() {}
func (x *Network) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[25]
+ mi := &file_daemon_proto_msgTypes[29]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2132,7 +2512,7 @@ func (x *Network) ProtoReflect() protoreflect.Message {
// Deprecated: Use Network.ProtoReflect.Descriptor instead.
func (*Network) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{25}
+ return file_daemon_proto_rawDescGZIP(), []int{29}
}
func (x *Network) GetID() string {
@@ -2184,7 +2564,7 @@ type PortInfo struct {
func (x *PortInfo) Reset() {
*x = PortInfo{}
- mi := &file_daemon_proto_msgTypes[26]
+ mi := &file_daemon_proto_msgTypes[30]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2196,7 +2576,7 @@ func (x *PortInfo) String() string {
func (*PortInfo) ProtoMessage() {}
func (x *PortInfo) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[26]
+ mi := &file_daemon_proto_msgTypes[30]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2209,7 +2589,7 @@ func (x *PortInfo) ProtoReflect() protoreflect.Message {
// Deprecated: Use PortInfo.ProtoReflect.Descriptor instead.
func (*PortInfo) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{26}
+ return file_daemon_proto_rawDescGZIP(), []int{30}
}
func (x *PortInfo) GetPortSelection() isPortInfo_PortSelection {
@@ -2266,7 +2646,7 @@ type ForwardingRule struct {
func (x *ForwardingRule) Reset() {
*x = ForwardingRule{}
- mi := &file_daemon_proto_msgTypes[27]
+ mi := &file_daemon_proto_msgTypes[31]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2278,7 +2658,7 @@ func (x *ForwardingRule) String() string {
func (*ForwardingRule) ProtoMessage() {}
func (x *ForwardingRule) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[27]
+ mi := &file_daemon_proto_msgTypes[31]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2291,7 +2671,7 @@ func (x *ForwardingRule) ProtoReflect() protoreflect.Message {
// Deprecated: Use ForwardingRule.ProtoReflect.Descriptor instead.
func (*ForwardingRule) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{27}
+ return file_daemon_proto_rawDescGZIP(), []int{31}
}
func (x *ForwardingRule) GetProtocol() string {
@@ -2338,7 +2718,7 @@ type ForwardingRulesResponse struct {
func (x *ForwardingRulesResponse) Reset() {
*x = ForwardingRulesResponse{}
- mi := &file_daemon_proto_msgTypes[28]
+ mi := &file_daemon_proto_msgTypes[32]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2350,7 +2730,7 @@ func (x *ForwardingRulesResponse) String() string {
func (*ForwardingRulesResponse) ProtoMessage() {}
func (x *ForwardingRulesResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[28]
+ mi := &file_daemon_proto_msgTypes[32]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2363,7 +2743,7 @@ func (x *ForwardingRulesResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use ForwardingRulesResponse.ProtoReflect.Descriptor instead.
func (*ForwardingRulesResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{28}
+ return file_daemon_proto_rawDescGZIP(), []int{32}
}
func (x *ForwardingRulesResponse) GetRules() []*ForwardingRule {
@@ -2387,7 +2767,7 @@ type DebugBundleRequest struct {
func (x *DebugBundleRequest) Reset() {
*x = DebugBundleRequest{}
- mi := &file_daemon_proto_msgTypes[29]
+ mi := &file_daemon_proto_msgTypes[33]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2399,7 +2779,7 @@ func (x *DebugBundleRequest) String() string {
func (*DebugBundleRequest) ProtoMessage() {}
func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[29]
+ mi := &file_daemon_proto_msgTypes[33]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2412,7 +2792,7 @@ func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use DebugBundleRequest.ProtoReflect.Descriptor instead.
func (*DebugBundleRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{29}
+ return file_daemon_proto_rawDescGZIP(), []int{33}
}
func (x *DebugBundleRequest) GetAnonymize() bool {
@@ -2461,7 +2841,7 @@ type DebugBundleResponse struct {
func (x *DebugBundleResponse) Reset() {
*x = DebugBundleResponse{}
- mi := &file_daemon_proto_msgTypes[30]
+ mi := &file_daemon_proto_msgTypes[34]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2473,7 +2853,7 @@ func (x *DebugBundleResponse) String() string {
func (*DebugBundleResponse) ProtoMessage() {}
func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[30]
+ mi := &file_daemon_proto_msgTypes[34]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2486,7 +2866,7 @@ func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use DebugBundleResponse.ProtoReflect.Descriptor instead.
func (*DebugBundleResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{30}
+ return file_daemon_proto_rawDescGZIP(), []int{34}
}
func (x *DebugBundleResponse) GetPath() string {
@@ -2518,7 +2898,7 @@ type GetLogLevelRequest struct {
func (x *GetLogLevelRequest) Reset() {
*x = GetLogLevelRequest{}
- mi := &file_daemon_proto_msgTypes[31]
+ mi := &file_daemon_proto_msgTypes[35]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2530,7 +2910,7 @@ func (x *GetLogLevelRequest) String() string {
func (*GetLogLevelRequest) ProtoMessage() {}
func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[31]
+ mi := &file_daemon_proto_msgTypes[35]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2543,7 +2923,7 @@ func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetLogLevelRequest.ProtoReflect.Descriptor instead.
func (*GetLogLevelRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{31}
+ return file_daemon_proto_rawDescGZIP(), []int{35}
}
type GetLogLevelResponse struct {
@@ -2555,7 +2935,7 @@ type GetLogLevelResponse struct {
func (x *GetLogLevelResponse) Reset() {
*x = GetLogLevelResponse{}
- mi := &file_daemon_proto_msgTypes[32]
+ mi := &file_daemon_proto_msgTypes[36]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2567,7 +2947,7 @@ func (x *GetLogLevelResponse) String() string {
func (*GetLogLevelResponse) ProtoMessage() {}
func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[32]
+ mi := &file_daemon_proto_msgTypes[36]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2580,7 +2960,7 @@ func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetLogLevelResponse.ProtoReflect.Descriptor instead.
func (*GetLogLevelResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{32}
+ return file_daemon_proto_rawDescGZIP(), []int{36}
}
func (x *GetLogLevelResponse) GetLevel() LogLevel {
@@ -2599,7 +2979,7 @@ type SetLogLevelRequest struct {
func (x *SetLogLevelRequest) Reset() {
*x = SetLogLevelRequest{}
- mi := &file_daemon_proto_msgTypes[33]
+ mi := &file_daemon_proto_msgTypes[37]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2611,7 +2991,7 @@ func (x *SetLogLevelRequest) String() string {
func (*SetLogLevelRequest) ProtoMessage() {}
func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[33]
+ mi := &file_daemon_proto_msgTypes[37]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2624,7 +3004,7 @@ func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use SetLogLevelRequest.ProtoReflect.Descriptor instead.
func (*SetLogLevelRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{33}
+ return file_daemon_proto_rawDescGZIP(), []int{37}
}
func (x *SetLogLevelRequest) GetLevel() LogLevel {
@@ -2642,7 +3022,7 @@ type SetLogLevelResponse struct {
func (x *SetLogLevelResponse) Reset() {
*x = SetLogLevelResponse{}
- mi := &file_daemon_proto_msgTypes[34]
+ mi := &file_daemon_proto_msgTypes[38]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2654,7 +3034,7 @@ func (x *SetLogLevelResponse) String() string {
func (*SetLogLevelResponse) ProtoMessage() {}
func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[34]
+ mi := &file_daemon_proto_msgTypes[38]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2667,7 +3047,7 @@ func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use SetLogLevelResponse.ProtoReflect.Descriptor instead.
func (*SetLogLevelResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{34}
+ return file_daemon_proto_rawDescGZIP(), []int{38}
}
// State represents a daemon state entry
@@ -2680,7 +3060,7 @@ type State struct {
func (x *State) Reset() {
*x = State{}
- mi := &file_daemon_proto_msgTypes[35]
+ mi := &file_daemon_proto_msgTypes[39]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2692,7 +3072,7 @@ func (x *State) String() string {
func (*State) ProtoMessage() {}
func (x *State) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[35]
+ mi := &file_daemon_proto_msgTypes[39]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2705,7 +3085,7 @@ func (x *State) ProtoReflect() protoreflect.Message {
// Deprecated: Use State.ProtoReflect.Descriptor instead.
func (*State) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{35}
+ return file_daemon_proto_rawDescGZIP(), []int{39}
}
func (x *State) GetName() string {
@@ -2724,7 +3104,7 @@ type ListStatesRequest struct {
func (x *ListStatesRequest) Reset() {
*x = ListStatesRequest{}
- mi := &file_daemon_proto_msgTypes[36]
+ mi := &file_daemon_proto_msgTypes[40]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2736,7 +3116,7 @@ func (x *ListStatesRequest) String() string {
func (*ListStatesRequest) ProtoMessage() {}
func (x *ListStatesRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[36]
+ mi := &file_daemon_proto_msgTypes[40]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2749,7 +3129,7 @@ func (x *ListStatesRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use ListStatesRequest.ProtoReflect.Descriptor instead.
func (*ListStatesRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{36}
+ return file_daemon_proto_rawDescGZIP(), []int{40}
}
// ListStatesResponse contains a list of states
@@ -2762,7 +3142,7 @@ type ListStatesResponse struct {
func (x *ListStatesResponse) Reset() {
*x = ListStatesResponse{}
- mi := &file_daemon_proto_msgTypes[37]
+ mi := &file_daemon_proto_msgTypes[41]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2774,7 +3154,7 @@ func (x *ListStatesResponse) String() string {
func (*ListStatesResponse) ProtoMessage() {}
func (x *ListStatesResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[37]
+ mi := &file_daemon_proto_msgTypes[41]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2787,7 +3167,7 @@ func (x *ListStatesResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use ListStatesResponse.ProtoReflect.Descriptor instead.
func (*ListStatesResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{37}
+ return file_daemon_proto_rawDescGZIP(), []int{41}
}
func (x *ListStatesResponse) GetStates() []*State {
@@ -2808,7 +3188,7 @@ type CleanStateRequest struct {
func (x *CleanStateRequest) Reset() {
*x = CleanStateRequest{}
- mi := &file_daemon_proto_msgTypes[38]
+ mi := &file_daemon_proto_msgTypes[42]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2820,7 +3200,7 @@ func (x *CleanStateRequest) String() string {
func (*CleanStateRequest) ProtoMessage() {}
func (x *CleanStateRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[38]
+ mi := &file_daemon_proto_msgTypes[42]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2833,7 +3213,7 @@ func (x *CleanStateRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use CleanStateRequest.ProtoReflect.Descriptor instead.
func (*CleanStateRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{38}
+ return file_daemon_proto_rawDescGZIP(), []int{42}
}
func (x *CleanStateRequest) GetStateName() string {
@@ -2860,7 +3240,7 @@ type CleanStateResponse struct {
func (x *CleanStateResponse) Reset() {
*x = CleanStateResponse{}
- mi := &file_daemon_proto_msgTypes[39]
+ mi := &file_daemon_proto_msgTypes[43]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2872,7 +3252,7 @@ func (x *CleanStateResponse) String() string {
func (*CleanStateResponse) ProtoMessage() {}
func (x *CleanStateResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[39]
+ mi := &file_daemon_proto_msgTypes[43]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2885,7 +3265,7 @@ func (x *CleanStateResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use CleanStateResponse.ProtoReflect.Descriptor instead.
func (*CleanStateResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{39}
+ return file_daemon_proto_rawDescGZIP(), []int{43}
}
func (x *CleanStateResponse) GetCleanedStates() int32 {
@@ -2906,7 +3286,7 @@ type DeleteStateRequest struct {
func (x *DeleteStateRequest) Reset() {
*x = DeleteStateRequest{}
- mi := &file_daemon_proto_msgTypes[40]
+ mi := &file_daemon_proto_msgTypes[44]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2918,7 +3298,7 @@ func (x *DeleteStateRequest) String() string {
func (*DeleteStateRequest) ProtoMessage() {}
func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[40]
+ mi := &file_daemon_proto_msgTypes[44]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2931,7 +3311,7 @@ func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use DeleteStateRequest.ProtoReflect.Descriptor instead.
func (*DeleteStateRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{40}
+ return file_daemon_proto_rawDescGZIP(), []int{44}
}
func (x *DeleteStateRequest) GetStateName() string {
@@ -2958,7 +3338,7 @@ type DeleteStateResponse struct {
func (x *DeleteStateResponse) Reset() {
*x = DeleteStateResponse{}
- mi := &file_daemon_proto_msgTypes[41]
+ mi := &file_daemon_proto_msgTypes[45]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2970,7 +3350,7 @@ func (x *DeleteStateResponse) String() string {
func (*DeleteStateResponse) ProtoMessage() {}
func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[41]
+ mi := &file_daemon_proto_msgTypes[45]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2983,7 +3363,7 @@ func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use DeleteStateResponse.ProtoReflect.Descriptor instead.
func (*DeleteStateResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{41}
+ return file_daemon_proto_rawDescGZIP(), []int{45}
}
func (x *DeleteStateResponse) GetDeletedStates() int32 {
@@ -3002,7 +3382,7 @@ type SetSyncResponsePersistenceRequest struct {
func (x *SetSyncResponsePersistenceRequest) Reset() {
*x = SetSyncResponsePersistenceRequest{}
- mi := &file_daemon_proto_msgTypes[42]
+ mi := &file_daemon_proto_msgTypes[46]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3014,7 +3394,7 @@ func (x *SetSyncResponsePersistenceRequest) String() string {
func (*SetSyncResponsePersistenceRequest) ProtoMessage() {}
func (x *SetSyncResponsePersistenceRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[42]
+ mi := &file_daemon_proto_msgTypes[46]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3027,7 +3407,7 @@ func (x *SetSyncResponsePersistenceRequest) ProtoReflect() protoreflect.Message
// Deprecated: Use SetSyncResponsePersistenceRequest.ProtoReflect.Descriptor instead.
func (*SetSyncResponsePersistenceRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{42}
+ return file_daemon_proto_rawDescGZIP(), []int{46}
}
func (x *SetSyncResponsePersistenceRequest) GetEnabled() bool {
@@ -3045,7 +3425,7 @@ type SetSyncResponsePersistenceResponse struct {
func (x *SetSyncResponsePersistenceResponse) Reset() {
*x = SetSyncResponsePersistenceResponse{}
- mi := &file_daemon_proto_msgTypes[43]
+ mi := &file_daemon_proto_msgTypes[47]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3057,7 +3437,7 @@ func (x *SetSyncResponsePersistenceResponse) String() string {
func (*SetSyncResponsePersistenceResponse) ProtoMessage() {}
func (x *SetSyncResponsePersistenceResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[43]
+ mi := &file_daemon_proto_msgTypes[47]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3070,7 +3450,7 @@ func (x *SetSyncResponsePersistenceResponse) ProtoReflect() protoreflect.Message
// Deprecated: Use SetSyncResponsePersistenceResponse.ProtoReflect.Descriptor instead.
func (*SetSyncResponsePersistenceResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{43}
+ return file_daemon_proto_rawDescGZIP(), []int{47}
}
type TCPFlags struct {
@@ -3087,7 +3467,7 @@ type TCPFlags struct {
func (x *TCPFlags) Reset() {
*x = TCPFlags{}
- mi := &file_daemon_proto_msgTypes[44]
+ mi := &file_daemon_proto_msgTypes[48]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3099,7 +3479,7 @@ func (x *TCPFlags) String() string {
func (*TCPFlags) ProtoMessage() {}
func (x *TCPFlags) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[44]
+ mi := &file_daemon_proto_msgTypes[48]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3112,7 +3492,7 @@ func (x *TCPFlags) ProtoReflect() protoreflect.Message {
// Deprecated: Use TCPFlags.ProtoReflect.Descriptor instead.
func (*TCPFlags) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{44}
+ return file_daemon_proto_rawDescGZIP(), []int{48}
}
func (x *TCPFlags) GetSyn() bool {
@@ -3174,7 +3554,7 @@ type TracePacketRequest struct {
func (x *TracePacketRequest) Reset() {
*x = TracePacketRequest{}
- mi := &file_daemon_proto_msgTypes[45]
+ mi := &file_daemon_proto_msgTypes[49]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3186,7 +3566,7 @@ func (x *TracePacketRequest) String() string {
func (*TracePacketRequest) ProtoMessage() {}
func (x *TracePacketRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[45]
+ mi := &file_daemon_proto_msgTypes[49]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3199,7 +3579,7 @@ func (x *TracePacketRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use TracePacketRequest.ProtoReflect.Descriptor instead.
func (*TracePacketRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{45}
+ return file_daemon_proto_rawDescGZIP(), []int{49}
}
func (x *TracePacketRequest) GetSourceIp() string {
@@ -3277,7 +3657,7 @@ type TraceStage struct {
func (x *TraceStage) Reset() {
*x = TraceStage{}
- mi := &file_daemon_proto_msgTypes[46]
+ mi := &file_daemon_proto_msgTypes[50]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3289,7 +3669,7 @@ func (x *TraceStage) String() string {
func (*TraceStage) ProtoMessage() {}
func (x *TraceStage) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[46]
+ mi := &file_daemon_proto_msgTypes[50]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3302,7 +3682,7 @@ func (x *TraceStage) ProtoReflect() protoreflect.Message {
// Deprecated: Use TraceStage.ProtoReflect.Descriptor instead.
func (*TraceStage) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{46}
+ return file_daemon_proto_rawDescGZIP(), []int{50}
}
func (x *TraceStage) GetName() string {
@@ -3343,7 +3723,7 @@ type TracePacketResponse struct {
func (x *TracePacketResponse) Reset() {
*x = TracePacketResponse{}
- mi := &file_daemon_proto_msgTypes[47]
+ mi := &file_daemon_proto_msgTypes[51]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3355,7 +3735,7 @@ func (x *TracePacketResponse) String() string {
func (*TracePacketResponse) ProtoMessage() {}
func (x *TracePacketResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[47]
+ mi := &file_daemon_proto_msgTypes[51]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3368,7 +3748,7 @@ func (x *TracePacketResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use TracePacketResponse.ProtoReflect.Descriptor instead.
func (*TracePacketResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{47}
+ return file_daemon_proto_rawDescGZIP(), []int{51}
}
func (x *TracePacketResponse) GetStages() []*TraceStage {
@@ -3393,7 +3773,7 @@ type SubscribeRequest struct {
func (x *SubscribeRequest) Reset() {
*x = SubscribeRequest{}
- mi := &file_daemon_proto_msgTypes[48]
+ mi := &file_daemon_proto_msgTypes[52]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3405,7 +3785,7 @@ func (x *SubscribeRequest) String() string {
func (*SubscribeRequest) ProtoMessage() {}
func (x *SubscribeRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[48]
+ mi := &file_daemon_proto_msgTypes[52]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3418,7 +3798,7 @@ func (x *SubscribeRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use SubscribeRequest.ProtoReflect.Descriptor instead.
func (*SubscribeRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{48}
+ return file_daemon_proto_rawDescGZIP(), []int{52}
}
type SystemEvent struct {
@@ -3436,7 +3816,7 @@ type SystemEvent struct {
func (x *SystemEvent) Reset() {
*x = SystemEvent{}
- mi := &file_daemon_proto_msgTypes[49]
+ mi := &file_daemon_proto_msgTypes[53]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3448,7 +3828,7 @@ func (x *SystemEvent) String() string {
func (*SystemEvent) ProtoMessage() {}
func (x *SystemEvent) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[49]
+ mi := &file_daemon_proto_msgTypes[53]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3461,7 +3841,7 @@ func (x *SystemEvent) ProtoReflect() protoreflect.Message {
// Deprecated: Use SystemEvent.ProtoReflect.Descriptor instead.
func (*SystemEvent) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{49}
+ return file_daemon_proto_rawDescGZIP(), []int{53}
}
func (x *SystemEvent) GetId() string {
@@ -3521,7 +3901,7 @@ type GetEventsRequest struct {
func (x *GetEventsRequest) Reset() {
*x = GetEventsRequest{}
- mi := &file_daemon_proto_msgTypes[50]
+ mi := &file_daemon_proto_msgTypes[54]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3533,7 +3913,7 @@ func (x *GetEventsRequest) String() string {
func (*GetEventsRequest) ProtoMessage() {}
func (x *GetEventsRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[50]
+ mi := &file_daemon_proto_msgTypes[54]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3546,7 +3926,7 @@ func (x *GetEventsRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetEventsRequest.ProtoReflect.Descriptor instead.
func (*GetEventsRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{50}
+ return file_daemon_proto_rawDescGZIP(), []int{54}
}
type GetEventsResponse struct {
@@ -3558,7 +3938,7 @@ type GetEventsResponse struct {
func (x *GetEventsResponse) Reset() {
*x = GetEventsResponse{}
- mi := &file_daemon_proto_msgTypes[51]
+ mi := &file_daemon_proto_msgTypes[55]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3570,7 +3950,7 @@ func (x *GetEventsResponse) String() string {
func (*GetEventsResponse) ProtoMessage() {}
func (x *GetEventsResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[51]
+ mi := &file_daemon_proto_msgTypes[55]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3583,7 +3963,7 @@ func (x *GetEventsResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetEventsResponse.ProtoReflect.Descriptor instead.
func (*GetEventsResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{51}
+ return file_daemon_proto_rawDescGZIP(), []int{55}
}
func (x *GetEventsResponse) GetEvents() []*SystemEvent {
@@ -3603,7 +3983,7 @@ type SwitchProfileRequest struct {
func (x *SwitchProfileRequest) Reset() {
*x = SwitchProfileRequest{}
- mi := &file_daemon_proto_msgTypes[52]
+ mi := &file_daemon_proto_msgTypes[56]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3615,7 +3995,7 @@ func (x *SwitchProfileRequest) String() string {
func (*SwitchProfileRequest) ProtoMessage() {}
func (x *SwitchProfileRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[52]
+ mi := &file_daemon_proto_msgTypes[56]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3628,7 +4008,7 @@ func (x *SwitchProfileRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use SwitchProfileRequest.ProtoReflect.Descriptor instead.
func (*SwitchProfileRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{52}
+ return file_daemon_proto_rawDescGZIP(), []int{56}
}
func (x *SwitchProfileRequest) GetProfileName() string {
@@ -3653,7 +4033,7 @@ type SwitchProfileResponse struct {
func (x *SwitchProfileResponse) Reset() {
*x = SwitchProfileResponse{}
- mi := &file_daemon_proto_msgTypes[53]
+ mi := &file_daemon_proto_msgTypes[57]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3665,7 +4045,7 @@ func (x *SwitchProfileResponse) String() string {
func (*SwitchProfileResponse) ProtoMessage() {}
func (x *SwitchProfileResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[53]
+ mi := &file_daemon_proto_msgTypes[57]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3678,7 +4058,7 @@ func (x *SwitchProfileResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use SwitchProfileResponse.ProtoReflect.Descriptor instead.
func (*SwitchProfileResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{53}
+ return file_daemon_proto_rawDescGZIP(), []int{57}
}
type SetConfigRequest struct {
@@ -3711,16 +4091,22 @@ type SetConfigRequest struct {
ExtraIFaceBlacklist []string `protobuf:"bytes,24,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"`
DnsLabels []string `protobuf:"bytes,25,rep,name=dns_labels,json=dnsLabels,proto3" json:"dns_labels,omitempty"`
// cleanDNSLabels clean map list of DNS labels.
- CleanDNSLabels bool `protobuf:"varint,26,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"`
- DnsRouteInterval *durationpb.Duration `protobuf:"bytes,27,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"`
- Mtu *int64 `protobuf:"varint,28,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"`
- unknownFields protoimpl.UnknownFields
- sizeCache protoimpl.SizeCache
+ CleanDNSLabels bool `protobuf:"varint,26,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"`
+ DnsRouteInterval *durationpb.Duration `protobuf:"bytes,27,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"`
+ Mtu *int64 `protobuf:"varint,28,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"`
+ EnableSSHRoot *bool `protobuf:"varint,29,opt,name=enableSSHRoot,proto3,oneof" json:"enableSSHRoot,omitempty"`
+ EnableSSHSFTP *bool `protobuf:"varint,30,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"`
+ EnableSSHLocalPortForwarding *bool `protobuf:"varint,31,opt,name=enableSSHLocalPortForwarding,proto3,oneof" json:"enableSSHLocalPortForwarding,omitempty"`
+ EnableSSHRemotePortForwarding *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
+ DisableSSHAuth *bool `protobuf:"varint,33,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
+ SshJWTCacheTTL *int32 `protobuf:"varint,34,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
}
func (x *SetConfigRequest) Reset() {
*x = SetConfigRequest{}
- mi := &file_daemon_proto_msgTypes[54]
+ mi := &file_daemon_proto_msgTypes[58]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3732,7 +4118,7 @@ func (x *SetConfigRequest) String() string {
func (*SetConfigRequest) ProtoMessage() {}
func (x *SetConfigRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[54]
+ mi := &file_daemon_proto_msgTypes[58]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3745,7 +4131,7 @@ func (x *SetConfigRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use SetConfigRequest.ProtoReflect.Descriptor instead.
func (*SetConfigRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{54}
+ return file_daemon_proto_rawDescGZIP(), []int{58}
}
func (x *SetConfigRequest) GetUsername() string {
@@ -3944,6 +4330,48 @@ func (x *SetConfigRequest) GetMtu() int64 {
return 0
}
+func (x *SetConfigRequest) GetEnableSSHRoot() bool {
+ if x != nil && x.EnableSSHRoot != nil {
+ return *x.EnableSSHRoot
+ }
+ return false
+}
+
+func (x *SetConfigRequest) GetEnableSSHSFTP() bool {
+ if x != nil && x.EnableSSHSFTP != nil {
+ return *x.EnableSSHSFTP
+ }
+ return false
+}
+
+func (x *SetConfigRequest) GetEnableSSHLocalPortForwarding() bool {
+ if x != nil && x.EnableSSHLocalPortForwarding != nil {
+ return *x.EnableSSHLocalPortForwarding
+ }
+ return false
+}
+
+func (x *SetConfigRequest) GetEnableSSHRemotePortForwarding() bool {
+ if x != nil && x.EnableSSHRemotePortForwarding != nil {
+ return *x.EnableSSHRemotePortForwarding
+ }
+ return false
+}
+
+func (x *SetConfigRequest) GetDisableSSHAuth() bool {
+ if x != nil && x.DisableSSHAuth != nil {
+ return *x.DisableSSHAuth
+ }
+ return false
+}
+
+func (x *SetConfigRequest) GetSshJWTCacheTTL() int32 {
+ if x != nil && x.SshJWTCacheTTL != nil {
+ return *x.SshJWTCacheTTL
+ }
+ return 0
+}
+
type SetConfigResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -3952,7 +4380,7 @@ type SetConfigResponse struct {
func (x *SetConfigResponse) Reset() {
*x = SetConfigResponse{}
- mi := &file_daemon_proto_msgTypes[55]
+ mi := &file_daemon_proto_msgTypes[59]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3964,7 +4392,7 @@ func (x *SetConfigResponse) String() string {
func (*SetConfigResponse) ProtoMessage() {}
func (x *SetConfigResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[55]
+ mi := &file_daemon_proto_msgTypes[59]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3977,7 +4405,7 @@ func (x *SetConfigResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use SetConfigResponse.ProtoReflect.Descriptor instead.
func (*SetConfigResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{55}
+ return file_daemon_proto_rawDescGZIP(), []int{59}
}
type AddProfileRequest struct {
@@ -3990,7 +4418,7 @@ type AddProfileRequest struct {
func (x *AddProfileRequest) Reset() {
*x = AddProfileRequest{}
- mi := &file_daemon_proto_msgTypes[56]
+ mi := &file_daemon_proto_msgTypes[60]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -4002,7 +4430,7 @@ func (x *AddProfileRequest) String() string {
func (*AddProfileRequest) ProtoMessage() {}
func (x *AddProfileRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[56]
+ mi := &file_daemon_proto_msgTypes[60]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -4015,7 +4443,7 @@ func (x *AddProfileRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use AddProfileRequest.ProtoReflect.Descriptor instead.
func (*AddProfileRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{56}
+ return file_daemon_proto_rawDescGZIP(), []int{60}
}
func (x *AddProfileRequest) GetUsername() string {
@@ -4040,7 +4468,7 @@ type AddProfileResponse struct {
func (x *AddProfileResponse) Reset() {
*x = AddProfileResponse{}
- mi := &file_daemon_proto_msgTypes[57]
+ mi := &file_daemon_proto_msgTypes[61]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -4052,7 +4480,7 @@ func (x *AddProfileResponse) String() string {
func (*AddProfileResponse) ProtoMessage() {}
func (x *AddProfileResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[57]
+ mi := &file_daemon_proto_msgTypes[61]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -4065,7 +4493,7 @@ func (x *AddProfileResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use AddProfileResponse.ProtoReflect.Descriptor instead.
func (*AddProfileResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{57}
+ return file_daemon_proto_rawDescGZIP(), []int{61}
}
type RemoveProfileRequest struct {
@@ -4078,7 +4506,7 @@ type RemoveProfileRequest struct {
func (x *RemoveProfileRequest) Reset() {
*x = RemoveProfileRequest{}
- mi := &file_daemon_proto_msgTypes[58]
+ mi := &file_daemon_proto_msgTypes[62]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -4090,7 +4518,7 @@ func (x *RemoveProfileRequest) String() string {
func (*RemoveProfileRequest) ProtoMessage() {}
func (x *RemoveProfileRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[58]
+ mi := &file_daemon_proto_msgTypes[62]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -4103,7 +4531,7 @@ func (x *RemoveProfileRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use RemoveProfileRequest.ProtoReflect.Descriptor instead.
func (*RemoveProfileRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{58}
+ return file_daemon_proto_rawDescGZIP(), []int{62}
}
func (x *RemoveProfileRequest) GetUsername() string {
@@ -4128,7 +4556,7 @@ type RemoveProfileResponse struct {
func (x *RemoveProfileResponse) Reset() {
*x = RemoveProfileResponse{}
- mi := &file_daemon_proto_msgTypes[59]
+ mi := &file_daemon_proto_msgTypes[63]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -4140,7 +4568,7 @@ func (x *RemoveProfileResponse) String() string {
func (*RemoveProfileResponse) ProtoMessage() {}
func (x *RemoveProfileResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[59]
+ mi := &file_daemon_proto_msgTypes[63]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -4153,7 +4581,7 @@ func (x *RemoveProfileResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use RemoveProfileResponse.ProtoReflect.Descriptor instead.
func (*RemoveProfileResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{59}
+ return file_daemon_proto_rawDescGZIP(), []int{63}
}
type ListProfilesRequest struct {
@@ -4165,7 +4593,7 @@ type ListProfilesRequest struct {
func (x *ListProfilesRequest) Reset() {
*x = ListProfilesRequest{}
- mi := &file_daemon_proto_msgTypes[60]
+ mi := &file_daemon_proto_msgTypes[64]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -4177,7 +4605,7 @@ func (x *ListProfilesRequest) String() string {
func (*ListProfilesRequest) ProtoMessage() {}
func (x *ListProfilesRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[60]
+ mi := &file_daemon_proto_msgTypes[64]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -4190,7 +4618,7 @@ func (x *ListProfilesRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use ListProfilesRequest.ProtoReflect.Descriptor instead.
func (*ListProfilesRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{60}
+ return file_daemon_proto_rawDescGZIP(), []int{64}
}
func (x *ListProfilesRequest) GetUsername() string {
@@ -4209,7 +4637,7 @@ type ListProfilesResponse struct {
func (x *ListProfilesResponse) Reset() {
*x = ListProfilesResponse{}
- mi := &file_daemon_proto_msgTypes[61]
+ mi := &file_daemon_proto_msgTypes[65]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -4221,7 +4649,7 @@ func (x *ListProfilesResponse) String() string {
func (*ListProfilesResponse) ProtoMessage() {}
func (x *ListProfilesResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[61]
+ mi := &file_daemon_proto_msgTypes[65]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -4234,7 +4662,7 @@ func (x *ListProfilesResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use ListProfilesResponse.ProtoReflect.Descriptor instead.
func (*ListProfilesResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{61}
+ return file_daemon_proto_rawDescGZIP(), []int{65}
}
func (x *ListProfilesResponse) GetProfiles() []*Profile {
@@ -4254,7 +4682,7 @@ type Profile struct {
func (x *Profile) Reset() {
*x = Profile{}
- mi := &file_daemon_proto_msgTypes[62]
+ mi := &file_daemon_proto_msgTypes[66]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -4266,7 +4694,7 @@ func (x *Profile) String() string {
func (*Profile) ProtoMessage() {}
func (x *Profile) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[62]
+ mi := &file_daemon_proto_msgTypes[66]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -4279,7 +4707,7 @@ func (x *Profile) ProtoReflect() protoreflect.Message {
// Deprecated: Use Profile.ProtoReflect.Descriptor instead.
func (*Profile) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{62}
+ return file_daemon_proto_rawDescGZIP(), []int{66}
}
func (x *Profile) GetName() string {
@@ -4304,7 +4732,7 @@ type GetActiveProfileRequest struct {
func (x *GetActiveProfileRequest) Reset() {
*x = GetActiveProfileRequest{}
- mi := &file_daemon_proto_msgTypes[63]
+ mi := &file_daemon_proto_msgTypes[67]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -4316,7 +4744,7 @@ func (x *GetActiveProfileRequest) String() string {
func (*GetActiveProfileRequest) ProtoMessage() {}
func (x *GetActiveProfileRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[63]
+ mi := &file_daemon_proto_msgTypes[67]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -4329,7 +4757,7 @@ func (x *GetActiveProfileRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetActiveProfileRequest.ProtoReflect.Descriptor instead.
func (*GetActiveProfileRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{63}
+ return file_daemon_proto_rawDescGZIP(), []int{67}
}
type GetActiveProfileResponse struct {
@@ -4342,7 +4770,7 @@ type GetActiveProfileResponse struct {
func (x *GetActiveProfileResponse) Reset() {
*x = GetActiveProfileResponse{}
- mi := &file_daemon_proto_msgTypes[64]
+ mi := &file_daemon_proto_msgTypes[68]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -4354,7 +4782,7 @@ func (x *GetActiveProfileResponse) String() string {
func (*GetActiveProfileResponse) ProtoMessage() {}
func (x *GetActiveProfileResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[64]
+ mi := &file_daemon_proto_msgTypes[68]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -4367,7 +4795,7 @@ func (x *GetActiveProfileResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetActiveProfileResponse.ProtoReflect.Descriptor instead.
func (*GetActiveProfileResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{64}
+ return file_daemon_proto_rawDescGZIP(), []int{68}
}
func (x *GetActiveProfileResponse) GetProfileName() string {
@@ -4394,7 +4822,7 @@ type LogoutRequest struct {
func (x *LogoutRequest) Reset() {
*x = LogoutRequest{}
- mi := &file_daemon_proto_msgTypes[65]
+ mi := &file_daemon_proto_msgTypes[69]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -4406,7 +4834,7 @@ func (x *LogoutRequest) String() string {
func (*LogoutRequest) ProtoMessage() {}
func (x *LogoutRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[65]
+ mi := &file_daemon_proto_msgTypes[69]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -4419,7 +4847,7 @@ func (x *LogoutRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use LogoutRequest.ProtoReflect.Descriptor instead.
func (*LogoutRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{65}
+ return file_daemon_proto_rawDescGZIP(), []int{69}
}
func (x *LogoutRequest) GetProfileName() string {
@@ -4444,7 +4872,7 @@ type LogoutResponse struct {
func (x *LogoutResponse) Reset() {
*x = LogoutResponse{}
- mi := &file_daemon_proto_msgTypes[66]
+ mi := &file_daemon_proto_msgTypes[70]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -4456,7 +4884,7 @@ func (x *LogoutResponse) String() string {
func (*LogoutResponse) ProtoMessage() {}
func (x *LogoutResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[66]
+ mi := &file_daemon_proto_msgTypes[70]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -4469,7 +4897,7 @@ func (x *LogoutResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use LogoutResponse.ProtoReflect.Descriptor instead.
func (*LogoutResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{66}
+ return file_daemon_proto_rawDescGZIP(), []int{70}
}
type GetFeaturesRequest struct {
@@ -4480,7 +4908,7 @@ type GetFeaturesRequest struct {
func (x *GetFeaturesRequest) Reset() {
*x = GetFeaturesRequest{}
- mi := &file_daemon_proto_msgTypes[67]
+ mi := &file_daemon_proto_msgTypes[71]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -4492,7 +4920,7 @@ func (x *GetFeaturesRequest) String() string {
func (*GetFeaturesRequest) ProtoMessage() {}
func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[67]
+ mi := &file_daemon_proto_msgTypes[71]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -4505,7 +4933,7 @@ func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetFeaturesRequest.ProtoReflect.Descriptor instead.
func (*GetFeaturesRequest) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{67}
+ return file_daemon_proto_rawDescGZIP(), []int{71}
}
type GetFeaturesResponse struct {
@@ -4518,7 +4946,7 @@ type GetFeaturesResponse struct {
func (x *GetFeaturesResponse) Reset() {
*x = GetFeaturesResponse{}
- mi := &file_daemon_proto_msgTypes[68]
+ mi := &file_daemon_proto_msgTypes[72]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -4530,7 +4958,7 @@ func (x *GetFeaturesResponse) String() string {
func (*GetFeaturesResponse) ProtoMessage() {}
func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[68]
+ mi := &file_daemon_proto_msgTypes[72]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -4543,7 +4971,7 @@ func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetFeaturesResponse.ProtoReflect.Descriptor instead.
func (*GetFeaturesResponse) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{68}
+ return file_daemon_proto_rawDescGZIP(), []int{72}
}
func (x *GetFeaturesResponse) GetDisableProfiles() bool {
@@ -4560,6 +4988,478 @@ func (x *GetFeaturesResponse) GetDisableUpdateSettings() bool {
return false
}
+// GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer
+type GetPeerSSHHostKeyRequest struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ // peer IP address or FQDN to get SSH host key for
+ PeerAddress string `protobuf:"bytes,1,opt,name=peerAddress,proto3" json:"peerAddress,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *GetPeerSSHHostKeyRequest) Reset() {
+ *x = GetPeerSSHHostKeyRequest{}
+ mi := &file_daemon_proto_msgTypes[73]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *GetPeerSSHHostKeyRequest) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*GetPeerSSHHostKeyRequest) ProtoMessage() {}
+
+func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message {
+ mi := &file_daemon_proto_msgTypes[73]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use GetPeerSSHHostKeyRequest.ProtoReflect.Descriptor instead.
+func (*GetPeerSSHHostKeyRequest) Descriptor() ([]byte, []int) {
+ return file_daemon_proto_rawDescGZIP(), []int{73}
+}
+
+func (x *GetPeerSSHHostKeyRequest) GetPeerAddress() string {
+ if x != nil {
+ return x.PeerAddress
+ }
+ return ""
+}
+
+// GetPeerSSHHostKeyResponse contains the SSH host key for the requested peer
+type GetPeerSSHHostKeyResponse struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ // SSH host key in SSH public key format (e.g., "ssh-ed25519 AAAAC3... hostname")
+ SshHostKey []byte `protobuf:"bytes,1,opt,name=sshHostKey,proto3" json:"sshHostKey,omitempty"`
+ // peer IP address
+ PeerIP string `protobuf:"bytes,2,opt,name=peerIP,proto3" json:"peerIP,omitempty"`
+ // peer FQDN
+ PeerFQDN string `protobuf:"bytes,3,opt,name=peerFQDN,proto3" json:"peerFQDN,omitempty"`
+ // indicates if the SSH host key was found
+ Found bool `protobuf:"varint,4,opt,name=found,proto3" json:"found,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *GetPeerSSHHostKeyResponse) Reset() {
+ *x = GetPeerSSHHostKeyResponse{}
+ mi := &file_daemon_proto_msgTypes[74]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *GetPeerSSHHostKeyResponse) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*GetPeerSSHHostKeyResponse) ProtoMessage() {}
+
+func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message {
+ mi := &file_daemon_proto_msgTypes[74]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use GetPeerSSHHostKeyResponse.ProtoReflect.Descriptor instead.
+func (*GetPeerSSHHostKeyResponse) Descriptor() ([]byte, []int) {
+ return file_daemon_proto_rawDescGZIP(), []int{74}
+}
+
+func (x *GetPeerSSHHostKeyResponse) GetSshHostKey() []byte {
+ if x != nil {
+ return x.SshHostKey
+ }
+ return nil
+}
+
+func (x *GetPeerSSHHostKeyResponse) GetPeerIP() string {
+ if x != nil {
+ return x.PeerIP
+ }
+ return ""
+}
+
+func (x *GetPeerSSHHostKeyResponse) GetPeerFQDN() string {
+ if x != nil {
+ return x.PeerFQDN
+ }
+ return ""
+}
+
+func (x *GetPeerSSHHostKeyResponse) GetFound() bool {
+ if x != nil {
+ return x.Found
+ }
+ return false
+}
+
+// RequestJWTAuthRequest for initiating JWT authentication flow
+type RequestJWTAuthRequest struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ // hint for OIDC login_hint parameter (typically email address)
+ Hint *string `protobuf:"bytes,1,opt,name=hint,proto3,oneof" json:"hint,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *RequestJWTAuthRequest) Reset() {
+ *x = RequestJWTAuthRequest{}
+ mi := &file_daemon_proto_msgTypes[75]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *RequestJWTAuthRequest) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*RequestJWTAuthRequest) ProtoMessage() {}
+
+func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message {
+ mi := &file_daemon_proto_msgTypes[75]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use RequestJWTAuthRequest.ProtoReflect.Descriptor instead.
+func (*RequestJWTAuthRequest) Descriptor() ([]byte, []int) {
+ return file_daemon_proto_rawDescGZIP(), []int{75}
+}
+
+func (x *RequestJWTAuthRequest) GetHint() string {
+ if x != nil && x.Hint != nil {
+ return *x.Hint
+ }
+ return ""
+}
+
+// RequestJWTAuthResponse contains authentication flow information
+type RequestJWTAuthResponse struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ // verification URI for user authentication
+ VerificationURI string `protobuf:"bytes,1,opt,name=verificationURI,proto3" json:"verificationURI,omitempty"`
+ // complete verification URI (with embedded user code)
+ VerificationURIComplete string `protobuf:"bytes,2,opt,name=verificationURIComplete,proto3" json:"verificationURIComplete,omitempty"`
+ // user code to enter on verification URI
+ UserCode string `protobuf:"bytes,3,opt,name=userCode,proto3" json:"userCode,omitempty"`
+ // device code for polling
+ DeviceCode string `protobuf:"bytes,4,opt,name=deviceCode,proto3" json:"deviceCode,omitempty"`
+ // expiration time in seconds
+ ExpiresIn int64 `protobuf:"varint,5,opt,name=expiresIn,proto3" json:"expiresIn,omitempty"`
+ // if a cached token is available, it will be returned here
+ CachedToken string `protobuf:"bytes,6,opt,name=cachedToken,proto3" json:"cachedToken,omitempty"`
+ // maximum age of JWT tokens in seconds (from management server)
+ MaxTokenAge int64 `protobuf:"varint,7,opt,name=maxTokenAge,proto3" json:"maxTokenAge,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *RequestJWTAuthResponse) Reset() {
+ *x = RequestJWTAuthResponse{}
+ mi := &file_daemon_proto_msgTypes[76]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *RequestJWTAuthResponse) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*RequestJWTAuthResponse) ProtoMessage() {}
+
+func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message {
+ mi := &file_daemon_proto_msgTypes[76]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use RequestJWTAuthResponse.ProtoReflect.Descriptor instead.
+func (*RequestJWTAuthResponse) Descriptor() ([]byte, []int) {
+ return file_daemon_proto_rawDescGZIP(), []int{76}
+}
+
+func (x *RequestJWTAuthResponse) GetVerificationURI() string {
+ if x != nil {
+ return x.VerificationURI
+ }
+ return ""
+}
+
+func (x *RequestJWTAuthResponse) GetVerificationURIComplete() string {
+ if x != nil {
+ return x.VerificationURIComplete
+ }
+ return ""
+}
+
+func (x *RequestJWTAuthResponse) GetUserCode() string {
+ if x != nil {
+ return x.UserCode
+ }
+ return ""
+}
+
+func (x *RequestJWTAuthResponse) GetDeviceCode() string {
+ if x != nil {
+ return x.DeviceCode
+ }
+ return ""
+}
+
+func (x *RequestJWTAuthResponse) GetExpiresIn() int64 {
+ if x != nil {
+ return x.ExpiresIn
+ }
+ return 0
+}
+
+func (x *RequestJWTAuthResponse) GetCachedToken() string {
+ if x != nil {
+ return x.CachedToken
+ }
+ return ""
+}
+
+func (x *RequestJWTAuthResponse) GetMaxTokenAge() int64 {
+ if x != nil {
+ return x.MaxTokenAge
+ }
+ return 0
+}
+
+// WaitJWTTokenRequest for waiting for authentication completion
+type WaitJWTTokenRequest struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ // device code from RequestJWTAuthResponse
+ DeviceCode string `protobuf:"bytes,1,opt,name=deviceCode,proto3" json:"deviceCode,omitempty"`
+ // user code for verification
+ UserCode string `protobuf:"bytes,2,opt,name=userCode,proto3" json:"userCode,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *WaitJWTTokenRequest) Reset() {
+ *x = WaitJWTTokenRequest{}
+ mi := &file_daemon_proto_msgTypes[77]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *WaitJWTTokenRequest) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*WaitJWTTokenRequest) ProtoMessage() {}
+
+func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message {
+ mi := &file_daemon_proto_msgTypes[77]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use WaitJWTTokenRequest.ProtoReflect.Descriptor instead.
+func (*WaitJWTTokenRequest) Descriptor() ([]byte, []int) {
+ return file_daemon_proto_rawDescGZIP(), []int{77}
+}
+
+func (x *WaitJWTTokenRequest) GetDeviceCode() string {
+ if x != nil {
+ return x.DeviceCode
+ }
+ return ""
+}
+
+func (x *WaitJWTTokenRequest) GetUserCode() string {
+ if x != nil {
+ return x.UserCode
+ }
+ return ""
+}
+
+// WaitJWTTokenResponse contains the JWT token after authentication
+type WaitJWTTokenResponse struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ // JWT token (access token or ID token)
+ Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"`
+ // token type (e.g., "Bearer")
+ TokenType string `protobuf:"bytes,2,opt,name=tokenType,proto3" json:"tokenType,omitempty"`
+ // expiration time in seconds
+ ExpiresIn int64 `protobuf:"varint,3,opt,name=expiresIn,proto3" json:"expiresIn,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *WaitJWTTokenResponse) Reset() {
+ *x = WaitJWTTokenResponse{}
+ mi := &file_daemon_proto_msgTypes[78]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *WaitJWTTokenResponse) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*WaitJWTTokenResponse) ProtoMessage() {}
+
+func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message {
+ mi := &file_daemon_proto_msgTypes[78]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use WaitJWTTokenResponse.ProtoReflect.Descriptor instead.
+func (*WaitJWTTokenResponse) Descriptor() ([]byte, []int) {
+ return file_daemon_proto_rawDescGZIP(), []int{78}
+}
+
+func (x *WaitJWTTokenResponse) GetToken() string {
+ if x != nil {
+ return x.Token
+ }
+ return ""
+}
+
+func (x *WaitJWTTokenResponse) GetTokenType() string {
+ if x != nil {
+ return x.TokenType
+ }
+ return ""
+}
+
+func (x *WaitJWTTokenResponse) GetExpiresIn() int64 {
+ if x != nil {
+ return x.ExpiresIn
+ }
+ return 0
+}
+
+type InstallerResultRequest struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *InstallerResultRequest) Reset() {
+ *x = InstallerResultRequest{}
+ mi := &file_daemon_proto_msgTypes[79]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *InstallerResultRequest) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*InstallerResultRequest) ProtoMessage() {}
+
+func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
+ mi := &file_daemon_proto_msgTypes[79]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead.
+func (*InstallerResultRequest) Descriptor() ([]byte, []int) {
+ return file_daemon_proto_rawDescGZIP(), []int{79}
+}
+
+type InstallerResultResponse struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
+ ErrorMsg string `protobuf:"bytes,2,opt,name=errorMsg,proto3" json:"errorMsg,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *InstallerResultResponse) Reset() {
+ *x = InstallerResultResponse{}
+ mi := &file_daemon_proto_msgTypes[80]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *InstallerResultResponse) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*InstallerResultResponse) ProtoMessage() {}
+
+func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
+ mi := &file_daemon_proto_msgTypes[80]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead.
+func (*InstallerResultResponse) Descriptor() ([]byte, []int) {
+ return file_daemon_proto_rawDescGZIP(), []int{80}
+}
+
+func (x *InstallerResultResponse) GetSuccess() bool {
+ if x != nil {
+ return x.Success
+ }
+ return false
+}
+
+func (x *InstallerResultResponse) GetErrorMsg() string {
+ if x != nil {
+ return x.ErrorMsg
+ }
+ return ""
+}
+
type PortInfo_Range struct {
state protoimpl.MessageState `protogen:"open.v1"`
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
@@ -4570,7 +5470,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{}
- mi := &file_daemon_proto_msgTypes[70]
+ mi := &file_daemon_proto_msgTypes[82]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -4582,7 +5482,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[70]
+ mi := &file_daemon_proto_msgTypes[82]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -4595,7 +5495,7 @@ func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
// Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead.
func (*PortInfo_Range) Descriptor() ([]byte, []int) {
- return file_daemon_proto_rawDescGZIP(), []int{26, 0}
+ return file_daemon_proto_rawDescGZIP(), []int{30, 0}
}
func (x *PortInfo_Range) GetStart() uint32 {
@@ -4617,7 +5517,15 @@ var File_daemon_proto protoreflect.FileDescriptor
const file_daemon_proto_rawDesc = "" +
"\n" +
"\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" +
- "\fEmptyRequest\"\xc3\x0e\n" +
+ "\fEmptyRequest\"\x7f\n" +
+ "\x12OSLifecycleRequest\x128\n" +
+ "\x04type\x18\x01 \x01(\x0e2$.daemon.OSLifecycleRequest.CycleTypeR\x04type\"/\n" +
+ "\tCycleType\x12\v\n" +
+ "\aUNKNOWN\x10\x00\x12\t\n" +
+ "\x05SLEEP\x10\x01\x12\n" +
+ "\n" +
+ "\x06WAKEUP\x10\x02\"\x15\n" +
+ "\x13OSLifecycleResponse\"\xb6\x12\n" +
"\fLoginRequest\x12\x1a\n" +
"\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" +
"\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" +
@@ -4654,7 +5562,14 @@ const file_daemon_proto_rawDesc = "" +
"\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" +
"\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" +
"\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" +
- "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01B\x13\n" +
+ "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01\x12\x17\n" +
+ "\x04hint\x18! \x01(\tH\x14R\x04hint\x88\x01\x01\x12)\n" +
+ "\renableSSHRoot\x18\" \x01(\bH\x15R\renableSSHRoot\x88\x01\x01\x12)\n" +
+ "\renableSSHSFTP\x18# \x01(\bH\x16R\renableSSHSFTP\x88\x01\x01\x12G\n" +
+ "\x1cenableSSHLocalPortForwarding\x18$ \x01(\bH\x17R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" +
+ "\x1denableSSHRemotePortForwarding\x18% \x01(\bH\x18R\x1denableSSHRemotePortForwarding\x88\x01\x01\x12+\n" +
+ "\x0edisableSSHAuth\x18& \x01(\bH\x19R\x0edisableSSHAuth\x88\x01\x01\x12+\n" +
+ "\x0esshJWTCacheTTL\x18' \x01(\x05H\x1aR\x0esshJWTCacheTTL\x88\x01\x01B\x13\n" +
"\x11_rosenpassEnabledB\x10\n" +
"\x0e_interfaceNameB\x10\n" +
"\x0e_wireguardPortB\x17\n" +
@@ -4674,7 +5589,14 @@ const file_daemon_proto_rawDesc = "" +
"\x0e_block_inboundB\x0e\n" +
"\f_profileNameB\v\n" +
"\t_usernameB\x06\n" +
- "\x04_mtu\"\xb5\x01\n" +
+ "\x04_mtuB\a\n" +
+ "\x05_hintB\x10\n" +
+ "\x0e_enableSSHRootB\x10\n" +
+ "\x0e_enableSSHSFTPB\x1f\n" +
+ "\x1d_enableSSHLocalPortForwardingB \n" +
+ "\x1e_enableSSHRemotePortForwardingB\x11\n" +
+ "\x0f_disableSSHAuthB\x11\n" +
+ "\x0f_sshJWTCacheTTL\"\xb5\x01\n" +
"\rLoginResponse\x12$\n" +
"\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" +
"\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" +
@@ -4684,12 +5606,16 @@ const file_daemon_proto_rawDesc = "" +
"\buserCode\x18\x01 \x01(\tR\buserCode\x12\x1a\n" +
"\bhostname\x18\x02 \x01(\tR\bhostname\",\n" +
"\x14WaitSSOLoginResponse\x12\x14\n" +
- "\x05email\x18\x01 \x01(\tR\x05email\"p\n" +
+ "\x05email\x18\x01 \x01(\tR\x05email\"\xa4\x01\n" +
"\tUpRequest\x12%\n" +
"\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" +
- "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
+ "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01\x12#\n" +
+ "\n" +
+ "autoUpdate\x18\x03 \x01(\bH\x02R\n" +
+ "autoUpdate\x88\x01\x01B\x0e\n" +
"\f_profileNameB\v\n" +
- "\t_username\"\f\n" +
+ "\t_usernameB\r\n" +
+ "\v_autoUpdate\"\f\n" +
"\n" +
"UpResponse\"\xa1\x01\n" +
"\rStatusRequest\x12,\n" +
@@ -4707,7 +5633,7 @@ const file_daemon_proto_rawDesc = "" +
"\fDownResponse\"P\n" +
"\x10GetConfigRequest\x12 \n" +
"\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" +
- "\busername\x18\x02 \x01(\tR\busername\"\xe0\x06\n" +
+ "\busername\x18\x02 \x01(\tR\busername\"\x86\t\n" +
"\x11GetConfigResponse\x12$\n" +
"\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" +
"\n" +
@@ -4732,8 +5658,14 @@ const file_daemon_proto_rawDesc = "" +
"disableDns\x122\n" +
"\x15disable_client_routes\x18\x12 \x01(\bR\x13disableClientRoutes\x122\n" +
"\x15disable_server_routes\x18\x13 \x01(\bR\x13disableServerRoutes\x12(\n" +
- "\x10block_lan_access\x18\x14 \x01(\bR\x0eblockLanAccess\x12)\n" +
- "\x10disable_firewall\x18\x15 \x01(\bR\x0fdisableFirewall\"\xde\x05\n" +
+ "\x10block_lan_access\x18\x14 \x01(\bR\x0eblockLanAccess\x12$\n" +
+ "\renableSSHRoot\x18\x15 \x01(\bR\renableSSHRoot\x12$\n" +
+ "\renableSSHSFTP\x18\x18 \x01(\bR\renableSSHSFTP\x12B\n" +
+ "\x1cenableSSHLocalPortForwarding\x18\x16 \x01(\bR\x1cenableSSHLocalPortForwarding\x12D\n" +
+ "\x1denableSSHRemotePortForwarding\x18\x17 \x01(\bR\x1denableSSHRemotePortForwarding\x12&\n" +
+ "\x0edisableSSHAuth\x18\x19 \x01(\bR\x0edisableSSHAuth\x12&\n" +
+ "\x0esshJWTCacheTTL\x18\x1a \x01(\x05R\x0esshJWTCacheTTL\x12)\n" +
+ "\x10disable_firewall\x18\x1b \x01(\bR\x0fdisableFirewall\"\xfe\x05\n" +
"\tPeerState\x12\x0e\n" +
"\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" +
"\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12\x1e\n" +
@@ -4754,7 +5686,10 @@ const file_daemon_proto_rawDesc = "" +
"\x10rosenpassEnabled\x18\x0f \x01(\bR\x10rosenpassEnabled\x12\x1a\n" +
"\bnetworks\x18\x10 \x03(\tR\bnetworks\x123\n" +
"\alatency\x18\x11 \x01(\v2\x19.google.protobuf.DurationR\alatency\x12\"\n" +
- "\frelayAddress\x18\x12 \x01(\tR\frelayAddress\"\xf0\x01\n" +
+ "\frelayAddress\x18\x12 \x01(\tR\frelayAddress\x12\x1e\n" +
+ "\n" +
+ "sshHostKey\x18\x13 \x01(\fR\n" +
+ "sshHostKey\"\xf0\x01\n" +
"\x0eLocalPeerState\x12\x0e\n" +
"\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" +
"\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12(\n" +
@@ -4780,7 +5715,15 @@ const file_daemon_proto_rawDesc = "" +
"\aservers\x18\x01 \x03(\tR\aservers\x12\x18\n" +
"\adomains\x18\x02 \x03(\tR\adomains\x12\x18\n" +
"\aenabled\x18\x03 \x01(\bR\aenabled\x12\x14\n" +
- "\x05error\x18\x04 \x01(\tR\x05error\"\xef\x03\n" +
+ "\x05error\x18\x04 \x01(\tR\x05error\"\x8e\x01\n" +
+ "\x0eSSHSessionInfo\x12\x1a\n" +
+ "\busername\x18\x01 \x01(\tR\busername\x12$\n" +
+ "\rremoteAddress\x18\x02 \x01(\tR\rremoteAddress\x12\x18\n" +
+ "\acommand\x18\x03 \x01(\tR\acommand\x12 \n" +
+ "\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\"^\n" +
+ "\x0eSSHServerState\x12\x18\n" +
+ "\aenabled\x18\x01 \x01(\bR\aenabled\x122\n" +
+ "\bsessions\x18\x02 \x03(\v2\x16.daemon.SSHSessionInfoR\bsessions\"\xaf\x04\n" +
"\n" +
"FullStatus\x12A\n" +
"\x0fmanagementState\x18\x01 \x01(\v2\x17.daemon.ManagementStateR\x0fmanagementState\x125\n" +
@@ -4792,7 +5735,9 @@ const file_daemon_proto_rawDesc = "" +
"dnsServers\x128\n" +
"\x17NumberOfForwardingRules\x18\b \x01(\x05R\x17NumberOfForwardingRules\x12+\n" +
"\x06events\x18\a \x03(\v2\x13.daemon.SystemEventR\x06events\x124\n" +
- "\x15lazyConnectionEnabled\x18\t \x01(\bR\x15lazyConnectionEnabled\"\x15\n" +
+ "\x15lazyConnectionEnabled\x18\t \x01(\bR\x15lazyConnectionEnabled\x12>\n" +
+ "\x0esshServerState\x18\n" +
+ " \x01(\v2\x16.daemon.SSHServerStateR\x0esshServerState\"\x15\n" +
"\x13ListNetworksRequest\"?\n" +
"\x14ListNetworksResponse\x12'\n" +
"\x06routes\x18\x01 \x03(\v2\x0f.daemon.NetworkR\x06routes\"a\n" +
@@ -4933,7 +5878,7 @@ const file_daemon_proto_rawDesc = "" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
"\f_profileNameB\v\n" +
"\t_username\"\x17\n" +
- "\x15SwitchProfileResponse\"\x8e\r\n" +
+ "\x15SwitchProfileResponse\"\xdf\x10\n" +
"\x10SetConfigRequest\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\x12 \n" +
"\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" +
@@ -4966,7 +5911,13 @@ const file_daemon_proto_rawDesc = "" +
"dns_labels\x18\x19 \x03(\tR\tdnsLabels\x12&\n" +
"\x0ecleanDNSLabels\x18\x1a \x01(\bR\x0ecleanDNSLabels\x12J\n" +
"\x10dnsRouteInterval\x18\x1b \x01(\v2\x19.google.protobuf.DurationH\x10R\x10dnsRouteInterval\x88\x01\x01\x12\x15\n" +
- "\x03mtu\x18\x1c \x01(\x03H\x11R\x03mtu\x88\x01\x01B\x13\n" +
+ "\x03mtu\x18\x1c \x01(\x03H\x11R\x03mtu\x88\x01\x01\x12)\n" +
+ "\renableSSHRoot\x18\x1d \x01(\bH\x12R\renableSSHRoot\x88\x01\x01\x12)\n" +
+ "\renableSSHSFTP\x18\x1e \x01(\bH\x13R\renableSSHSFTP\x88\x01\x01\x12G\n" +
+ "\x1cenableSSHLocalPortForwarding\x18\x1f \x01(\bH\x14R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" +
+ "\x1denableSSHRemotePortForwarding\x18 \x01(\bH\x15R\x1denableSSHRemotePortForwarding\x88\x01\x01\x12+\n" +
+ "\x0edisableSSHAuth\x18! \x01(\bH\x16R\x0edisableSSHAuth\x88\x01\x01\x12+\n" +
+ "\x0esshJWTCacheTTL\x18\" \x01(\x05H\x17R\x0esshJWTCacheTTL\x88\x01\x01B\x13\n" +
"\x11_rosenpassEnabledB\x10\n" +
"\x0e_interfaceNameB\x10\n" +
"\x0e_wireguardPortB\x17\n" +
@@ -4984,7 +5935,13 @@ const file_daemon_proto_rawDesc = "" +
"\x16_lazyConnectionEnabledB\x10\n" +
"\x0e_block_inboundB\x13\n" +
"\x11_dnsRouteIntervalB\x06\n" +
- "\x04_mtu\"\x13\n" +
+ "\x04_mtuB\x10\n" +
+ "\x0e_enableSSHRootB\x10\n" +
+ "\x0e_enableSSHSFTPB\x1f\n" +
+ "\x1d_enableSSHLocalPortForwardingB \n" +
+ "\x1e_enableSSHRemotePortForwardingB\x11\n" +
+ "\x0f_disableSSHAuthB\x11\n" +
+ "\x0f_sshJWTCacheTTL\"\x13\n" +
"\x11SetConfigResponse\"Q\n" +
"\x11AddProfileRequest\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\x12 \n" +
@@ -5014,7 +5971,42 @@ const file_daemon_proto_rawDesc = "" +
"\x12GetFeaturesRequest\"x\n" +
"\x13GetFeaturesResponse\x12)\n" +
"\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" +
- "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings*b\n" +
+ "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\"<\n" +
+ "\x18GetPeerSSHHostKeyRequest\x12 \n" +
+ "\vpeerAddress\x18\x01 \x01(\tR\vpeerAddress\"\x85\x01\n" +
+ "\x19GetPeerSSHHostKeyResponse\x12\x1e\n" +
+ "\n" +
+ "sshHostKey\x18\x01 \x01(\fR\n" +
+ "sshHostKey\x12\x16\n" +
+ "\x06peerIP\x18\x02 \x01(\tR\x06peerIP\x12\x1a\n" +
+ "\bpeerFQDN\x18\x03 \x01(\tR\bpeerFQDN\x12\x14\n" +
+ "\x05found\x18\x04 \x01(\bR\x05found\"9\n" +
+ "\x15RequestJWTAuthRequest\x12\x17\n" +
+ "\x04hint\x18\x01 \x01(\tH\x00R\x04hint\x88\x01\x01B\a\n" +
+ "\x05_hint\"\x9a\x02\n" +
+ "\x16RequestJWTAuthResponse\x12(\n" +
+ "\x0fverificationURI\x18\x01 \x01(\tR\x0fverificationURI\x128\n" +
+ "\x17verificationURIComplete\x18\x02 \x01(\tR\x17verificationURIComplete\x12\x1a\n" +
+ "\buserCode\x18\x03 \x01(\tR\buserCode\x12\x1e\n" +
+ "\n" +
+ "deviceCode\x18\x04 \x01(\tR\n" +
+ "deviceCode\x12\x1c\n" +
+ "\texpiresIn\x18\x05 \x01(\x03R\texpiresIn\x12 \n" +
+ "\vcachedToken\x18\x06 \x01(\tR\vcachedToken\x12 \n" +
+ "\vmaxTokenAge\x18\a \x01(\x03R\vmaxTokenAge\"Q\n" +
+ "\x13WaitJWTTokenRequest\x12\x1e\n" +
+ "\n" +
+ "deviceCode\x18\x01 \x01(\tR\n" +
+ "deviceCode\x12\x1a\n" +
+ "\buserCode\x18\x02 \x01(\tR\buserCode\"h\n" +
+ "\x14WaitJWTTokenResponse\x12\x14\n" +
+ "\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" +
+ "\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" +
+ "\texpiresIn\x18\x03 \x01(\x03R\texpiresIn\"\x18\n" +
+ "\x16InstallerResultRequest\"O\n" +
+ "\x17InstallerResultResponse\x12\x18\n" +
+ "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
+ "\berrorMsg\x18\x02 \x01(\tR\berrorMsg*b\n" +
"\bLogLevel\x12\v\n" +
"\aUNKNOWN\x10\x00\x12\t\n" +
"\x05PANIC\x10\x01\x12\t\n" +
@@ -5023,7 +6015,7 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" +
- "\x05TRACE\x10\a2\x8f\x10\n" +
+ "\x05TRACE\x10\a2\xb4\x13\n" +
"\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -5055,7 +6047,12 @@ const file_daemon_proto_rawDesc = "" +
"\fListProfiles\x12\x1b.daemon.ListProfilesRequest\x1a\x1c.daemon.ListProfilesResponse\"\x00\x12W\n" +
"\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" +
"\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00\x12H\n" +
- "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00B\bZ\x06/protob\x06proto3"
+ "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" +
+ "\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" +
+ "\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" +
+ "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" +
+ "\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
+ "\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3"
var (
file_daemon_proto_rawDescOnce sync.Once
@@ -5069,180 +6066,206 @@ func file_daemon_proto_rawDescGZIP() []byte {
return file_daemon_proto_rawDescData
}
-var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3)
-var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 72)
+var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
+var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 84)
var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel
- (SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity
- (SystemEvent_Category)(0), // 2: daemon.SystemEvent.Category
- (*EmptyRequest)(nil), // 3: daemon.EmptyRequest
- (*LoginRequest)(nil), // 4: daemon.LoginRequest
- (*LoginResponse)(nil), // 5: daemon.LoginResponse
- (*WaitSSOLoginRequest)(nil), // 6: daemon.WaitSSOLoginRequest
- (*WaitSSOLoginResponse)(nil), // 7: daemon.WaitSSOLoginResponse
- (*UpRequest)(nil), // 8: daemon.UpRequest
- (*UpResponse)(nil), // 9: daemon.UpResponse
- (*StatusRequest)(nil), // 10: daemon.StatusRequest
- (*StatusResponse)(nil), // 11: daemon.StatusResponse
- (*DownRequest)(nil), // 12: daemon.DownRequest
- (*DownResponse)(nil), // 13: daemon.DownResponse
- (*GetConfigRequest)(nil), // 14: daemon.GetConfigRequest
- (*GetConfigResponse)(nil), // 15: daemon.GetConfigResponse
- (*PeerState)(nil), // 16: daemon.PeerState
- (*LocalPeerState)(nil), // 17: daemon.LocalPeerState
- (*SignalState)(nil), // 18: daemon.SignalState
- (*ManagementState)(nil), // 19: daemon.ManagementState
- (*RelayState)(nil), // 20: daemon.RelayState
- (*NSGroupState)(nil), // 21: daemon.NSGroupState
- (*FullStatus)(nil), // 22: daemon.FullStatus
- (*ListNetworksRequest)(nil), // 23: daemon.ListNetworksRequest
- (*ListNetworksResponse)(nil), // 24: daemon.ListNetworksResponse
- (*SelectNetworksRequest)(nil), // 25: daemon.SelectNetworksRequest
- (*SelectNetworksResponse)(nil), // 26: daemon.SelectNetworksResponse
- (*IPList)(nil), // 27: daemon.IPList
- (*Network)(nil), // 28: daemon.Network
- (*PortInfo)(nil), // 29: daemon.PortInfo
- (*ForwardingRule)(nil), // 30: daemon.ForwardingRule
- (*ForwardingRulesResponse)(nil), // 31: daemon.ForwardingRulesResponse
- (*DebugBundleRequest)(nil), // 32: daemon.DebugBundleRequest
- (*DebugBundleResponse)(nil), // 33: daemon.DebugBundleResponse
- (*GetLogLevelRequest)(nil), // 34: daemon.GetLogLevelRequest
- (*GetLogLevelResponse)(nil), // 35: daemon.GetLogLevelResponse
- (*SetLogLevelRequest)(nil), // 36: daemon.SetLogLevelRequest
- (*SetLogLevelResponse)(nil), // 37: daemon.SetLogLevelResponse
- (*State)(nil), // 38: daemon.State
- (*ListStatesRequest)(nil), // 39: daemon.ListStatesRequest
- (*ListStatesResponse)(nil), // 40: daemon.ListStatesResponse
- (*CleanStateRequest)(nil), // 41: daemon.CleanStateRequest
- (*CleanStateResponse)(nil), // 42: daemon.CleanStateResponse
- (*DeleteStateRequest)(nil), // 43: daemon.DeleteStateRequest
- (*DeleteStateResponse)(nil), // 44: daemon.DeleteStateResponse
- (*SetSyncResponsePersistenceRequest)(nil), // 45: daemon.SetSyncResponsePersistenceRequest
- (*SetSyncResponsePersistenceResponse)(nil), // 46: daemon.SetSyncResponsePersistenceResponse
- (*TCPFlags)(nil), // 47: daemon.TCPFlags
- (*TracePacketRequest)(nil), // 48: daemon.TracePacketRequest
- (*TraceStage)(nil), // 49: daemon.TraceStage
- (*TracePacketResponse)(nil), // 50: daemon.TracePacketResponse
- (*SubscribeRequest)(nil), // 51: daemon.SubscribeRequest
- (*SystemEvent)(nil), // 52: daemon.SystemEvent
- (*GetEventsRequest)(nil), // 53: daemon.GetEventsRequest
- (*GetEventsResponse)(nil), // 54: daemon.GetEventsResponse
- (*SwitchProfileRequest)(nil), // 55: daemon.SwitchProfileRequest
- (*SwitchProfileResponse)(nil), // 56: daemon.SwitchProfileResponse
- (*SetConfigRequest)(nil), // 57: daemon.SetConfigRequest
- (*SetConfigResponse)(nil), // 58: daemon.SetConfigResponse
- (*AddProfileRequest)(nil), // 59: daemon.AddProfileRequest
- (*AddProfileResponse)(nil), // 60: daemon.AddProfileResponse
- (*RemoveProfileRequest)(nil), // 61: daemon.RemoveProfileRequest
- (*RemoveProfileResponse)(nil), // 62: daemon.RemoveProfileResponse
- (*ListProfilesRequest)(nil), // 63: daemon.ListProfilesRequest
- (*ListProfilesResponse)(nil), // 64: daemon.ListProfilesResponse
- (*Profile)(nil), // 65: daemon.Profile
- (*GetActiveProfileRequest)(nil), // 66: daemon.GetActiveProfileRequest
- (*GetActiveProfileResponse)(nil), // 67: daemon.GetActiveProfileResponse
- (*LogoutRequest)(nil), // 68: daemon.LogoutRequest
- (*LogoutResponse)(nil), // 69: daemon.LogoutResponse
- (*GetFeaturesRequest)(nil), // 70: daemon.GetFeaturesRequest
- (*GetFeaturesResponse)(nil), // 71: daemon.GetFeaturesResponse
- nil, // 72: daemon.Network.ResolvedIPsEntry
- (*PortInfo_Range)(nil), // 73: daemon.PortInfo.Range
- nil, // 74: daemon.SystemEvent.MetadataEntry
- (*durationpb.Duration)(nil), // 75: google.protobuf.Duration
- (*timestamppb.Timestamp)(nil), // 76: google.protobuf.Timestamp
+ (OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType
+ (SystemEvent_Severity)(0), // 2: daemon.SystemEvent.Severity
+ (SystemEvent_Category)(0), // 3: daemon.SystemEvent.Category
+ (*EmptyRequest)(nil), // 4: daemon.EmptyRequest
+ (*OSLifecycleRequest)(nil), // 5: daemon.OSLifecycleRequest
+ (*OSLifecycleResponse)(nil), // 6: daemon.OSLifecycleResponse
+ (*LoginRequest)(nil), // 7: daemon.LoginRequest
+ (*LoginResponse)(nil), // 8: daemon.LoginResponse
+ (*WaitSSOLoginRequest)(nil), // 9: daemon.WaitSSOLoginRequest
+ (*WaitSSOLoginResponse)(nil), // 10: daemon.WaitSSOLoginResponse
+ (*UpRequest)(nil), // 11: daemon.UpRequest
+ (*UpResponse)(nil), // 12: daemon.UpResponse
+ (*StatusRequest)(nil), // 13: daemon.StatusRequest
+ (*StatusResponse)(nil), // 14: daemon.StatusResponse
+ (*DownRequest)(nil), // 15: daemon.DownRequest
+ (*DownResponse)(nil), // 16: daemon.DownResponse
+ (*GetConfigRequest)(nil), // 17: daemon.GetConfigRequest
+ (*GetConfigResponse)(nil), // 18: daemon.GetConfigResponse
+ (*PeerState)(nil), // 19: daemon.PeerState
+ (*LocalPeerState)(nil), // 20: daemon.LocalPeerState
+ (*SignalState)(nil), // 21: daemon.SignalState
+ (*ManagementState)(nil), // 22: daemon.ManagementState
+ (*RelayState)(nil), // 23: daemon.RelayState
+ (*NSGroupState)(nil), // 24: daemon.NSGroupState
+ (*SSHSessionInfo)(nil), // 25: daemon.SSHSessionInfo
+ (*SSHServerState)(nil), // 26: daemon.SSHServerState
+ (*FullStatus)(nil), // 27: daemon.FullStatus
+ (*ListNetworksRequest)(nil), // 28: daemon.ListNetworksRequest
+ (*ListNetworksResponse)(nil), // 29: daemon.ListNetworksResponse
+ (*SelectNetworksRequest)(nil), // 30: daemon.SelectNetworksRequest
+ (*SelectNetworksResponse)(nil), // 31: daemon.SelectNetworksResponse
+ (*IPList)(nil), // 32: daemon.IPList
+ (*Network)(nil), // 33: daemon.Network
+ (*PortInfo)(nil), // 34: daemon.PortInfo
+ (*ForwardingRule)(nil), // 35: daemon.ForwardingRule
+ (*ForwardingRulesResponse)(nil), // 36: daemon.ForwardingRulesResponse
+ (*DebugBundleRequest)(nil), // 37: daemon.DebugBundleRequest
+ (*DebugBundleResponse)(nil), // 38: daemon.DebugBundleResponse
+ (*GetLogLevelRequest)(nil), // 39: daemon.GetLogLevelRequest
+ (*GetLogLevelResponse)(nil), // 40: daemon.GetLogLevelResponse
+ (*SetLogLevelRequest)(nil), // 41: daemon.SetLogLevelRequest
+ (*SetLogLevelResponse)(nil), // 42: daemon.SetLogLevelResponse
+ (*State)(nil), // 43: daemon.State
+ (*ListStatesRequest)(nil), // 44: daemon.ListStatesRequest
+ (*ListStatesResponse)(nil), // 45: daemon.ListStatesResponse
+ (*CleanStateRequest)(nil), // 46: daemon.CleanStateRequest
+ (*CleanStateResponse)(nil), // 47: daemon.CleanStateResponse
+ (*DeleteStateRequest)(nil), // 48: daemon.DeleteStateRequest
+ (*DeleteStateResponse)(nil), // 49: daemon.DeleteStateResponse
+ (*SetSyncResponsePersistenceRequest)(nil), // 50: daemon.SetSyncResponsePersistenceRequest
+ (*SetSyncResponsePersistenceResponse)(nil), // 51: daemon.SetSyncResponsePersistenceResponse
+ (*TCPFlags)(nil), // 52: daemon.TCPFlags
+ (*TracePacketRequest)(nil), // 53: daemon.TracePacketRequest
+ (*TraceStage)(nil), // 54: daemon.TraceStage
+ (*TracePacketResponse)(nil), // 55: daemon.TracePacketResponse
+ (*SubscribeRequest)(nil), // 56: daemon.SubscribeRequest
+ (*SystemEvent)(nil), // 57: daemon.SystemEvent
+ (*GetEventsRequest)(nil), // 58: daemon.GetEventsRequest
+ (*GetEventsResponse)(nil), // 59: daemon.GetEventsResponse
+ (*SwitchProfileRequest)(nil), // 60: daemon.SwitchProfileRequest
+ (*SwitchProfileResponse)(nil), // 61: daemon.SwitchProfileResponse
+ (*SetConfigRequest)(nil), // 62: daemon.SetConfigRequest
+ (*SetConfigResponse)(nil), // 63: daemon.SetConfigResponse
+ (*AddProfileRequest)(nil), // 64: daemon.AddProfileRequest
+ (*AddProfileResponse)(nil), // 65: daemon.AddProfileResponse
+ (*RemoveProfileRequest)(nil), // 66: daemon.RemoveProfileRequest
+ (*RemoveProfileResponse)(nil), // 67: daemon.RemoveProfileResponse
+ (*ListProfilesRequest)(nil), // 68: daemon.ListProfilesRequest
+ (*ListProfilesResponse)(nil), // 69: daemon.ListProfilesResponse
+ (*Profile)(nil), // 70: daemon.Profile
+ (*GetActiveProfileRequest)(nil), // 71: daemon.GetActiveProfileRequest
+ (*GetActiveProfileResponse)(nil), // 72: daemon.GetActiveProfileResponse
+ (*LogoutRequest)(nil), // 73: daemon.LogoutRequest
+ (*LogoutResponse)(nil), // 74: daemon.LogoutResponse
+ (*GetFeaturesRequest)(nil), // 75: daemon.GetFeaturesRequest
+ (*GetFeaturesResponse)(nil), // 76: daemon.GetFeaturesResponse
+ (*GetPeerSSHHostKeyRequest)(nil), // 77: daemon.GetPeerSSHHostKeyRequest
+ (*GetPeerSSHHostKeyResponse)(nil), // 78: daemon.GetPeerSSHHostKeyResponse
+ (*RequestJWTAuthRequest)(nil), // 79: daemon.RequestJWTAuthRequest
+ (*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
+ (*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
+ (*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
+ (*InstallerResultRequest)(nil), // 83: daemon.InstallerResultRequest
+ (*InstallerResultResponse)(nil), // 84: daemon.InstallerResultResponse
+ nil, // 85: daemon.Network.ResolvedIPsEntry
+ (*PortInfo_Range)(nil), // 86: daemon.PortInfo.Range
+ nil, // 87: daemon.SystemEvent.MetadataEntry
+ (*durationpb.Duration)(nil), // 88: google.protobuf.Duration
+ (*timestamppb.Timestamp)(nil), // 89: google.protobuf.Timestamp
}
var file_daemon_proto_depIdxs = []int32{
- 75, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
- 22, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
- 76, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
- 76, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
- 75, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
- 19, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
- 18, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState
- 17, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
- 16, // 8: daemon.FullStatus.peers:type_name -> daemon.PeerState
- 20, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState
- 21, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
- 52, // 11: daemon.FullStatus.events:type_name -> daemon.SystemEvent
- 28, // 12: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
- 72, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
- 73, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
- 29, // 15: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
- 29, // 16: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
- 30, // 17: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
- 0, // 18: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
- 0, // 19: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
- 38, // 20: daemon.ListStatesResponse.states:type_name -> daemon.State
- 47, // 21: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
- 49, // 22: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
- 1, // 23: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
- 2, // 24: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
- 76, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
- 74, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
- 52, // 27: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
- 75, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
- 65, // 29: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
- 27, // 30: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
- 4, // 31: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
- 6, // 32: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
- 8, // 33: daemon.DaemonService.Up:input_type -> daemon.UpRequest
- 10, // 34: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
- 12, // 35: daemon.DaemonService.Down:input_type -> daemon.DownRequest
- 14, // 36: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
- 23, // 37: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
- 25, // 38: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
- 25, // 39: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
- 3, // 40: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
- 32, // 41: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
- 34, // 42: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
- 36, // 43: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
- 39, // 44: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
- 41, // 45: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
- 43, // 46: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
- 45, // 47: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
- 48, // 48: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
- 51, // 49: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
- 53, // 50: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
- 55, // 51: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
- 57, // 52: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
- 59, // 53: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
- 61, // 54: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
- 63, // 55: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
- 66, // 56: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
- 68, // 57: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
- 70, // 58: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
- 5, // 59: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
- 7, // 60: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
- 9, // 61: daemon.DaemonService.Up:output_type -> daemon.UpResponse
- 11, // 62: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
- 13, // 63: daemon.DaemonService.Down:output_type -> daemon.DownResponse
- 15, // 64: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
- 24, // 65: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
- 26, // 66: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
- 26, // 67: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
- 31, // 68: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
- 33, // 69: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
- 35, // 70: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
- 37, // 71: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
- 40, // 72: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
- 42, // 73: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
- 44, // 74: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
- 46, // 75: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
- 50, // 76: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
- 52, // 77: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
- 54, // 78: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
- 56, // 79: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
- 58, // 80: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
- 60, // 81: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
- 62, // 82: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
- 64, // 83: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
- 67, // 84: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
- 69, // 85: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
- 71, // 86: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
- 59, // [59:87] is the sub-list for method output_type
- 31, // [31:59] is the sub-list for method input_type
- 31, // [31:31] is the sub-list for extension type_name
- 31, // [31:31] is the sub-list for extension extendee
- 0, // [0:31] is the sub-list for field type_name
+ 1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
+ 88, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
+ 27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
+ 89, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
+ 89, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
+ 88, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
+ 25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
+ 22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
+ 21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
+ 20, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
+ 19, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState
+ 23, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState
+ 24, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
+ 57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
+ 26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
+ 33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
+ 85, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
+ 86, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
+ 34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
+ 34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
+ 35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
+ 0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
+ 0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
+ 43, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State
+ 52, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
+ 54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
+ 2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
+ 3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
+ 89, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
+ 87, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
+ 57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
+ 88, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
+ 70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
+ 32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
+ 7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
+ 9, // 35: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
+ 11, // 36: daemon.DaemonService.Up:input_type -> daemon.UpRequest
+ 13, // 37: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
+ 15, // 38: daemon.DaemonService.Down:input_type -> daemon.DownRequest
+ 17, // 39: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
+ 28, // 40: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
+ 30, // 41: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
+ 30, // 42: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
+ 4, // 43: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
+ 37, // 44: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
+ 39, // 45: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
+ 41, // 46: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
+ 44, // 47: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
+ 46, // 48: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
+ 48, // 49: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
+ 50, // 50: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
+ 53, // 51: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
+ 56, // 52: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
+ 58, // 53: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
+ 60, // 54: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
+ 62, // 55: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
+ 64, // 56: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
+ 66, // 57: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
+ 68, // 58: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
+ 71, // 59: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
+ 73, // 60: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
+ 75, // 61: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
+ 77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
+ 79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
+ 81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
+ 5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
+ 83, // 66: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
+ 8, // 67: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
+ 10, // 68: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
+ 12, // 69: daemon.DaemonService.Up:output_type -> daemon.UpResponse
+ 14, // 70: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
+ 16, // 71: daemon.DaemonService.Down:output_type -> daemon.DownResponse
+ 18, // 72: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
+ 29, // 73: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
+ 31, // 74: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
+ 31, // 75: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
+ 36, // 76: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
+ 38, // 77: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
+ 40, // 78: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
+ 42, // 79: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
+ 45, // 80: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
+ 47, // 81: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
+ 49, // 82: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
+ 51, // 83: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
+ 55, // 84: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
+ 57, // 85: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
+ 59, // 86: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
+ 61, // 87: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
+ 63, // 88: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
+ 65, // 89: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
+ 67, // 90: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
+ 69, // 91: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
+ 72, // 92: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
+ 74, // 93: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
+ 76, // 94: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
+ 78, // 95: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
+ 80, // 96: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
+ 82, // 97: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
+ 6, // 98: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
+ 84, // 99: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
+ 67, // [67:100] is the sub-list for method output_type
+ 34, // [34:67] is the sub-list for method input_type
+ 34, // [34:34] is the sub-list for extension type_name
+ 34, // [34:34] is the sub-list for extension extendee
+ 0, // [0:34] is the sub-list for field type_name
}
func init() { file_daemon_proto_init() }
@@ -5250,25 +6273,26 @@ func file_daemon_proto_init() {
if File_daemon_proto != nil {
return
}
- file_daemon_proto_msgTypes[1].OneofWrappers = []any{}
- file_daemon_proto_msgTypes[5].OneofWrappers = []any{}
+ file_daemon_proto_msgTypes[3].OneofWrappers = []any{}
file_daemon_proto_msgTypes[7].OneofWrappers = []any{}
- file_daemon_proto_msgTypes[26].OneofWrappers = []any{
+ file_daemon_proto_msgTypes[9].OneofWrappers = []any{}
+ file_daemon_proto_msgTypes[30].OneofWrappers = []any{
(*PortInfo_Port)(nil),
(*PortInfo_Range_)(nil),
}
- file_daemon_proto_msgTypes[45].OneofWrappers = []any{}
- file_daemon_proto_msgTypes[46].OneofWrappers = []any{}
- file_daemon_proto_msgTypes[52].OneofWrappers = []any{}
- file_daemon_proto_msgTypes[54].OneofWrappers = []any{}
- file_daemon_proto_msgTypes[65].OneofWrappers = []any{}
+ file_daemon_proto_msgTypes[49].OneofWrappers = []any{}
+ file_daemon_proto_msgTypes[50].OneofWrappers = []any{}
+ file_daemon_proto_msgTypes[56].OneofWrappers = []any{}
+ file_daemon_proto_msgTypes[58].OneofWrappers = []any{}
+ file_daemon_proto_msgTypes[69].OneofWrappers = []any{}
+ file_daemon_proto_msgTypes[75].OneofWrappers = []any{}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
- NumEnums: 3,
- NumMessages: 72,
+ NumEnums: 4,
+ NumMessages: 84,
NumExtensions: 0,
NumServices: 1,
},
diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto
index 3bf86873c..fb34e959d 100644
--- a/client/proto/daemon.proto
+++ b/client/proto/daemon.proto
@@ -24,7 +24,7 @@ service DaemonService {
// Status of the service.
rpc Status(StatusRequest) returns (StatusResponse) {}
- // Down engine work in the daemon.
+ // Down stops engine work in the daemon.
rpc Down(DownRequest) returns (DownResponse) {}
// GetConfig of the daemon.
@@ -84,9 +84,37 @@ service DaemonService {
rpc Logout(LogoutRequest) returns (LogoutResponse) {}
rpc GetFeatures(GetFeaturesRequest) returns (GetFeaturesResponse) {}
+
+ // GetPeerSSHHostKey retrieves SSH host key for a specific peer
+ rpc GetPeerSSHHostKey(GetPeerSSHHostKeyRequest) returns (GetPeerSSHHostKeyResponse) {}
+
+ // RequestJWTAuth initiates JWT authentication flow for SSH
+ rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {}
+
+ // WaitJWTToken waits for JWT authentication completion
+ rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
+
+ rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
+
+ rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
}
+
+message OSLifecycleRequest {
+ // avoid collision with loglevel enum
+ enum CycleType {
+ UNKNOWN = 0;
+ SLEEP = 1;
+ WAKEUP = 2;
+ }
+
+ CycleType type = 1;
+}
+
+message OSLifecycleResponse {}
+
+
message LoginRequest {
// setupKey netbird setup key.
string setupKey = 1;
@@ -158,6 +186,16 @@ message LoginRequest {
optional string username = 31;
optional int64 mtu = 32;
+
+ // hint is used to pre-fill the email/username field during SSO authentication
+ optional string hint = 33;
+
+ optional bool enableSSHRoot = 34;
+ optional bool enableSSHSFTP = 35;
+ optional bool enableSSHLocalPortForwarding = 36;
+ optional bool enableSSHRemotePortForwarding = 37;
+ optional bool disableSSHAuth = 38;
+ optional int32 sshJWTCacheTTL = 39;
}
message LoginResponse {
@@ -179,15 +217,16 @@ message WaitSSOLoginResponse {
message UpRequest {
optional string profileName = 1;
optional string username = 2;
+ optional bool autoUpdate = 3;
}
message UpResponse {}
message StatusRequest{
bool getFullPeerStatus = 1;
- bool shouldRunProbes = 2;
+ bool shouldRunProbes = 2;
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
- optional bool waitForReady = 3;
+ optional bool waitForReady = 3;
}
message StatusResponse{
@@ -253,7 +292,19 @@ message GetConfigResponse {
bool block_lan_access = 20;
- bool disable_firewall = 21;
+ bool enableSSHRoot = 21;
+
+ bool enableSSHSFTP = 24;
+
+ bool enableSSHLocalPortForwarding = 22;
+
+ bool enableSSHRemotePortForwarding = 23;
+
+ bool disableSSHAuth = 25;
+
+ int32 sshJWTCacheTTL = 26;
+
+ bool disable_firewall = 27;
}
// PeerState contains the latest state of a peer
@@ -275,6 +326,7 @@ message PeerState {
repeated string networks = 16;
google.protobuf.Duration latency = 17;
string relayAddress = 18;
+ bytes sshHostKey = 19;
}
// LocalPeerState contains the latest state of the local peer
@@ -316,6 +368,20 @@ message NSGroupState {
string error = 4;
}
+// SSHSessionInfo contains information about an active SSH session
+message SSHSessionInfo {
+ string username = 1;
+ string remoteAddress = 2;
+ string command = 3;
+ string jwtUsername = 4;
+}
+
+// SSHServerState contains the latest state of the SSH server
+message SSHServerState {
+ bool enabled = 1;
+ repeated SSHSessionInfo sessions = 2;
+}
+
// FullStatus contains the full state held by the Status instance
message FullStatus {
ManagementState managementState = 1;
@@ -329,6 +395,7 @@ message FullStatus {
repeated SystemEvent events = 7;
bool lazyConnectionEnabled = 9;
+ SSHServerState sshServerState = 10;
}
// Networks
@@ -542,56 +609,63 @@ message SwitchProfileRequest {
message SwitchProfileResponse {}
message SetConfigRequest {
- string username = 1;
- string profileName = 2;
- // managementUrl to authenticate.
- string managementUrl = 3;
+ string username = 1;
+ string profileName = 2;
+ // managementUrl to authenticate.
+ string managementUrl = 3;
- // adminUrl to manage keys.
- string adminURL = 4;
+ // adminUrl to manage keys.
+ string adminURL = 4;
- optional bool rosenpassEnabled = 5;
+ optional bool rosenpassEnabled = 5;
- optional string interfaceName = 6;
+ optional string interfaceName = 6;
- optional int64 wireguardPort = 7;
+ optional int64 wireguardPort = 7;
- optional string optionalPreSharedKey = 8;
+ optional string optionalPreSharedKey = 8;
- optional bool disableAutoConnect = 9;
+ optional bool disableAutoConnect = 9;
- optional bool serverSSHAllowed = 10;
+ optional bool serverSSHAllowed = 10;
- optional bool rosenpassPermissive = 11;
+ optional bool rosenpassPermissive = 11;
- optional bool networkMonitor = 12;
+ optional bool networkMonitor = 12;
- optional bool disable_client_routes = 13;
- optional bool disable_server_routes = 14;
- optional bool disable_dns = 15;
- optional bool disable_firewall = 16;
- optional bool block_lan_access = 17;
+ optional bool disable_client_routes = 13;
+ optional bool disable_server_routes = 14;
+ optional bool disable_dns = 15;
+ optional bool disable_firewall = 16;
+ optional bool block_lan_access = 17;
- optional bool disable_notifications = 18;
+ optional bool disable_notifications = 18;
- optional bool lazyConnectionEnabled = 19;
+ optional bool lazyConnectionEnabled = 19;
- optional bool block_inbound = 20;
+ optional bool block_inbound = 20;
- repeated string natExternalIPs = 21;
- bool cleanNATExternalIPs = 22;
+ repeated string natExternalIPs = 21;
+ bool cleanNATExternalIPs = 22;
- bytes customDNSAddress = 23;
+ bytes customDNSAddress = 23;
- repeated string extraIFaceBlacklist = 24;
+ repeated string extraIFaceBlacklist = 24;
- repeated string dns_labels = 25;
- // cleanDNSLabels clean map list of DNS labels.
- bool cleanDNSLabels = 26;
+ repeated string dns_labels = 25;
+ // cleanDNSLabels clean map list of DNS labels.
+ bool cleanDNSLabels = 26;
- optional google.protobuf.Duration dnsRouteInterval = 27;
+ optional google.protobuf.Duration dnsRouteInterval = 27;
- optional int64 mtu = 28;
+ optional int64 mtu = 28;
+
+ optional bool enableSSHRoot = 29;
+ optional bool enableSSHSFTP = 30;
+ optional bool enableSSHLocalPortForwarding = 31;
+ optional bool enableSSHRemotePortForwarding = 32;
+ optional bool disableSSHAuth = 33;
+ optional int32 sshJWTCacheTTL = 34;
}
message SetConfigResponse{}
@@ -643,3 +717,71 @@ message GetFeaturesResponse{
bool disable_profiles = 1;
bool disable_update_settings = 2;
}
+
+// GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer
+message GetPeerSSHHostKeyRequest {
+ // peer IP address or FQDN to get SSH host key for
+ string peerAddress = 1;
+}
+
+// GetPeerSSHHostKeyResponse contains the SSH host key for the requested peer
+message GetPeerSSHHostKeyResponse {
+ // SSH host key in SSH public key format (e.g., "ssh-ed25519 AAAAC3... hostname")
+ bytes sshHostKey = 1;
+ // peer IP address
+ string peerIP = 2;
+ // peer FQDN
+ string peerFQDN = 3;
+ // indicates if the SSH host key was found
+ bool found = 4;
+}
+
+// RequestJWTAuthRequest for initiating JWT authentication flow
+message RequestJWTAuthRequest {
+ // hint for OIDC login_hint parameter (typically email address)
+ optional string hint = 1;
+}
+
+// RequestJWTAuthResponse contains authentication flow information
+message RequestJWTAuthResponse {
+ // verification URI for user authentication
+ string verificationURI = 1;
+ // complete verification URI (with embedded user code)
+ string verificationURIComplete = 2;
+ // user code to enter on verification URI
+ string userCode = 3;
+ // device code for polling
+ string deviceCode = 4;
+ // expiration time in seconds
+ int64 expiresIn = 5;
+ // if a cached token is available, it will be returned here
+ string cachedToken = 6;
+ // maximum age of JWT tokens in seconds (from management server)
+ int64 maxTokenAge = 7;
+}
+
+// WaitJWTTokenRequest for waiting for authentication completion
+message WaitJWTTokenRequest {
+ // device code from RequestJWTAuthResponse
+ string deviceCode = 1;
+ // user code for verification
+ string userCode = 2;
+}
+
+// WaitJWTTokenResponse contains the JWT token after authentication
+message WaitJWTTokenResponse {
+ // JWT token (access token or ID token)
+ string token = 1;
+ // token type (e.g., "Bearer")
+ string tokenType = 2;
+ // expiration time in seconds
+ int64 expiresIn = 3;
+}
+
+message InstallerResultRequest {
+}
+
+message InstallerResultResponse {
+ bool success = 1;
+ string errorMsg = 2;
+}
diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go
index bf7c9c7b3..fdabb1879 100644
--- a/client/proto/daemon_grpc.pb.go
+++ b/client/proto/daemon_grpc.pb.go
@@ -27,7 +27,7 @@ type DaemonServiceClient interface {
Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error)
// Status of the service.
Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error)
- // Down engine work in the daemon.
+ // Down stops engine work in the daemon.
Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error)
// GetConfig of the daemon.
GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error)
@@ -64,6 +64,14 @@ type DaemonServiceClient interface {
// Logout disconnects from the network and deletes the peer from the management server
Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error)
GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error)
+ // GetPeerSSHHostKey retrieves SSH host key for a specific peer
+ GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error)
+ // RequestJWTAuth initiates JWT authentication flow for SSH
+ RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
+ // WaitJWTToken waits for JWT authentication completion
+ WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
+ NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
+ GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
}
type daemonServiceClient struct {
@@ -349,6 +357,51 @@ func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRe
return out, nil
}
+func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) {
+ out := new(GetPeerSSHHostKeyResponse)
+ err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetPeerSSHHostKey", in, out, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+func (c *daemonServiceClient) RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) {
+ out := new(RequestJWTAuthResponse)
+ err := c.cc.Invoke(ctx, "/daemon.DaemonService/RequestJWTAuth", in, out, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) {
+ out := new(WaitJWTTokenResponse)
+ err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitJWTToken", in, out, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) {
+ out := new(OSLifecycleResponse)
+ err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) {
+ out := new(InstallerResultResponse)
+ err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetInstallerResult", in, out, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -362,7 +415,7 @@ type DaemonServiceServer interface {
Up(context.Context, *UpRequest) (*UpResponse, error)
// Status of the service.
Status(context.Context, *StatusRequest) (*StatusResponse, error)
- // Down engine work in the daemon.
+ // Down stops engine work in the daemon.
Down(context.Context, *DownRequest) (*DownResponse, error)
// GetConfig of the daemon.
GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error)
@@ -399,6 +452,14 @@ type DaemonServiceServer interface {
// Logout disconnects from the network and deletes the peer from the management server
Logout(context.Context, *LogoutRequest) (*LogoutResponse, error)
GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error)
+ // GetPeerSSHHostKey retrieves SSH host key for a specific peer
+ GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error)
+ // RequestJWTAuth initiates JWT authentication flow for SSH
+ RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
+ // WaitJWTToken waits for JWT authentication completion
+ WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
+ NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
+ GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -490,6 +551,21 @@ func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest)
func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetFeatures not implemented")
}
+func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) {
+ return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented")
+}
+func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) {
+ return nil, status.Errorf(codes.Unimplemented, "method RequestJWTAuth not implemented")
+}
+func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
+ return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
+}
+func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
+ return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
+}
+func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) {
+ return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented")
+}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -1010,6 +1086,96 @@ func _DaemonService_GetFeatures_Handler(srv interface{}, ctx context.Context, de
return interceptor(ctx, in, info, handler)
}
+func _DaemonService_GetPeerSSHHostKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(GetPeerSSHHostKeyRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/daemon.DaemonService/GetPeerSSHHostKey",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, req.(*GetPeerSSHHostKeyRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
+func _DaemonService_RequestJWTAuth_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(RequestJWTAuthRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(DaemonServiceServer).RequestJWTAuth(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/daemon.DaemonService/RequestJWTAuth",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(DaemonServiceServer).RequestJWTAuth(ctx, req.(*RequestJWTAuthRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
+func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(WaitJWTTokenRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(DaemonServiceServer).WaitJWTToken(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/daemon.DaemonService/WaitJWTToken",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(DaemonServiceServer).WaitJWTToken(ctx, req.(*WaitJWTTokenRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
+func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(OSLifecycleRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/daemon.DaemonService/NotifyOSLifecycle",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, req.(*OSLifecycleRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
+func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(InstallerResultRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(DaemonServiceServer).GetInstallerResult(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/daemon.DaemonService/GetInstallerResult",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(DaemonServiceServer).GetInstallerResult(ctx, req.(*InstallerResultRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -1125,6 +1291,26 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetFeatures",
Handler: _DaemonService_GetFeatures_Handler,
},
+ {
+ MethodName: "GetPeerSSHHostKey",
+ Handler: _DaemonService_GetPeerSSHHostKey_Handler,
+ },
+ {
+ MethodName: "RequestJWTAuth",
+ Handler: _DaemonService_RequestJWTAuth_Handler,
+ },
+ {
+ MethodName: "WaitJWTToken",
+ Handler: _DaemonService_WaitJWTToken_Handler,
+ },
+ {
+ MethodName: "NotifyOSLifecycle",
+ Handler: _DaemonService_NotifyOSLifecycle_Handler,
+ },
+ {
+ MethodName: "GetInstallerResult",
+ Handler: _DaemonService_GetInstallerResult_Handler,
+ },
},
Streams: []grpc.StreamDesc{
{
diff --git a/client/proto/generate.sh b/client/proto/generate.sh
index f9a2c3750..e659cef90 100755
--- a/client/proto/generate.sh
+++ b/client/proto/generate.sh
@@ -14,4 +14,4 @@ cd "$script_path"
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional
-cd "$old_pwd"
\ No newline at end of file
+cd "$old_pwd"
diff --git a/client/server/jwt_cache.go b/client/server/jwt_cache.go
new file mode 100644
index 000000000..21e170517
--- /dev/null
+++ b/client/server/jwt_cache.go
@@ -0,0 +1,79 @@
+package server
+
+import (
+ "sync"
+ "time"
+
+ "github.com/awnumar/memguard"
+ log "github.com/sirupsen/logrus"
+)
+
+type jwtCache struct {
+ mu sync.RWMutex
+ enclave *memguard.Enclave
+ expiresAt time.Time
+ timer *time.Timer
+ maxTokenSize int
+}
+
+func newJWTCache() *jwtCache {
+ return &jwtCache{
+ maxTokenSize: 8192,
+ }
+}
+
+func (c *jwtCache) store(token string, maxAge time.Duration) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.cleanup()
+
+ if c.timer != nil {
+ c.timer.Stop()
+ }
+
+ tokenBytes := []byte(token)
+ c.enclave = memguard.NewEnclave(tokenBytes)
+
+ c.expiresAt = time.Now().Add(maxAge)
+
+ var timer *time.Timer
+ timer = time.AfterFunc(maxAge, func() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.timer != timer {
+ return
+ }
+ c.cleanup()
+ c.timer = nil
+ log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge)
+ })
+ c.timer = timer
+}
+
+func (c *jwtCache) get() (string, bool) {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ if c.enclave == nil || time.Now().After(c.expiresAt) {
+ return "", false
+ }
+
+ buffer, err := c.enclave.Open()
+ if err != nil {
+ log.Debugf("Failed to open JWT token enclave: %v", err)
+ return "", false
+ }
+ defer buffer.Destroy()
+
+ token := string(buffer.Bytes())
+ return token, true
+}
+
+// cleanup destroys the secure enclave, must be called with lock held
+func (c *jwtCache) cleanup() {
+ if c.enclave != nil {
+ c.enclave = nil
+ }
+ c.expiresAt = time.Time{}
+}
diff --git a/client/server/lifecycle.go b/client/server/lifecycle.go
new file mode 100644
index 000000000..3722c027d
--- /dev/null
+++ b/client/server/lifecycle.go
@@ -0,0 +1,77 @@
+package server
+
+import (
+ "context"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/internal"
+ "github.com/netbirdio/netbird/client/proto"
+)
+
+// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
+func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
+ switch req.GetType() {
+ case proto.OSLifecycleRequest_WAKEUP:
+ return s.handleWakeUp(callerCtx)
+ case proto.OSLifecycleRequest_SLEEP:
+ return s.handleSleep(callerCtx)
+ default:
+ log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
+ }
+ return &proto.OSLifecycleResponse{}, nil
+}
+
+// handleWakeUp processes a wake-up event by triggering the Up command if the system was previously put to sleep.
+// It resets the sleep state and logs the process. Returns a response or an error if the Up command fails.
+func (s *Server) handleWakeUp(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
+ if !s.sleepTriggeredDown.Load() {
+ log.Info("skipping up because wasn't sleep down")
+ return &proto.OSLifecycleResponse{}, nil
+ }
+
+ // avoid other wakeup runs if sleep didn't make the computer sleep
+ s.sleepTriggeredDown.Store(false)
+
+ log.Info("running up after wake up")
+ _, err := s.Up(callerCtx, &proto.UpRequest{})
+ if err != nil {
+ log.Errorf("running up failed: %v", err)
+ return &proto.OSLifecycleResponse{}, err
+ }
+
+ log.Info("running up command executed successfully")
+ return &proto.OSLifecycleResponse{}, nil
+}
+
+// handleSleep handles the sleep event by initiating a "down" sequence if the system is in a connected or connecting state.
+func (s *Server) handleSleep(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
+ s.mutex.Lock()
+
+ state := internal.CtxGetState(s.rootCtx)
+ status, err := state.Status()
+ if err != nil {
+ s.mutex.Unlock()
+ return &proto.OSLifecycleResponse{}, err
+ }
+
+ if status != internal.StatusConnecting && status != internal.StatusConnected {
+ log.Infof("skipping setting the agent down because status is %s", status)
+ s.mutex.Unlock()
+ return &proto.OSLifecycleResponse{}, nil
+ }
+ s.mutex.Unlock()
+
+ log.Info("running down after system started sleeping")
+
+ _, err = s.Down(callerCtx, &proto.DownRequest{})
+ if err != nil {
+ log.Errorf("running down failed: %v", err)
+ return &proto.OSLifecycleResponse{}, err
+ }
+
+ s.sleepTriggeredDown.Store(true)
+
+ log.Info("running down executed successfully")
+ return &proto.OSLifecycleResponse{}, nil
+}
diff --git a/client/server/lifecycle_test.go b/client/server/lifecycle_test.go
new file mode 100644
index 000000000..a604c60af
--- /dev/null
+++ b/client/server/lifecycle_test.go
@@ -0,0 +1,219 @@
+package server
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/netbirdio/netbird/client/internal"
+ "github.com/netbirdio/netbird/client/internal/peer"
+ "github.com/netbirdio/netbird/client/proto"
+)
+
+func newTestServer() *Server {
+ ctx := internal.CtxInitState(context.Background())
+ return &Server{
+ rootCtx: ctx,
+ statusRecorder: peer.NewRecorder(""),
+ }
+}
+
+func TestNotifyOSLifecycle_WakeUp_SkipsWhenNotSleepTriggered(t *testing.T) {
+ s := newTestServer()
+
+ // sleepTriggeredDown is false by default
+ assert.False(t, s.sleepTriggeredDown.Load())
+
+ resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
+ Type: proto.OSLifecycleRequest_WAKEUP,
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, resp)
+ assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false")
+}
+
+func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusIdle(t *testing.T) {
+ s := newTestServer()
+
+ state := internal.CtxGetState(s.rootCtx)
+ state.Set(internal.StatusIdle)
+
+ resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
+ Type: proto.OSLifecycleRequest_SLEEP,
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, resp)
+ assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is Idle")
+}
+
+func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusNeedsLogin(t *testing.T) {
+ s := newTestServer()
+
+ state := internal.CtxGetState(s.rootCtx)
+ state.Set(internal.StatusNeedsLogin)
+
+ resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
+ Type: proto.OSLifecycleRequest_SLEEP,
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, resp)
+ assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is NeedsLogin")
+}
+
+func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnecting(t *testing.T) {
+ s := newTestServer()
+
+ state := internal.CtxGetState(s.rootCtx)
+ state.Set(internal.StatusConnecting)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ s.actCancel = cancel
+
+ resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
+ Type: proto.OSLifecycleRequest_SLEEP,
+ })
+
+ require.NoError(t, err)
+ assert.NotNil(t, resp, "handleSleep returns not nil response on success")
+ assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connecting")
+}
+
+func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnected(t *testing.T) {
+ s := newTestServer()
+
+ state := internal.CtxGetState(s.rootCtx)
+ state.Set(internal.StatusConnected)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ s.actCancel = cancel
+
+ resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
+ Type: proto.OSLifecycleRequest_SLEEP,
+ })
+
+ require.NoError(t, err)
+ assert.NotNil(t, resp, "handleSleep returns not nil response on success")
+ assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connected")
+}
+
+func TestNotifyOSLifecycle_WakeUp_ResetsFlag(t *testing.T) {
+ s := newTestServer()
+
+ // Manually set the flag to simulate prior sleep down
+ s.sleepTriggeredDown.Store(true)
+
+ // WakeUp will try to call Up which fails without proper setup, but flag should reset first
+ _, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
+ Type: proto.OSLifecycleRequest_WAKEUP,
+ })
+
+ assert.False(t, s.sleepTriggeredDown.Load(), "flag should be reset after WakeUp attempt")
+}
+
+func TestNotifyOSLifecycle_MultipleWakeUpCalls(t *testing.T) {
+ s := newTestServer()
+
+ // First wakeup without prior sleep - should be no-op
+ resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
+ Type: proto.OSLifecycleRequest_WAKEUP,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, resp)
+ assert.False(t, s.sleepTriggeredDown.Load())
+
+ // Simulate prior sleep
+ s.sleepTriggeredDown.Store(true)
+
+ // First wakeup after sleep - should reset flag
+ _, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
+ Type: proto.OSLifecycleRequest_WAKEUP,
+ })
+ assert.False(t, s.sleepTriggeredDown.Load())
+
+ // Second wakeup - should be no-op
+ resp, err = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
+ Type: proto.OSLifecycleRequest_WAKEUP,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, resp)
+ assert.False(t, s.sleepTriggeredDown.Load())
+}
+
+func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
+ s := newTestServer()
+
+ resp, err := s.handleWakeUp(context.Background())
+
+ require.NoError(t, err)
+ require.NotNil(t, resp)
+}
+
+func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
+ s := newTestServer()
+ s.sleepTriggeredDown.Store(true)
+
+ // Even if Up fails, flag should be reset
+ _, _ = s.handleWakeUp(context.Background())
+
+ assert.False(t, s.sleepTriggeredDown.Load(), "flag must be reset before calling Up")
+}
+
+func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
+ tests := []struct {
+ name string
+ status internal.StatusType
+ }{
+ {"Idle", internal.StatusIdle},
+ {"NeedsLogin", internal.StatusNeedsLogin},
+ {"LoginFailed", internal.StatusLoginFailed},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := newTestServer()
+ state := internal.CtxGetState(s.rootCtx)
+ state.Set(tt.status)
+
+ resp, err := s.handleSleep(context.Background())
+
+ require.NoError(t, err)
+ require.NotNil(t, resp)
+ assert.False(t, s.sleepTriggeredDown.Load())
+ })
+ }
+}
+
+func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
+ tests := []struct {
+ name string
+ status internal.StatusType
+ }{
+ {"Connecting", internal.StatusConnecting},
+ {"Connected", internal.StatusConnected},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := newTestServer()
+ state := internal.CtxGetState(s.rootCtx)
+ state.Set(tt.status)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ s.actCancel = cancel
+
+ resp, err := s.handleSleep(ctx)
+
+ require.NoError(t, err)
+ assert.NotNil(t, resp)
+ assert.True(t, s.sleepTriggeredDown.Load())
+ })
+ }
+}
diff --git a/client/server/network.go b/client/server/network.go
index 18b16795d..bb1cce56c 100644
--- a/client/server/network.go
+++ b/client/server/network.go
@@ -11,8 +11,8 @@ import (
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/proto"
- "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
+ "github.com/netbirdio/netbird/shared/management/domain"
)
type selectRoute struct {
diff --git a/client/server/server.go b/client/server/server.go
index 052809362..fbb3f0d52 100644
--- a/client/server/server.go
+++ b/client/server/server.go
@@ -46,6 +46,9 @@ const (
defaultMaxRetryTime = 14 * 24 * time.Hour
defaultRetryMultiplier = 1.7
+ // JWT token cache TTL for the client daemon (disabled by default)
+ defaultJWTCacheTTL = 0
+
errRestoreResidualState = "failed to restore residual state: %v"
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
@@ -81,6 +84,11 @@ type Server struct {
profileManager *profilemanager.ServiceManager
profilesDisabled bool
updateSettingsDisabled bool
+
+ // sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down
+ sleepTriggeredDown atomic.Bool
+
+ jwtCache *jwtCache
}
type oauthAuthFlow struct {
@@ -100,6 +108,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
profileManager: profilemanager.NewServiceManager(configFile),
profilesDisabled: profilesDisabled,
updateSettingsDisabled: updateSettingsDisabled,
+ jwtCache: newJWTCache(),
}
}
@@ -183,7 +192,7 @@ func (s *Server) Start() error {
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
- go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
+ go s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, s.clientRunningChan, s.clientGiveUpChan)
return nil
}
@@ -214,7 +223,7 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error {
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
-func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
+func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}, giveUpChan chan struct{}) {
defer func() {
s.mutex.Lock()
s.clientRunning = false
@@ -222,7 +231,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
}()
if s.config.DisableAutoConnect {
- if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
+ if err := s.connect(ctx, s.config, s.statusRecorder, doInitialAutoUpdate, runningChan); err != nil {
log.Debugf("run client connection exited with error: %v", err)
}
log.Tracef("client connection exited")
@@ -251,7 +260,8 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
}()
runOperation := func() error {
- err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
+ err := s.connect(ctx, profileConfig, statusRecorder, doInitialAutoUpdate, runningChan)
+ doInitialAutoUpdate = false
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
return err
@@ -353,6 +363,13 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.CustomDNSAddress = []byte{}
}
+ config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
+
+ if msg.DnsRouteInterval != nil {
+ interval := msg.DnsRouteInterval.AsDuration()
+ config.DNSRouteInterval = &interval
+ }
+
config.RosenpassEnabled = msg.RosenpassEnabled
config.RosenpassPermissive = msg.RosenpassPermissive
config.DisableAutoConnect = msg.DisableAutoConnect
@@ -366,6 +383,17 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.DisableNotifications = msg.DisableNotifications
config.LazyConnectionEnabled = msg.LazyConnectionEnabled
config.BlockInbound = msg.BlockInbound
+ config.EnableSSHRoot = msg.EnableSSHRoot
+ config.EnableSSHSFTP = msg.EnableSSHSFTP
+ config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForwarding
+ config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForwarding
+ if msg.DisableSSHAuth != nil {
+ config.DisableSSHAuth = msg.DisableSSHAuth
+ }
+ if msg.SshJWTCacheTTL != nil {
+ ttl := int(*msg.SshJWTCacheTTL)
+ config.SSHJWTCacheTTL = &ttl
+ }
if msg.Mtu != nil {
mtu := uint16(*msg.Mtu)
@@ -476,13 +504,17 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
state.Set(internal.StatusConnecting)
if msg.SetupKey == "" {
- oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient)
+ hint := ""
+ if msg.Hint != nil {
+ hint = *msg.Hint
+ }
+ oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, false, hint)
if err != nil {
state.Set(internal.StatusLoginFailed)
return nil, err
}
- if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) {
+ if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(ctx) {
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
log.Debugf("using previous oauth flow info")
return &proto.LoginResponse{
@@ -499,7 +531,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
}
}
- authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
+ authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
if err != nil {
log.Errorf("getting a request OAuth flow failed: %v", err)
return nil, err
@@ -697,7 +729,12 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
- go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
+
+ var doAutoUpdate bool
+ if msg != nil && msg.AutoUpdate != nil && *msg.AutoUpdate {
+ doAutoUpdate = true
+ }
+ go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan)
return s.waitForUp(callerCtx)
}
@@ -791,6 +828,7 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
defer s.mutex.Unlock()
if err := s.cleanupConnection(); err != nil {
+ // todo review to update the status in case any type of error
log.Errorf("failed to shut down properly: %v", err)
return nil, err
}
@@ -883,6 +921,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
}
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
+ // todo review to update the status in case any type of error
log.Errorf("failed to cleanup connection: %v", err)
return nil, err
}
@@ -1050,20 +1089,240 @@ func (s *Server) Status(
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
if msg.GetFullPeerStatus {
- if msg.ShouldRunProbes {
- s.runProbes()
- }
-
+ s.runProbes(msg.ShouldRunProbes)
fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus)
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
+
+ pbFullStatus.SshServerState = s.getSSHServerState()
+
statusResponse.FullStatus = pbFullStatus
}
return &statusResponse, nil
}
-func (s *Server) runProbes() {
+// getSSHServerState retrieves the current SSH server state including enabled status and active sessions
+func (s *Server) getSSHServerState() *proto.SSHServerState {
+ s.mutex.Lock()
+ connectClient := s.connectClient
+ s.mutex.Unlock()
+
+ if connectClient == nil {
+ return nil
+ }
+
+ engine := connectClient.Engine()
+ if engine == nil {
+ return nil
+ }
+
+ enabled, sessions := engine.GetSSHServerStatus()
+ sshServerState := &proto.SSHServerState{
+ Enabled: enabled,
+ }
+
+ for _, session := range sessions {
+ sshServerState.Sessions = append(sshServerState.Sessions, &proto.SSHSessionInfo{
+ Username: session.Username,
+ RemoteAddress: session.RemoteAddress,
+ Command: session.Command,
+ JwtUsername: session.JWTUsername,
+ })
+ }
+
+ return sshServerState
+}
+
+// GetPeerSSHHostKey retrieves SSH host key for a specific peer
+func (s *Server) GetPeerSSHHostKey(
+ ctx context.Context,
+ req *proto.GetPeerSSHHostKeyRequest,
+) (*proto.GetPeerSSHHostKeyResponse, error) {
+ if ctx.Err() != nil {
+ return nil, ctx.Err()
+ }
+
+ s.mutex.Lock()
+ connectClient := s.connectClient
+ statusRecorder := s.statusRecorder
+ s.mutex.Unlock()
+
+ if connectClient == nil {
+ return nil, errors.New("client not initialized")
+ }
+
+ engine := connectClient.Engine()
+ if engine == nil {
+ return nil, errors.New("engine not started")
+ }
+
+ peerAddress := req.GetPeerAddress()
+ hostKey, found := engine.GetPeerSSHKey(peerAddress)
+
+ response := &proto.GetPeerSSHHostKeyResponse{
+ Found: found,
+ }
+
+ if !found {
+ return response, nil
+ }
+
+ response.SshHostKey = hostKey
+
+ if statusRecorder == nil {
+ return response, nil
+ }
+
+ fullStatus := statusRecorder.GetFullStatus()
+ for _, peerState := range fullStatus.Peers {
+ if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
+ response.PeerIP = peerState.IP
+ response.PeerFQDN = peerState.FQDN
+ break
+ }
+ }
+
+ return response, nil
+}
+
+// getJWTCacheTTL returns the JWT cache TTL from config or default (disabled)
+func (s *Server) getJWTCacheTTL() time.Duration {
+ s.mutex.Lock()
+ config := s.config
+ s.mutex.Unlock()
+
+ if config == nil || config.SSHJWTCacheTTL == nil {
+ return defaultJWTCacheTTL
+ }
+
+ seconds := *config.SSHJWTCacheTTL
+ if seconds == 0 {
+ log.Debug("SSH JWT cache disabled (configured to 0)")
+ return 0
+ }
+
+ ttl := time.Duration(seconds) * time.Second
+ log.Debugf("SSH JWT cache TTL set to %v from config", ttl)
+ return ttl
+}
+
+// RequestJWTAuth initiates JWT authentication flow for SSH
+func (s *Server) RequestJWTAuth(
+ ctx context.Context,
+ msg *proto.RequestJWTAuthRequest,
+) (*proto.RequestJWTAuthResponse, error) {
+ if ctx.Err() != nil {
+ return nil, ctx.Err()
+ }
+
+ s.mutex.Lock()
+ config := s.config
+ s.mutex.Unlock()
+
+ if config == nil {
+ return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
+ }
+
+ jwtCacheTTL := s.getJWTCacheTTL()
+ if jwtCacheTTL > 0 {
+ if cachedToken, found := s.jwtCache.get(); found {
+ log.Debugf("JWT token found in cache, returning cached token for SSH authentication")
+
+ return &proto.RequestJWTAuthResponse{
+ CachedToken: cachedToken,
+ MaxTokenAge: int64(jwtCacheTTL.Seconds()),
+ }, nil
+ }
+ }
+
+ hint := ""
+ if msg.Hint != nil {
+ hint = *msg.Hint
+ }
+
+ if hint == "" {
+ hint = profilemanager.GetLoginHint()
+ }
+
+ isDesktop := isUnixRunningDesktop()
+ oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, false, hint)
+ if err != nil {
+ return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
+ }
+
+ authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
+ if err != nil {
+ return nil, gstatus.Errorf(codes.Internal, "failed to request auth info: %v", err)
+ }
+
+ s.mutex.Lock()
+ s.oauthAuthFlow.flow = oAuthFlow
+ s.oauthAuthFlow.info = authInfo
+ s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second)
+ s.mutex.Unlock()
+
+ return &proto.RequestJWTAuthResponse{
+ VerificationURI: authInfo.VerificationURI,
+ VerificationURIComplete: authInfo.VerificationURIComplete,
+ UserCode: authInfo.UserCode,
+ DeviceCode: authInfo.DeviceCode,
+ ExpiresIn: int64(authInfo.ExpiresIn),
+ MaxTokenAge: int64(jwtCacheTTL.Seconds()),
+ }, nil
+}
+
+// WaitJWTToken waits for JWT authentication completion
+func (s *Server) WaitJWTToken(
+ ctx context.Context,
+ req *proto.WaitJWTTokenRequest,
+) (*proto.WaitJWTTokenResponse, error) {
+ if ctx.Err() != nil {
+ return nil, ctx.Err()
+ }
+
+ s.mutex.Lock()
+ oAuthFlow := s.oauthAuthFlow.flow
+ authInfo := s.oauthAuthFlow.info
+ s.mutex.Unlock()
+
+ if oAuthFlow == nil || authInfo.DeviceCode != req.DeviceCode {
+ return nil, gstatus.Errorf(codes.InvalidArgument, "invalid device code or no active auth flow")
+ }
+
+ tokenInfo, err := oAuthFlow.WaitToken(ctx, authInfo)
+ if err != nil {
+ return nil, gstatus.Errorf(codes.Internal, "failed to get token: %v", err)
+ }
+
+ token := tokenInfo.GetTokenToUse()
+
+ jwtCacheTTL := s.getJWTCacheTTL()
+ if jwtCacheTTL > 0 {
+ s.jwtCache.store(token, jwtCacheTTL)
+ log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL)
+ } else {
+ log.Debug("JWT caching disabled, not storing token")
+ }
+
+ s.mutex.Lock()
+ s.oauthAuthFlow = oauthAuthFlow{}
+ s.mutex.Unlock()
+ return &proto.WaitJWTTokenResponse{
+ Token: tokenInfo.GetTokenToUse(),
+ TokenType: tokenInfo.TokenType,
+ ExpiresIn: int64(tokenInfo.ExpiresIn),
+ }, nil
+}
+
+func isUnixRunningDesktop() bool {
+ if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
+ return false
+ }
+ return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
+}
+
+func (s *Server) runProbes(waitForProbeResult bool) {
if s.connectClient == nil {
return
}
@@ -1074,7 +1333,7 @@ func (s *Server) runProbes() {
}
if time.Since(s.lastProbe) > probeThreshold {
- if engine.RunHealthProbes() {
+ if engine.RunHealthProbes(waitForProbeResult) {
s.lastProbe = time.Now()
}
}
@@ -1129,26 +1388,62 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
blockLANAccess := cfg.BlockLANAccess
disableFirewall := cfg.DisableFirewall
+ enableSSHRoot := false
+ if cfg.EnableSSHRoot != nil {
+ enableSSHRoot = *cfg.EnableSSHRoot
+ }
+
+ enableSSHSFTP := false
+ if cfg.EnableSSHSFTP != nil {
+ enableSSHSFTP = *cfg.EnableSSHSFTP
+ }
+
+ enableSSHLocalPortForwarding := false
+ if cfg.EnableSSHLocalPortForwarding != nil {
+ enableSSHLocalPortForwarding = *cfg.EnableSSHLocalPortForwarding
+ }
+
+ enableSSHRemotePortForwarding := false
+ if cfg.EnableSSHRemotePortForwarding != nil {
+ enableSSHRemotePortForwarding = *cfg.EnableSSHRemotePortForwarding
+ }
+
+ disableSSHAuth := false
+ if cfg.DisableSSHAuth != nil {
+ disableSSHAuth = *cfg.DisableSSHAuth
+ }
+
+ sshJWTCacheTTL := int32(0)
+ if cfg.SSHJWTCacheTTL != nil {
+ sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
+ }
+
return &proto.GetConfigResponse{
- ManagementUrl: managementURL.String(),
- PreSharedKey: preSharedKey,
- AdminURL: adminURL.String(),
- InterfaceName: cfg.WgIface,
- WireguardPort: int64(cfg.WgPort),
- Mtu: int64(cfg.MTU),
- DisableAutoConnect: cfg.DisableAutoConnect,
- ServerSSHAllowed: *cfg.ServerSSHAllowed,
- RosenpassEnabled: cfg.RosenpassEnabled,
- RosenpassPermissive: cfg.RosenpassPermissive,
- LazyConnectionEnabled: cfg.LazyConnectionEnabled,
- BlockInbound: cfg.BlockInbound,
- DisableNotifications: disableNotifications,
- NetworkMonitor: networkMonitor,
- DisableDns: disableDNS,
- DisableClientRoutes: disableClientRoutes,
- DisableServerRoutes: disableServerRoutes,
- BlockLanAccess: blockLANAccess,
- DisableFirewall: disableFirewall,
+ ManagementUrl: managementURL.String(),
+ PreSharedKey: preSharedKey,
+ AdminURL: adminURL.String(),
+ InterfaceName: cfg.WgIface,
+ WireguardPort: int64(cfg.WgPort),
+ Mtu: int64(cfg.MTU),
+ DisableAutoConnect: cfg.DisableAutoConnect,
+ ServerSSHAllowed: *cfg.ServerSSHAllowed,
+ RosenpassEnabled: cfg.RosenpassEnabled,
+ RosenpassPermissive: cfg.RosenpassPermissive,
+ LazyConnectionEnabled: cfg.LazyConnectionEnabled,
+ BlockInbound: cfg.BlockInbound,
+ DisableNotifications: disableNotifications,
+ NetworkMonitor: networkMonitor,
+ DisableDns: disableDNS,
+ DisableClientRoutes: disableClientRoutes,
+ DisableServerRoutes: disableServerRoutes,
+ BlockLanAccess: blockLANAccess,
+ EnableSSHRoot: enableSSHRoot,
+ EnableSSHSFTP: enableSSHSFTP,
+ EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
+ EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
+ DisableSSHAuth: disableSSHAuth,
+ SshJWTCacheTTL: sshJWTCacheTTL,
+ DisableFirewall: disableFirewall,
}, nil
}
@@ -1252,9 +1547,9 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
return features, nil
}
-func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
+func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}) error {
log.Tracef("running client connection")
- s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
+ s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
if err := s.connectClient.Run(runningChan); err != nil {
return err
@@ -1379,6 +1674,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
+ SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
diff --git a/client/server/server_test.go b/client/server/server_test.go
index e0a4805f6..69b4453ea 100644
--- a/client/server/server_test.go
+++ b/client/server/server_test.go
@@ -15,9 +15,14 @@ import (
"github.com/netbirdio/management-integrations/integrations"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
+ "github.com/netbirdio/netbird/management/internals/modules/peers"
+ "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
+ nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
+
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
- "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@@ -31,7 +36,6 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
- "github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -108,7 +112,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
- s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
+ s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, nil, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
@@ -290,7 +294,6 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
}
t.Cleanup(cleanUp)
- peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
@@ -311,13 +314,19 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
- accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
+ requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
+ peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
+ networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
+ accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
return nil, "", err
}
- secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
- mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
+ secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
+ if err != nil {
+ return nil, "", err
+ }
+ mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
if err != nil {
return nil, "", err
}
diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go
new file mode 100644
index 000000000..8e360175d
--- /dev/null
+++ b/client/server/setconfig_test.go
@@ -0,0 +1,314 @@
+package server
+
+import (
+ "context"
+ "os/user"
+ "path/filepath"
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+ "google.golang.org/protobuf/types/known/durationpb"
+
+ "github.com/netbirdio/netbird/client/internal/profilemanager"
+ "github.com/netbirdio/netbird/client/proto"
+)
+
+// TestSetConfig_AllFieldsSaved ensures that all fields in SetConfigRequest are properly saved to the config.
+// This test uses reflection to detect when new fields are added but not handled in SetConfig.
+func TestSetConfig_AllFieldsSaved(t *testing.T) {
+ tempDir := t.TempDir()
+ origDefaultProfileDir := profilemanager.DefaultConfigPathDir
+ origDefaultConfigPath := profilemanager.DefaultConfigPath
+ origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
+ profilemanager.ConfigDirOverride = tempDir
+ profilemanager.DefaultConfigPathDir = tempDir
+ profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
+ profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json")
+ t.Cleanup(func() {
+ profilemanager.DefaultConfigPathDir = origDefaultProfileDir
+ profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
+ profilemanager.DefaultConfigPath = origDefaultConfigPath
+ profilemanager.ConfigDirOverride = ""
+ })
+
+ currUser, err := user.Current()
+ require.NoError(t, err)
+
+ profName := "test-profile"
+
+ ic := profilemanager.ConfigInput{
+ ConfigPath: filepath.Join(tempDir, profName+".json"),
+ ManagementURL: "https://api.netbird.io:443",
+ }
+ _, err = profilemanager.UpdateOrCreateConfig(ic)
+ require.NoError(t, err)
+
+ pm := profilemanager.ServiceManager{}
+ err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
+ Name: profName,
+ Username: currUser.Username,
+ })
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ s := New(ctx, "console", "", false, false)
+
+ rosenpassEnabled := true
+ rosenpassPermissive := true
+ serverSSHAllowed := true
+ interfaceName := "utun100"
+ wireguardPort := int64(51820)
+ preSharedKey := "test-psk"
+ disableAutoConnect := true
+ networkMonitor := true
+ disableClientRoutes := true
+ disableServerRoutes := true
+ disableDNS := true
+ disableFirewall := true
+ blockLANAccess := true
+ disableNotifications := true
+ lazyConnectionEnabled := true
+ blockInbound := true
+ mtu := int64(1280)
+ sshJWTCacheTTL := int32(300)
+
+ req := &proto.SetConfigRequest{
+ ProfileName: profName,
+ Username: currUser.Username,
+ ManagementUrl: "https://new-api.netbird.io:443",
+ AdminURL: "https://new-admin.netbird.io",
+ RosenpassEnabled: &rosenpassEnabled,
+ RosenpassPermissive: &rosenpassPermissive,
+ ServerSSHAllowed: &serverSSHAllowed,
+ InterfaceName: &interfaceName,
+ WireguardPort: &wireguardPort,
+ OptionalPreSharedKey: &preSharedKey,
+ DisableAutoConnect: &disableAutoConnect,
+ NetworkMonitor: &networkMonitor,
+ DisableClientRoutes: &disableClientRoutes,
+ DisableServerRoutes: &disableServerRoutes,
+ DisableDns: &disableDNS,
+ DisableFirewall: &disableFirewall,
+ BlockLanAccess: &blockLANAccess,
+ DisableNotifications: &disableNotifications,
+ LazyConnectionEnabled: &lazyConnectionEnabled,
+ BlockInbound: &blockInbound,
+ NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
+ CleanNATExternalIPs: false,
+ CustomDNSAddress: []byte("1.1.1.1:53"),
+ ExtraIFaceBlacklist: []string{"eth1", "eth2"},
+ DnsLabels: []string{"label1", "label2"},
+ CleanDNSLabels: false,
+ DnsRouteInterval: durationpb.New(2 * time.Minute),
+ Mtu: &mtu,
+ SshJWTCacheTTL: &sshJWTCacheTTL,
+ }
+
+ _, err = s.SetConfig(ctx, req)
+ require.NoError(t, err)
+
+ profState := profilemanager.ActiveProfileState{
+ Name: profName,
+ Username: currUser.Username,
+ }
+ cfgPath, err := profState.FilePath()
+ require.NoError(t, err)
+
+ cfg, err := profilemanager.GetConfig(cfgPath)
+ require.NoError(t, err)
+
+ require.Equal(t, "https://new-api.netbird.io:443", cfg.ManagementURL.String())
+ require.Equal(t, "https://new-admin.netbird.io:443", cfg.AdminURL.String())
+ require.Equal(t, rosenpassEnabled, cfg.RosenpassEnabled)
+ require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
+ require.NotNil(t, cfg.ServerSSHAllowed)
+ require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
+ require.Equal(t, interfaceName, cfg.WgIface)
+ require.Equal(t, int(wireguardPort), cfg.WgPort)
+ require.Equal(t, preSharedKey, cfg.PreSharedKey)
+ require.Equal(t, disableAutoConnect, cfg.DisableAutoConnect)
+ require.NotNil(t, cfg.NetworkMonitor)
+ require.Equal(t, networkMonitor, *cfg.NetworkMonitor)
+ require.Equal(t, disableClientRoutes, cfg.DisableClientRoutes)
+ require.Equal(t, disableServerRoutes, cfg.DisableServerRoutes)
+ require.Equal(t, disableDNS, cfg.DisableDNS)
+ require.Equal(t, disableFirewall, cfg.DisableFirewall)
+ require.Equal(t, blockLANAccess, cfg.BlockLANAccess)
+ require.NotNil(t, cfg.DisableNotifications)
+ require.Equal(t, disableNotifications, *cfg.DisableNotifications)
+ require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled)
+ require.Equal(t, blockInbound, cfg.BlockInbound)
+ require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs)
+ require.Equal(t, "1.1.1.1:53", cfg.CustomDNSAddress)
+ // IFaceBlackList contains defaults + extras
+ require.Contains(t, cfg.IFaceBlackList, "eth1")
+ require.Contains(t, cfg.IFaceBlackList, "eth2")
+ require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList())
+ require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval)
+ require.Equal(t, uint16(mtu), cfg.MTU)
+ require.NotNil(t, cfg.SSHJWTCacheTTL)
+ require.Equal(t, int(sshJWTCacheTTL), *cfg.SSHJWTCacheTTL)
+
+ verifyAllFieldsCovered(t, req)
+}
+
+// verifyAllFieldsCovered uses reflection to ensure we're testing all fields in SetConfigRequest.
+// If a new field is added to SetConfigRequest, this function will fail the test,
+// forcing the developer to update both the SetConfig handler and this test.
+func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
+ t.Helper()
+
+ metadataFields := map[string]bool{
+ "state": true, // protobuf internal
+ "sizeCache": true, // protobuf internal
+ "unknownFields": true, // protobuf internal
+ "Username": true, // metadata
+ "ProfileName": true, // metadata
+ "CleanNATExternalIPs": true, // control flag for clearing
+ "CleanDNSLabels": true, // control flag for clearing
+ }
+
+ expectedFields := map[string]bool{
+ "ManagementUrl": true,
+ "AdminURL": true,
+ "RosenpassEnabled": true,
+ "RosenpassPermissive": true,
+ "ServerSSHAllowed": true,
+ "InterfaceName": true,
+ "WireguardPort": true,
+ "OptionalPreSharedKey": true,
+ "DisableAutoConnect": true,
+ "NetworkMonitor": true,
+ "DisableClientRoutes": true,
+ "DisableServerRoutes": true,
+ "DisableDns": true,
+ "DisableFirewall": true,
+ "BlockLanAccess": true,
+ "DisableNotifications": true,
+ "LazyConnectionEnabled": true,
+ "BlockInbound": true,
+ "NatExternalIPs": true,
+ "CustomDNSAddress": true,
+ "ExtraIFaceBlacklist": true,
+ "DnsLabels": true,
+ "DnsRouteInterval": true,
+ "Mtu": true,
+ "EnableSSHRoot": true,
+ "EnableSSHSFTP": true,
+ "EnableSSHLocalPortForwarding": true,
+ "EnableSSHRemotePortForwarding": true,
+ "DisableSSHAuth": true,
+ "SshJWTCacheTTL": true,
+ }
+
+ val := reflect.ValueOf(req).Elem()
+ typ := val.Type()
+
+ var unexpectedFields []string
+ for i := 0; i < val.NumField(); i++ {
+ field := typ.Field(i)
+ fieldName := field.Name
+
+ if metadataFields[fieldName] {
+ continue
+ }
+
+ if !expectedFields[fieldName] {
+ unexpectedFields = append(unexpectedFields, fieldName)
+ }
+ }
+
+ if len(unexpectedFields) > 0 {
+ t.Fatalf("New field(s) detected in SetConfigRequest: %v", unexpectedFields)
+ }
+}
+
+// TestCLIFlags_MappedToSetConfig ensures all CLI flags that modify config are properly mapped to SetConfigRequest.
+// This test catches bugs where a new CLI flag is added but not wired to the SetConfigRequest in setupSetConfigReq.
+func TestCLIFlags_MappedToSetConfig(t *testing.T) {
+ // Map of CLI flag names to their corresponding SetConfigRequest field names.
+ // This map must be updated when adding new config-related CLI flags.
+ flagToField := map[string]string{
+ "management-url": "ManagementUrl",
+ "admin-url": "AdminURL",
+ "enable-rosenpass": "RosenpassEnabled",
+ "rosenpass-permissive": "RosenpassPermissive",
+ "allow-server-ssh": "ServerSSHAllowed",
+ "interface-name": "InterfaceName",
+ "wireguard-port": "WireguardPort",
+ "preshared-key": "OptionalPreSharedKey",
+ "disable-auto-connect": "DisableAutoConnect",
+ "network-monitor": "NetworkMonitor",
+ "disable-client-routes": "DisableClientRoutes",
+ "disable-server-routes": "DisableServerRoutes",
+ "disable-dns": "DisableDns",
+ "disable-firewall": "DisableFirewall",
+ "block-lan-access": "BlockLanAccess",
+ "block-inbound": "BlockInbound",
+ "enable-lazy-connection": "LazyConnectionEnabled",
+ "external-ip-map": "NatExternalIPs",
+ "dns-resolver-address": "CustomDNSAddress",
+ "extra-iface-blacklist": "ExtraIFaceBlacklist",
+ "extra-dns-labels": "DnsLabels",
+ "dns-router-interval": "DnsRouteInterval",
+ "mtu": "Mtu",
+ "enable-ssh-root": "EnableSSHRoot",
+ "enable-ssh-sftp": "EnableSSHSFTP",
+ "enable-ssh-local-port-forwarding": "EnableSSHLocalPortForwarding",
+ "enable-ssh-remote-port-forwarding": "EnableSSHRemotePortForwarding",
+ "disable-ssh-auth": "DisableSSHAuth",
+ "ssh-jwt-cache-ttl": "SshJWTCacheTTL",
+ }
+
+ // SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).
+ fieldsWithoutCLIFlags := map[string]bool{
+ "DisableNotifications": true, // Only settable via UI
+ }
+
+ // Get all SetConfigRequest fields to verify our map is complete.
+ req := &proto.SetConfigRequest{}
+ val := reflect.ValueOf(req).Elem()
+ typ := val.Type()
+
+ var unmappedFields []string
+ for i := 0; i < val.NumField(); i++ {
+ field := typ.Field(i)
+ fieldName := field.Name
+
+ // Skip protobuf internal fields and metadata fields.
+ if fieldName == "state" || fieldName == "sizeCache" || fieldName == "unknownFields" {
+ continue
+ }
+ if fieldName == "Username" || fieldName == "ProfileName" {
+ continue
+ }
+ if fieldName == "CleanNATExternalIPs" || fieldName == "CleanDNSLabels" {
+ continue
+ }
+
+ // Check if this field is either mapped to a CLI flag or explicitly documented as having no CLI flag.
+ mappedToCLI := false
+ for _, mappedField := range flagToField {
+ if mappedField == fieldName {
+ mappedToCLI = true
+ break
+ }
+ }
+
+ hasNoCLIFlag := fieldsWithoutCLIFlags[fieldName]
+
+ if !mappedToCLI && !hasNoCLIFlag {
+ unmappedFields = append(unmappedFields, fieldName)
+ }
+ }
+
+ if len(unmappedFields) > 0 {
+ t.Fatalf("SetConfigRequest field(s) not documented: %v\n"+
+ "Either add the CLI flag to flagToField map, or if there's no CLI flag for this field, "+
+ "add it to fieldsWithoutCLIFlags map with a comment explaining why.", unmappedFields)
+ }
+
+ t.Log("All SetConfigRequest fields are properly documented")
+}
diff --git a/client/server/state.go b/client/server/state.go
index 107f55154..1cf85cd37 100644
--- a/client/server/state.go
+++ b/client/server/state.go
@@ -10,7 +10,9 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal"
+ "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
+ nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/proto"
)
@@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error {
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
}
+ // clean up any remaining routes independently of the state file
+ if !nbnet.AdvancedRouting() {
+ if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
+ merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
+ }
+ }
+
return nberrors.FormatErrorOrNil(merr)
}
diff --git a/client/server/state_generic.go b/client/server/state_generic.go
index e6c7bdd44..980ba0cda 100644
--- a/client/server/state_generic.go
+++ b/client/server/state_generic.go
@@ -6,9 +6,11 @@ import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
+ "github.com/netbirdio/netbird/client/ssh/config"
)
func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})
+ mgr.RegisterState(&config.ShutdownState{})
}
diff --git a/client/server/state_linux.go b/client/server/state_linux.go
index 087628907..019477d8e 100644
--- a/client/server/state_linux.go
+++ b/client/server/state_linux.go
@@ -8,6 +8,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
+ "github.com/netbirdio/netbird/client/ssh/config"
)
func registerStates(mgr *statemanager.Manager) {
@@ -15,4 +16,5 @@ func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&systemops.ShutdownState{})
mgr.RegisterState(&nftables.ShutdownState{})
mgr.RegisterState(&iptables.ShutdownState{})
+ mgr.RegisterState(&config.ShutdownState{})
}
diff --git a/client/server/updateresult.go b/client/server/updateresult.go
new file mode 100644
index 000000000..8e00d5062
--- /dev/null
+++ b/client/server/updateresult.go
@@ -0,0 +1,30 @@
+package server
+
+import (
+ "context"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/internal/updatemanager/installer"
+ "github.com/netbirdio/netbird/client/proto"
+)
+
+func (s *Server) GetInstallerResult(ctx context.Context, _ *proto.InstallerResultRequest) (*proto.InstallerResultResponse, error) {
+ inst := installer.New()
+ dir := inst.TempDir()
+
+ rh := installer.NewResultHandler(dir)
+ result, err := rh.Watch(ctx)
+ if err != nil {
+ log.Errorf("failed to watch update result: %v", err)
+ return &proto.InstallerResultResponse{
+ Success: false,
+ ErrorMsg: err.Error(),
+ }, nil
+ }
+
+ return &proto.InstallerResultResponse{
+ Success: result.Success,
+ ErrorMsg: result.Error,
+ }, nil
+}
diff --git a/client/ssh/auth/auth.go b/client/ssh/auth/auth.go
new file mode 100644
index 000000000..488b6e12e
--- /dev/null
+++ b/client/ssh/auth/auth.go
@@ -0,0 +1,184 @@
+package auth
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+
+ log "github.com/sirupsen/logrus"
+
+ sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
+)
+
+const (
+ // DefaultUserIDClaim is the default JWT claim used to extract user IDs
+ DefaultUserIDClaim = "sub"
+ // Wildcard is a special user ID that matches all users
+ Wildcard = "*"
+)
+
+var (
+ ErrEmptyUserID = errors.New("JWT user ID is empty")
+ ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
+ ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
+ ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
+)
+
+// Authorizer handles SSH fine-grained access control authorization
+type Authorizer struct {
+ // UserIDClaim is the JWT claim to extract the user ID from
+ userIDClaim string
+
+ // authorizedUsers is a list of hashed user IDs authorized to access this peer
+ authorizedUsers []sshuserhash.UserIDHash
+
+ // machineUsers maps OS login usernames to lists of authorized user indexes
+ machineUsers map[string][]uint32
+
+ // mu protects the list of users
+ mu sync.RWMutex
+}
+
+// Config contains configuration for the SSH authorizer
+type Config struct {
+ // UserIDClaim is the JWT claim to extract the user ID from (e.g., "sub", "email")
+ UserIDClaim string
+
+ // AuthorizedUsers is a list of hashed user IDs (FNV-1a 64-bit) authorized to access this peer
+ AuthorizedUsers []sshuserhash.UserIDHash
+
+ // MachineUsers maps OS login usernames to indexes in AuthorizedUsers
+ // If a user wants to login as a specific OS user, their index must be in the corresponding list
+ MachineUsers map[string][]uint32
+}
+
+// NewAuthorizer creates a new SSH authorizer with empty configuration
+func NewAuthorizer() *Authorizer {
+ a := &Authorizer{
+ userIDClaim: DefaultUserIDClaim,
+ machineUsers: make(map[string][]uint32),
+ }
+
+ return a
+}
+
+// Update updates the authorizer configuration with new values
+func (a *Authorizer) Update(config *Config) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ if config == nil {
+ // Clear authorization
+ a.userIDClaim = DefaultUserIDClaim
+ a.authorizedUsers = []sshuserhash.UserIDHash{}
+ a.machineUsers = make(map[string][]uint32)
+ log.Info("SSH authorization cleared")
+ return
+ }
+
+ userIDClaim := config.UserIDClaim
+ if userIDClaim == "" {
+ userIDClaim = DefaultUserIDClaim
+ }
+ a.userIDClaim = userIDClaim
+
+ // Store authorized users list
+ a.authorizedUsers = config.AuthorizedUsers
+
+ // Store machine users mapping
+ machineUsers := make(map[string][]uint32)
+ for osUser, indexes := range config.MachineUsers {
+ if len(indexes) > 0 {
+ machineUsers[osUser] = indexes
+ }
+ }
+ a.machineUsers = machineUsers
+
+ log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
+ len(config.AuthorizedUsers), len(machineUsers))
+}
+
+// Authorize validates if a user is authorized to login as the specified OS user
+// Returns nil if authorized, or an error describing why authorization failed
+func (a *Authorizer) Authorize(jwtUserID, osUsername string) error {
+ if jwtUserID == "" {
+ log.Warnf("SSH auth denied: JWT user ID is empty for OS user '%s'", osUsername)
+ return ErrEmptyUserID
+ }
+
+ // Hash the JWT user ID for comparison
+ hashedUserID, err := sshuserhash.HashUserID(jwtUserID)
+ if err != nil {
+ log.Errorf("SSH auth denied: failed to hash user ID '%s' for OS user '%s': %v", jwtUserID, osUsername, err)
+ return fmt.Errorf("failed to hash user ID: %w", err)
+ }
+
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ // Find the index of this user in the authorized list
+ userIndex, found := a.findUserIndex(hashedUserID)
+ if !found {
+ log.Warnf("SSH auth denied: user '%s' (hash: %s) not in authorized list for OS user '%s'", jwtUserID, hashedUserID, osUsername)
+ return ErrUserNotAuthorized
+ }
+
+ return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex)
+}
+
+// checkMachineUserMapping validates if a user's index is authorized for the specified OS user
+// Checks wildcard mapping first, then specific OS user mappings
+func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) error {
+ // If wildcard exists and user's index is in the wildcard list, allow access to any OS user
+ if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard {
+ if a.isIndexInList(uint32(userIndex), wildcardIndexes) {
+ log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' via wildcard (index: %d)", jwtUserID, osUsername, userIndex)
+ return nil
+ }
+ }
+
+ // Check for specific OS username mapping
+ allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername]
+ if !hasMachineUserMapping {
+ // No mapping for this OS user - deny by default (fail closed)
+ log.Warnf("SSH auth denied: no machine user mapping for OS user '%s' (JWT user: %s)", osUsername, jwtUserID)
+ return ErrNoMachineUserMapping
+ }
+
+ // Check if user's index is in the allowed indexes for this specific OS user
+ if !a.isIndexInList(uint32(userIndex), allowedIndexes) {
+ log.Warnf("SSH auth denied: user '%s' not mapped to OS user '%s' (user index: %d)", jwtUserID, osUsername, userIndex)
+ return ErrUserNotMappedToOSUser
+ }
+
+ log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' (index: %d)", jwtUserID, osUsername, userIndex)
+ return nil
+}
+
+// GetUserIDClaim returns the JWT claim name used to extract user IDs
+func (a *Authorizer) GetUserIDClaim() string {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+ return a.userIDClaim
+}
+
+// findUserIndex finds the index of a hashed user ID in the authorized users list
+// Returns the index and true if found, 0 and false if not found
+func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {
+ for i, id := range a.authorizedUsers {
+ if id == hashedUserID {
+ return i, true
+ }
+ }
+ return 0, false
+}
+
+// isIndexInList checks if an index exists in a list of indexes
+func (a *Authorizer) isIndexInList(index uint32, indexes []uint32) bool {
+ for _, idx := range indexes {
+ if idx == index {
+ return true
+ }
+ }
+ return false
+}
diff --git a/client/ssh/auth/auth_test.go b/client/ssh/auth/auth_test.go
new file mode 100644
index 000000000..2b3b5a414
--- /dev/null
+++ b/client/ssh/auth/auth_test.go
@@ -0,0 +1,612 @@
+package auth
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/netbirdio/netbird/shared/sshauth"
+)
+
+func TestAuthorizer_Authorize_UserNotInList(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Set up authorized users list with one user
+ authorizedUserHash, err := sshauth.HashUserID("authorized-user")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{authorizedUserHash},
+ MachineUsers: map[string][]uint32{},
+ }
+ authorizer.Update(config)
+
+ // Try to authorize a different user
+ err = authorizer.Authorize("unauthorized-user", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotAuthorized)
+}
+
+func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+ user2Hash, err := sshauth.HashUserID("user2")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
+ MachineUsers: map[string][]uint32{}, // Empty = deny all (fail closed)
+ }
+ authorizer.Update(config)
+
+ // All attempts should fail when no machine user mappings exist (fail closed)
+ err = authorizer.Authorize("user1", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+
+ err = authorizer.Authorize("user2", "admin")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+
+ err = authorizer.Authorize("user1", "postgres")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+}
+
+func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+ user2Hash, err := sshauth.HashUserID("user2")
+ require.NoError(t, err)
+ user3Hash, err := sshauth.HashUserID("user3")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
+ MachineUsers: map[string][]uint32{
+ "root": {0, 1}, // user1 and user2 can access root
+ "postgres": {1, 2}, // user2 and user3 can access postgres
+ "admin": {0}, // only user1 can access admin
+ },
+ }
+ authorizer.Update(config)
+
+ // user1 (index 0) should access root and admin
+ err = authorizer.Authorize("user1", "root")
+ assert.NoError(t, err)
+
+ err = authorizer.Authorize("user1", "admin")
+ assert.NoError(t, err)
+
+ // user2 (index 1) should access root and postgres
+ err = authorizer.Authorize("user2", "root")
+ assert.NoError(t, err)
+
+ err = authorizer.Authorize("user2", "postgres")
+ assert.NoError(t, err)
+
+ // user3 (index 2) should access postgres
+ err = authorizer.Authorize("user3", "postgres")
+ assert.NoError(t, err)
+}
+
+func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Set up authorized users list
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+ user2Hash, err := sshauth.HashUserID("user2")
+ require.NoError(t, err)
+ user3Hash, err := sshauth.HashUserID("user3")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
+ MachineUsers: map[string][]uint32{
+ "root": {0, 1}, // user1 and user2 can access root
+ "postgres": {1, 2}, // user2 and user3 can access postgres
+ "admin": {0}, // only user1 can access admin
+ },
+ }
+ authorizer.Update(config)
+
+ // user1 (index 0) should NOT access postgres
+ err = authorizer.Authorize("user1", "postgres")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
+
+ // user2 (index 1) should NOT access admin
+ err = authorizer.Authorize("user2", "admin")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
+
+ // user3 (index 2) should NOT access root
+ err = authorizer.Authorize("user3", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
+
+ // user3 (index 2) should NOT access admin
+ err = authorizer.Authorize("user3", "admin")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
+}
+
+func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Set up authorized users list
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
+ MachineUsers: map[string][]uint32{
+ "root": {0}, // only root is mapped
+ },
+ }
+ authorizer.Update(config)
+
+ // user1 should NOT access an unmapped OS user (fail closed)
+ err = authorizer.Authorize("user1", "postgres")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+}
+
+func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Set up authorized users list
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
+ MachineUsers: map[string][]uint32{},
+ }
+ authorizer.Update(config)
+
+ // Empty user ID should fail
+ err = authorizer.Authorize("", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrEmptyUserID)
+}
+
+func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Set up multiple authorized users
+ userHashes := make([]sshauth.UserIDHash, 10)
+ for i := 0; i < 10; i++ {
+ hash, err := sshauth.HashUserID("user" + string(rune('0'+i)))
+ require.NoError(t, err)
+ userHashes[i] = hash
+ }
+
+ // Create machine user mapping for all users
+ rootIndexes := make([]uint32, 10)
+ for i := 0; i < 10; i++ {
+ rootIndexes[i] = uint32(i)
+ }
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: userHashes,
+ MachineUsers: map[string][]uint32{
+ "root": rootIndexes,
+ },
+ }
+ authorizer.Update(config)
+
+ // All users should be authorized for root
+ for i := 0; i < 10; i++ {
+ err := authorizer.Authorize("user"+string(rune('0'+i)), "root")
+ assert.NoError(t, err, "user%d should be authorized", i)
+ }
+
+ // User not in list should fail
+ err := authorizer.Authorize("unknown-user", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotAuthorized)
+}
+
+func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Set up initial configuration
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
+ MachineUsers: map[string][]uint32{"root": {0}},
+ }
+ authorizer.Update(config)
+
+ // user1 should be authorized
+ err = authorizer.Authorize("user1", "root")
+ assert.NoError(t, err)
+
+ // Clear configuration
+ authorizer.Update(nil)
+
+ // user1 should no longer be authorized
+ err = authorizer.Authorize("user1", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotAuthorized)
+}
+
+func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+
+ // Machine users with empty index lists should be filtered out
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
+ MachineUsers: map[string][]uint32{
+ "root": {0},
+ "postgres": {}, // empty list - should be filtered out
+ "admin": nil, // nil list - should be filtered out
+ },
+ }
+ authorizer.Update(config)
+
+ // root should work
+ err = authorizer.Authorize("user1", "root")
+ assert.NoError(t, err)
+
+ // postgres should fail (no mapping)
+ err = authorizer.Authorize("user1", "postgres")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+
+ // admin should fail (no mapping)
+ err = authorizer.Authorize("user1", "admin")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+}
+
+func TestAuthorizer_CustomUserIDClaim(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Set up with custom user ID claim
+ user1Hash, err := sshauth.HashUserID("user@example.com")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: "email",
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
+ MachineUsers: map[string][]uint32{
+ "root": {0},
+ },
+ }
+ authorizer.Update(config)
+
+ // Verify the custom claim is set
+ assert.Equal(t, "email", authorizer.GetUserIDClaim())
+
+ // Authorize with email as user ID
+ err = authorizer.Authorize("user@example.com", "root")
+ assert.NoError(t, err)
+}
+
+func TestAuthorizer_DefaultUserIDClaim(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Verify default claim
+ assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
+ assert.Equal(t, "sub", authorizer.GetUserIDClaim())
+
+ // Set up with empty user ID claim (should use default)
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: "", // empty - should use default
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
+ MachineUsers: map[string][]uint32{},
+ }
+ authorizer.Update(config)
+
+ // Should fall back to default
+ assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
+}
+
+func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Create a large authorized users list
+ const numUsers = 1000
+ userHashes := make([]sshauth.UserIDHash, numUsers)
+ for i := 0; i < numUsers; i++ {
+ hash, err := sshauth.HashUserID("user" + string(rune(i)))
+ require.NoError(t, err)
+ userHashes[i] = hash
+ }
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: userHashes,
+ MachineUsers: map[string][]uint32{
+ "root": {0, 500, 999}, // first, middle, and last user
+ },
+ }
+ authorizer.Update(config)
+
+ // First user should have access
+ err := authorizer.Authorize("user"+string(rune(0)), "root")
+ assert.NoError(t, err)
+
+ // Middle user should have access
+ err = authorizer.Authorize("user"+string(rune(500)), "root")
+ assert.NoError(t, err)
+
+ // Last user should have access
+ err = authorizer.Authorize("user"+string(rune(999)), "root")
+ assert.NoError(t, err)
+
+ // User not in mapping should NOT have access
+ err = authorizer.Authorize("user"+string(rune(100)), "root")
+ assert.Error(t, err)
+}
+
+func TestAuthorizer_ConcurrentAuthorization(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Set up authorized users
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+ user2Hash, err := sshauth.HashUserID("user2")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
+ MachineUsers: map[string][]uint32{
+ "root": {0, 1},
+ },
+ }
+ authorizer.Update(config)
+
+ // Test concurrent authorization calls (should be safe to read concurrently)
+ const numGoroutines = 100
+ errChan := make(chan error, numGoroutines)
+
+ for i := 0; i < numGoroutines; i++ {
+ go func(idx int) {
+ user := "user1"
+ if idx%2 == 0 {
+ user = "user2"
+ }
+ err := authorizer.Authorize(user, "root")
+ errChan <- err
+ }(i)
+ }
+
+ // Wait for all goroutines to complete and collect errors
+ for i := 0; i < numGoroutines; i++ {
+ err := <-errChan
+ assert.NoError(t, err)
+ }
+}
+
+func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+ user2Hash, err := sshauth.HashUserID("user2")
+ require.NoError(t, err)
+ user3Hash, err := sshauth.HashUserID("user3")
+ require.NoError(t, err)
+
+ // Configure with wildcard - all authorized users can access any OS user
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
+ MachineUsers: map[string][]uint32{
+ "*": {0, 1, 2}, // wildcard with all user indexes
+ },
+ }
+ authorizer.Update(config)
+
+ // All authorized users should be able to access any OS user
+ err = authorizer.Authorize("user1", "root")
+ assert.NoError(t, err)
+
+ err = authorizer.Authorize("user2", "postgres")
+ assert.NoError(t, err)
+
+ err = authorizer.Authorize("user3", "admin")
+ assert.NoError(t, err)
+
+ err = authorizer.Authorize("user1", "ubuntu")
+ assert.NoError(t, err)
+
+ err = authorizer.Authorize("user2", "nginx")
+ assert.NoError(t, err)
+
+ err = authorizer.Authorize("user3", "docker")
+ assert.NoError(t, err)
+}
+
+func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+
+ // Configure with wildcard
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
+ MachineUsers: map[string][]uint32{
+ "*": {0},
+ },
+ }
+ authorizer.Update(config)
+
+ // user1 should have access
+ err = authorizer.Authorize("user1", "root")
+ assert.NoError(t, err)
+
+ // Unauthorized user should still be denied even with wildcard
+ err = authorizer.Authorize("unauthorized-user", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotAuthorized)
+}
+
+func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+ user2Hash, err := sshauth.HashUserID("user2")
+ require.NoError(t, err)
+
+ // Configure with both wildcard and specific mappings
+ // Wildcard takes precedence for users in the wildcard index list
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
+ MachineUsers: map[string][]uint32{
+ "*": {0, 1}, // wildcard for both users
+ "root": {0}, // specific mapping that would normally restrict to user1 only
+ },
+ }
+ authorizer.Update(config)
+
+ // Both users should be able to access root via wildcard (takes precedence over specific mapping)
+ err = authorizer.Authorize("user1", "root")
+ assert.NoError(t, err)
+
+ err = authorizer.Authorize("user2", "root")
+ assert.NoError(t, err)
+
+ // Both users should be able to access any other OS user via wildcard
+ err = authorizer.Authorize("user1", "postgres")
+ assert.NoError(t, err)
+
+ err = authorizer.Authorize("user2", "admin")
+ assert.NoError(t, err)
+}
+
+func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+ user2Hash, err := sshauth.HashUserID("user2")
+ require.NoError(t, err)
+
+ // Configure WITHOUT wildcard - only specific mappings
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
+ MachineUsers: map[string][]uint32{
+ "root": {0}, // only user1
+ "postgres": {1}, // only user2
+ },
+ }
+ authorizer.Update(config)
+
+ // user1 can access root
+ err = authorizer.Authorize("user1", "root")
+ assert.NoError(t, err)
+
+ // user2 can access postgres
+ err = authorizer.Authorize("user2", "postgres")
+ assert.NoError(t, err)
+
+ // user1 cannot access postgres
+ err = authorizer.Authorize("user1", "postgres")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
+
+ // user2 cannot access root
+ err = authorizer.Authorize("user2", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
+
+ // Neither can access unmapped OS users
+ err = authorizer.Authorize("user1", "admin")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+
+ err = authorizer.Authorize("user2", "admin")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+}
+
+func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
+ // This test covers the scenario where wildcard exists with limited indexes.
+ // Only users whose indexes are in the wildcard list can access any OS user via wildcard.
+ // Other users can only access OS users they are explicitly mapped to.
+ authorizer := NewAuthorizer()
+
+ // Create two authorized user hashes (simulating the base64-encoded hashes in the config)
+ wasmHash, err := sshauth.HashUserID("wasm")
+ require.NoError(t, err)
+ user2Hash, err := sshauth.HashUserID("user2")
+ require.NoError(t, err)
+
+ // Configure with wildcard having only index 0, and specific mappings for other OS users
+ config := &Config{
+ UserIDClaim: "sub",
+ AuthorizedUsers: []sshauth.UserIDHash{wasmHash, user2Hash},
+ MachineUsers: map[string][]uint32{
+ "*": {0}, // wildcard with only index 0 - only wasm has wildcard access
+ "alice": {1}, // specific mapping for user2
+ "bob": {1}, // specific mapping for user2
+ },
+ }
+ authorizer.Update(config)
+
+ // wasm (index 0) should access any OS user via wildcard
+ err = authorizer.Authorize("wasm", "root")
+ assert.NoError(t, err, "wasm should access root via wildcard")
+
+ err = authorizer.Authorize("wasm", "alice")
+ assert.NoError(t, err, "wasm should access alice via wildcard")
+
+ err = authorizer.Authorize("wasm", "bob")
+ assert.NoError(t, err, "wasm should access bob via wildcard")
+
+ err = authorizer.Authorize("wasm", "postgres")
+ assert.NoError(t, err, "wasm should access postgres via wildcard")
+
+ // user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres
+ err = authorizer.Authorize("user2", "alice")
+ assert.NoError(t, err, "user2 should access alice via explicit mapping")
+
+ err = authorizer.Authorize("user2", "bob")
+ assert.NoError(t, err, "user2 should access bob via explicit mapping")
+
+ err = authorizer.Authorize("user2", "root")
+ assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)")
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+
+ err = authorizer.Authorize("user2", "postgres")
+ assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)")
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+
+ // Unauthorized user should still be denied
+ err = authorizer.Authorize("user3", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
+}
diff --git a/client/ssh/client.go b/client/ssh/client.go
deleted file mode 100644
index afba347f8..000000000
--- a/client/ssh/client.go
+++ /dev/null
@@ -1,118 +0,0 @@
-//go:build !js
-
-package ssh
-
-import (
- "fmt"
- "net"
- "os"
- "time"
-
- "golang.org/x/crypto/ssh"
- "golang.org/x/term"
-)
-
-// Client wraps crypto/ssh Client to simplify usage
-type Client struct {
- client *ssh.Client
-}
-
-// Close closes the wrapped SSH Client
-func (c *Client) Close() error {
- return c.client.Close()
-}
-
-// OpenTerminal starts an interactive terminal session with the remote SSH server
-func (c *Client) OpenTerminal() error {
- session, err := c.client.NewSession()
- if err != nil {
- return fmt.Errorf("failed to open new session: %v", err)
- }
- defer func() {
- err := session.Close()
- if err != nil {
- return
- }
- }()
-
- fd := int(os.Stdout.Fd())
- state, err := term.MakeRaw(fd)
- if err != nil {
- return fmt.Errorf("failed to run raw terminal: %s", err)
- }
- defer func() {
- err := term.Restore(fd, state)
- if err != nil {
- return
- }
- }()
-
- w, h, err := term.GetSize(fd)
- if err != nil {
- return fmt.Errorf("terminal get size: %s", err)
- }
-
- modes := ssh.TerminalModes{
- ssh.ECHO: 1,
- ssh.TTY_OP_ISPEED: 14400,
- ssh.TTY_OP_OSPEED: 14400,
- }
-
- terminal := os.Getenv("TERM")
- if terminal == "" {
- terminal = "xterm-256color"
- }
- if err := session.RequestPty(terminal, h, w, modes); err != nil {
- return fmt.Errorf("failed requesting pty session with xterm: %s", err)
- }
-
- session.Stdout = os.Stdout
- session.Stderr = os.Stderr
- session.Stdin = os.Stdin
-
- if err := session.Shell(); err != nil {
- return fmt.Errorf("failed to start login shell on the remote host: %s", err)
- }
-
- if err := session.Wait(); err != nil {
- if e, ok := err.(*ssh.ExitError); ok {
- if e.ExitStatus() == 130 {
- return nil
- }
- }
- return fmt.Errorf("failed running SSH session: %s", err)
- }
-
- return nil
-}
-
-// DialWithKey connects to the remote SSH server with a provided private key file (PEM).
-func DialWithKey(addr, user string, privateKey []byte) (*Client, error) {
-
- signer, err := ssh.ParsePrivateKey(privateKey)
- if err != nil {
- return nil, err
- }
-
- config := &ssh.ClientConfig{
- User: user,
- Timeout: 5 * time.Second,
- Auth: []ssh.AuthMethod{
- ssh.PublicKeys(signer),
- },
- HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }),
- }
-
- return Dial("tcp", addr, config)
-}
-
-// Dial connects to the remote SSH server.
-func Dial(network, addr string, config *ssh.ClientConfig) (*Client, error) {
- client, err := ssh.Dial(network, addr, config)
- if err != nil {
- return nil, err
- }
- return &Client{
- client: client,
- }, nil
-}
diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go
new file mode 100644
index 000000000..aab222093
--- /dev/null
+++ b/client/ssh/client/client.go
@@ -0,0 +1,710 @@
+package client
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strconv"
+ "strings"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/crypto/ssh"
+ "golang.org/x/crypto/ssh/knownhosts"
+ "golang.org/x/term"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials/insecure"
+
+ "github.com/netbirdio/netbird/client/internal/profilemanager"
+ "github.com/netbirdio/netbird/client/proto"
+ nbssh "github.com/netbirdio/netbird/client/ssh"
+ "github.com/netbirdio/netbird/client/ssh/detection"
+ "github.com/netbirdio/netbird/util"
+)
+
+const (
+ // DefaultDaemonAddr is the default address for the NetBird daemon
+ DefaultDaemonAddr = "unix:///var/run/netbird.sock"
+ // DefaultDaemonAddrWindows is the default address for the NetBird daemon on Windows
+ DefaultDaemonAddrWindows = "tcp://127.0.0.1:41731"
+)
+
+// Client wraps crypto/ssh Client for simplified SSH operations
+type Client struct {
+ client *ssh.Client
+ terminalState *term.State
+ terminalFd int
+
+ windowsStdoutMode uint32 // nolint:unused
+ windowsStdinMode uint32 // nolint:unused
+}
+
+func (c *Client) Close() error {
+ return c.client.Close()
+}
+
+func (c *Client) OpenTerminal(ctx context.Context) error {
+ session, err := c.client.NewSession()
+ if err != nil {
+ return fmt.Errorf("new session: %w", err)
+ }
+ defer func() {
+ if err := session.Close(); err != nil {
+ log.Debugf("session close error: %v", err)
+ }
+ }()
+
+ if err := c.setupTerminalMode(ctx, session); err != nil {
+ return err
+ }
+
+ c.setupSessionIO(session)
+
+ if err := session.Shell(); err != nil {
+ return fmt.Errorf("start shell: %w", err)
+ }
+
+ return c.waitForSession(ctx, session)
+}
+
+// setupSessionIO connects session streams to local terminal
+func (c *Client) setupSessionIO(session *ssh.Session) {
+ session.Stdout = os.Stdout
+ session.Stderr = os.Stderr
+ session.Stdin = os.Stdin
+}
+
+// waitForSession waits for the session to complete with context cancellation
+func (c *Client) waitForSession(ctx context.Context, session *ssh.Session) error {
+ done := make(chan error, 1)
+ go func() {
+ done <- session.Wait()
+ }()
+
+ defer c.restoreTerminal()
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case err := <-done:
+ return c.handleSessionError(err)
+ }
+}
+
+// handleSessionError processes session termination errors
+func (c *Client) handleSessionError(err error) error {
+ if err == nil {
+ return nil
+ }
+
+ var e *ssh.ExitError
+ var em *ssh.ExitMissingError
+ if !errors.As(err, &e) && !errors.As(err, &em) {
+ return fmt.Errorf("session wait: %w", err)
+ }
+
+ return nil
+}
+
+// restoreTerminal restores the terminal to its original state
+func (c *Client) restoreTerminal() {
+ if c.terminalState != nil {
+ _ = term.Restore(c.terminalFd, c.terminalState)
+ c.terminalState = nil
+ c.terminalFd = 0
+ }
+
+ if err := c.restoreWindowsConsoleState(); err != nil {
+ log.Debugf("restore Windows console state: %v", err)
+ }
+}
+
+// ExecuteCommand executes a command on the remote host and returns the output
+func (c *Client) ExecuteCommand(ctx context.Context, command string) ([]byte, error) {
+ session, cleanup, err := c.createSession(ctx)
+ if err != nil {
+ return nil, err
+ }
+ defer cleanup()
+
+ output, err := session.CombinedOutput(command)
+ if err != nil {
+ var e *ssh.ExitError
+ var em *ssh.ExitMissingError
+ if !errors.As(err, &e) && !errors.As(err, &em) {
+ return output, fmt.Errorf("execute command: %w", err)
+ }
+ }
+
+ return output, nil
+}
+
+// ExecuteCommandWithIO executes a command with interactive I/O connected to local terminal
+func (c *Client) ExecuteCommandWithIO(ctx context.Context, command string) error {
+ session, cleanup, err := c.createSession(ctx)
+ if err != nil {
+ return fmt.Errorf("create session: %w", err)
+ }
+ defer cleanup()
+
+ c.setupSessionIO(session)
+
+ if err := session.Start(command); err != nil {
+ return fmt.Errorf("start command: %w", err)
+ }
+
+ done := make(chan error, 1)
+ go func() {
+ done <- session.Wait()
+ }()
+
+ select {
+ case <-ctx.Done():
+ _ = session.Signal(ssh.SIGTERM)
+ select {
+ case <-done:
+ return ctx.Err()
+ case <-time.After(100 * time.Millisecond):
+ return ctx.Err()
+ }
+ case err := <-done:
+ return c.handleCommandError(err)
+ }
+}
+
+// ExecuteCommandWithPTY executes a command with a pseudo-terminal for interactive sessions
+func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) error {
+ session, cleanup, err := c.createSession(ctx)
+ if err != nil {
+ return fmt.Errorf("create session: %w", err)
+ }
+ defer cleanup()
+
+ if err := c.setupTerminalMode(ctx, session); err != nil {
+ return fmt.Errorf("setup terminal mode: %w", err)
+ }
+
+ c.setupSessionIO(session)
+
+ if err := session.Start(command); err != nil {
+ return fmt.Errorf("start command: %w", err)
+ }
+
+ defer c.restoreTerminal()
+
+ done := make(chan error, 1)
+ go func() {
+ done <- session.Wait()
+ }()
+
+ select {
+ case <-ctx.Done():
+ _ = session.Signal(ssh.SIGTERM)
+ select {
+ case <-done:
+ return ctx.Err()
+ case <-time.After(100 * time.Millisecond):
+ return ctx.Err()
+ }
+ case err := <-done:
+ return c.handleCommandError(err)
+ }
+}
+
+// handleCommandError processes command execution errors
+func (c *Client) handleCommandError(err error) error {
+ if err == nil {
+ return nil
+ }
+
+ var e *ssh.ExitError
+ var em *ssh.ExitMissingError
+ if errors.As(err, &e) || errors.As(err, &em) {
+ return err
+ }
+
+ return fmt.Errorf("execute command: %w", err)
+}
+
+// setupContextCancellation sets up context cancellation for a session
+func (c *Client) setupContextCancellation(ctx context.Context, session *ssh.Session) func() {
+ done := make(chan struct{})
+ go func() {
+ select {
+ case <-ctx.Done():
+ _ = session.Signal(ssh.SIGTERM)
+ _ = session.Close()
+ case <-done:
+ }
+ }()
+ return func() { close(done) }
+}
+
+// createSession creates a new SSH session with context cancellation setup
+func (c *Client) createSession(ctx context.Context) (*ssh.Session, func(), error) {
+ session, err := c.client.NewSession()
+ if err != nil {
+ return nil, nil, fmt.Errorf("new session: %w", err)
+ }
+
+ cancel := c.setupContextCancellation(ctx, session)
+ cleanup := func() {
+ cancel()
+ _ = session.Close()
+ }
+
+ return session, cleanup, nil
+}
+
+// getDefaultDaemonAddr returns the daemon address from environment or default for the OS
+func getDefaultDaemonAddr() string {
+ if addr := os.Getenv("NB_DAEMON_ADDR"); addr != "" {
+ return addr
+ }
+ if runtime.GOOS == "windows" {
+ return DefaultDaemonAddrWindows
+ }
+ return DefaultDaemonAddr
+}
+
+// DialOptions contains options for SSH connections
+type DialOptions struct {
+ KnownHostsFile string
+ IdentityFile string
+ DaemonAddr string
+ SkipCachedToken bool
+ InsecureSkipVerify bool
+ NoBrowser bool
+}
+
+// Dial connects to the given ssh server with specified options
+func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
+ daemonAddr := opts.DaemonAddr
+ if daemonAddr == "" {
+ daemonAddr = getDefaultDaemonAddr()
+ }
+ opts.DaemonAddr = daemonAddr
+
+ hostKeyCallback, err := createHostKeyCallback(opts)
+ if err != nil {
+ return nil, fmt.Errorf("create host key callback: %w", err)
+ }
+
+ config := &ssh.ClientConfig{
+ User: user,
+ Timeout: 30 * time.Second,
+ HostKeyCallback: hostKeyCallback,
+ }
+
+ if opts.IdentityFile != "" {
+ authMethod, err := createSSHKeyAuth(opts.IdentityFile)
+ if err != nil {
+ return nil, fmt.Errorf("create SSH key auth: %w", err)
+ }
+ config.Auth = append(config.Auth, authMethod)
+ }
+
+ return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken, opts.NoBrowser)
+}
+
+// dialSSH establishes an SSH connection without JWT authentication
+func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) {
+ dialer := &net.Dialer{}
+ conn, err := dialer.DialContext(ctx, network, addr)
+ if err != nil {
+ return nil, fmt.Errorf("dial %s: %w", addr, err)
+ }
+
+ clientConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
+ if err != nil {
+ if closeErr := conn.Close(); closeErr != nil {
+ log.Debugf("connection close after handshake failure: %v", closeErr)
+ }
+ return nil, fmt.Errorf("ssh handshake: %w", err)
+ }
+
+ client := ssh.NewClient(clientConn, chans, reqs)
+ return &Client{
+ client: client,
+ }, nil
+}
+
+// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection
+func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache, noBrowser bool) (*Client, error) {
+ host, portStr, err := net.SplitHostPort(addr)
+ if err != nil {
+ return nil, fmt.Errorf("parse address %s: %w", addr, err)
+ }
+ port, err := strconv.Atoi(portStr)
+ if err != nil {
+ return nil, fmt.Errorf("parse port %s: %w", portStr, err)
+ }
+
+ detectionCtx, cancel := context.WithTimeout(ctx, config.Timeout)
+ defer cancel()
+
+ dialer := &net.Dialer{}
+ serverType, err := detection.DetectSSHServerType(detectionCtx, dialer, host, port)
+ if err != nil {
+ return nil, fmt.Errorf("SSH server detection: %w", err)
+ }
+
+ if !serverType.RequiresJWT() {
+ return dialSSH(ctx, network, addr, config)
+ }
+
+ jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout)
+ defer cancel()
+
+ jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache, noBrowser)
+ if err != nil {
+ return nil, fmt.Errorf("request JWT token: %w", err)
+ }
+
+ configWithJWT := nbssh.AddJWTAuth(config, jwtToken)
+ return dialSSH(ctx, network, addr, configWithJWT)
+}
+
+// requestJWTToken requests a JWT token from the NetBird daemon
+func requestJWTToken(ctx context.Context, daemonAddr string, skipCache, noBrowser bool) (string, error) {
+ hint := profilemanager.GetLoginHint()
+
+ conn, err := connectToDaemon(daemonAddr)
+ if err != nil {
+ return "", fmt.Errorf("connect to daemon: %w", err)
+ }
+ defer conn.Close()
+
+ client := proto.NewDaemonServiceClient(conn)
+
+ var browserOpener func(string) error
+ if !noBrowser {
+ browserOpener = util.OpenBrowser
+ }
+
+ return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint, browserOpener)
+}
+
+// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
+func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
+ conn, err := connectToDaemon(daemonAddr)
+ if err != nil {
+ return err
+ }
+ defer func() {
+ if err := conn.Close(); err != nil {
+ log.Debugf("daemon connection close error: %v", err)
+ }
+ }()
+
+ client := proto.NewDaemonServiceClient(conn)
+ verifier := nbssh.NewDaemonHostKeyVerifier(client)
+ callback := nbssh.CreateHostKeyCallback(verifier)
+ return callback(hostname, remote, key)
+}
+
+func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) {
+ addr := strings.TrimPrefix(daemonAddr, "tcp://")
+
+ conn, err := grpc.NewClient(
+ addr,
+ grpc.WithTransportCredentials(insecure.NewCredentials()),
+ )
+ if err != nil {
+ log.Debugf("failed to create gRPC client for NetBird daemon at %s: %v", daemonAddr, err)
+ return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err)
+ }
+
+ return conn, nil
+}
+
+// getKnownHostsFiles returns paths to known_hosts files in order of preference
+func getKnownHostsFiles() []string {
+ var files []string
+
+ // User's known_hosts file (highest priority)
+ if homeDir, err := os.UserHomeDir(); err == nil {
+ userKnownHosts := filepath.Join(homeDir, ".ssh", "known_hosts")
+ files = append(files, userKnownHosts)
+ }
+
+ // NetBird managed known_hosts files
+ if runtime.GOOS == "windows" {
+ programData := os.Getenv("PROGRAMDATA")
+ if programData == "" {
+ programData = `C:\ProgramData`
+ }
+ netbirdKnownHosts := filepath.Join(programData, "ssh", "ssh_known_hosts.d", "99-netbird")
+ files = append(files, netbirdKnownHosts)
+ } else {
+ files = append(files, "/etc/ssh/ssh_known_hosts.d/99-netbird")
+ files = append(files, "/etc/ssh/ssh_known_hosts")
+ }
+
+ return files
+}
+
+// createHostKeyCallback creates a host key verification callback
+func createHostKeyCallback(opts DialOptions) (ssh.HostKeyCallback, error) {
+ if opts.InsecureSkipVerify {
+ return ssh.InsecureIgnoreHostKey(), nil // #nosec G106 - User explicitly requested insecure mode
+ }
+
+ return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
+ if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil {
+ return nil
+ }
+ return tryKnownHostsVerification(hostname, remote, key, opts.KnownHostsFile)
+ }, nil
+}
+
+func tryDaemonVerification(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
+ if daemonAddr == "" {
+ return fmt.Errorf("no daemon address")
+ }
+ return verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr)
+}
+
+func tryKnownHostsVerification(hostname string, remote net.Addr, key ssh.PublicKey, knownHostsFile string) error {
+ knownHostsFiles := getKnownHostsFilesList(knownHostsFile)
+ hostKeyCallbacks := buildHostKeyCallbacks(knownHostsFiles)
+
+ for _, callback := range hostKeyCallbacks {
+ if err := callback(hostname, remote, key); err == nil {
+ return nil
+ }
+ }
+ return fmt.Errorf("host key verification failed: key for %s not found in any known_hosts file", hostname)
+}
+
+func getKnownHostsFilesList(knownHostsFile string) []string {
+ if knownHostsFile != "" {
+ return []string{knownHostsFile}
+ }
+ return getKnownHostsFiles()
+}
+
+func buildHostKeyCallbacks(knownHostsFiles []string) []ssh.HostKeyCallback {
+ var hostKeyCallbacks []ssh.HostKeyCallback
+ for _, file := range knownHostsFiles {
+ if callback, err := knownhosts.New(file); err == nil {
+ hostKeyCallbacks = append(hostKeyCallbacks, callback)
+ }
+ }
+ return hostKeyCallbacks
+}
+
+// createSSHKeyAuth creates SSH key authentication from a private key file
+func createSSHKeyAuth(keyFile string) (ssh.AuthMethod, error) {
+ keyData, err := os.ReadFile(keyFile)
+ if err != nil {
+ return nil, fmt.Errorf("read SSH key file %s: %w", keyFile, err)
+ }
+
+ signer, err := ssh.ParsePrivateKey(keyData)
+ if err != nil {
+ return nil, fmt.Errorf("parse SSH private key: %w", err)
+ }
+
+ return ssh.PublicKeys(signer), nil
+}
+
+// LocalPortForward sets up local port forwarding, binding to localAddr and forwarding to remoteAddr
+func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr string) error {
+ localListener, err := net.Listen("tcp", localAddr)
+ if err != nil {
+ return fmt.Errorf("listen on %s: %w", localAddr, err)
+ }
+
+ go func() {
+ defer func() {
+ if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
+ log.Debugf("local listener close error: %v", err)
+ }
+ }()
+ for {
+ localConn, err := localListener.Accept()
+ if err != nil {
+ if ctx.Err() != nil {
+ return
+ }
+ continue
+ }
+
+ go c.handleLocalForward(localConn, remoteAddr)
+ }
+ }()
+
+ <-ctx.Done()
+ if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
+ log.Debugf("local listener close error: %v", err)
+ }
+ return ctx.Err()
+}
+
+// handleLocalForward handles a single local port forwarding connection
+func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
+ defer func() {
+ if err := localConn.Close(); err != nil {
+ log.Debugf("local connection close error: %v", err)
+ }
+ }()
+
+ channel, err := c.client.Dial("tcp", remoteAddr)
+ if err != nil {
+ if strings.Contains(err.Error(), "administratively prohibited") {
+ _, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n")
+ } else {
+ log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err)
+ }
+ return
+ }
+ defer func() {
+ if err := channel.Close(); err != nil {
+ log.Debugf("remote channel close error: %v", err)
+ }
+ }()
+
+ go func() {
+ if _, err := io.Copy(channel, localConn); err != nil {
+ log.Debugf("local forward copy error (local->remote): %v", err)
+ }
+ }()
+
+ if _, err := io.Copy(localConn, channel); err != nil {
+ log.Debugf("local forward copy error (remote->local): %v", err)
+ }
+}
+
+// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
+func (c *Client) RemotePortForward(ctx context.Context, remoteAddr, localAddr string) error {
+ host, port, err := c.parseRemoteAddress(remoteAddr)
+ if err != nil {
+ return fmt.Errorf("parse remote address: %w", err)
+ }
+
+ req := c.buildTCPIPForwardRequest(host, port)
+ if err := c.sendTCPIPForwardRequest(req); err != nil {
+ return fmt.Errorf("setup remote forward: %w", err)
+ }
+
+ go c.handleRemoteForwardChannels(ctx, localAddr)
+
+ <-ctx.Done()
+
+ if err := c.cancelTCPIPForwardRequest(req); err != nil {
+ return fmt.Errorf("cancel tcpip-forward: %w", err)
+ }
+ return ctx.Err()
+}
+
+// parseRemoteAddress parses host and port from remote address string
+func (c *Client) parseRemoteAddress(remoteAddr string) (string, uint32, error) {
+ host, portStr, err := net.SplitHostPort(remoteAddr)
+ if err != nil {
+ return "", 0, fmt.Errorf("parse remote address %s: %w", remoteAddr, err)
+ }
+
+ port, err := strconv.Atoi(portStr)
+ if err != nil {
+ return "", 0, fmt.Errorf("parse remote port %s: %w", portStr, err)
+ }
+
+ return host, uint32(port), nil
+}
+
+// buildTCPIPForwardRequest creates a tcpip-forward request message
+func (c *Client) buildTCPIPForwardRequest(host string, port uint32) tcpipForwardMsg {
+ return tcpipForwardMsg{
+ Host: host,
+ Port: port,
+ }
+}
+
+// sendTCPIPForwardRequest sends the tcpip-forward request to establish remote port forwarding
+func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error {
+ ok, _, err := c.client.SendRequest("tcpip-forward", true, ssh.Marshal(&req))
+ if err != nil {
+ return fmt.Errorf("send tcpip-forward request: %w", err)
+ }
+ if !ok {
+ return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)")
+ }
+ return nil
+}
+
+// cancelTCPIPForwardRequest cancels the tcpip-forward request
+func (c *Client) cancelTCPIPForwardRequest(req tcpipForwardMsg) error {
+ _, _, err := c.client.SendRequest("cancel-tcpip-forward", true, ssh.Marshal(&req))
+ if err != nil {
+ return fmt.Errorf("send cancel-tcpip-forward request: %w", err)
+ }
+ return nil
+}
+
+// handleRemoteForwardChannels handles incoming forwarded-tcpip channels
+func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr string) {
+ // Get the channel once - subsequent calls return nil!
+ channelRequests := c.client.HandleChannelOpen("forwarded-tcpip")
+ if channelRequests == nil {
+ log.Debugf("forwarded-tcpip channel type already being handled")
+ return
+ }
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case newChan := <-channelRequests:
+ if newChan != nil {
+ go c.handleRemoteForwardChannel(newChan, localAddr)
+ }
+ }
+ }
+}
+
+// handleRemoteForwardChannel handles a single forwarded-tcpip channel
+func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) {
+ channel, reqs, err := newChan.Accept()
+ if err != nil {
+ return
+ }
+ defer func() {
+ if err := channel.Close(); err != nil {
+ log.Debugf("remote channel close error: %v", err)
+ }
+ }()
+
+ go ssh.DiscardRequests(reqs)
+
+ localConn, err := net.Dial("tcp", localAddr)
+ if err != nil {
+ return
+ }
+ defer func() {
+ if err := localConn.Close(); err != nil {
+ log.Debugf("local connection close error: %v", err)
+ }
+ }()
+
+ go func() {
+ if _, err := io.Copy(localConn, channel); err != nil {
+ log.Debugf("remote forward copy error (remote->local): %v", err)
+ }
+ }()
+
+ if _, err := io.Copy(channel, localConn); err != nil {
+ log.Debugf("remote forward copy error (local->remote): %v", err)
+ }
+}
+
+// tcpipForwardMsg represents the structure for tcpip-forward requests
+type tcpipForwardMsg struct {
+ Host string
+ Port uint32
+}
diff --git a/client/ssh/client/client_test.go b/client/ssh/client/client_test.go
new file mode 100644
index 000000000..e38e02a86
--- /dev/null
+++ b/client/ssh/client/client_test.go
@@ -0,0 +1,512 @@
+package client
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "os/user"
+ "runtime"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ cryptossh "golang.org/x/crypto/ssh"
+
+ "github.com/netbirdio/netbird/client/ssh"
+ sshserver "github.com/netbirdio/netbird/client/ssh/server"
+ "github.com/netbirdio/netbird/client/ssh/testutil"
+)
+
+// TestMain handles package-level setup and cleanup
+func TestMain(m *testing.M) {
+ // Guard against infinite recursion when test binary is called as "netbird ssh exec"
+ // This happens when running tests as non-privileged user with fallback
+ if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
+ // Just exit with error to break the recursion
+ fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
+ os.Exit(1)
+ }
+
+ // Run tests
+ code := m.Run()
+
+ // Cleanup any created test users
+ testutil.CleanupTestUsers()
+
+ os.Exit(code)
+}
+
+func TestSSHClient_DialWithKey(t *testing.T) {
+ // Generate host key for server
+ hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
+ require.NoError(t, err)
+
+ // Create and start server
+ serverConfig := &sshserver.Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := sshserver.New(serverConfig)
+ server.SetAllowRootLogin(true) // Allow root/admin login for tests
+
+ serverAddr := sshserver.StartTestServer(t, server)
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+
+ // Test Dial
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ currentUser := testutil.GetTestUsername(t)
+ client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
+ InsecureSkipVerify: true,
+ })
+ require.NoError(t, err)
+ defer func() {
+ err := client.Close()
+ assert.NoError(t, err)
+ }()
+
+ // Verify client is connected
+ assert.NotNil(t, client.client)
+}
+
+func TestSSHClient_CommandExecution(t *testing.T) {
+ if runtime.GOOS == "windows" && testutil.IsCI() {
+ t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
+ }
+
+ server, _, client := setupTestSSHServerAndClient(t)
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+ defer func() {
+ err := client.Close()
+ assert.NoError(t, err)
+ }()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ t.Run("ExecuteCommand captures output", func(t *testing.T) {
+ output, err := client.ExecuteCommand(ctx, "echo hello")
+ assert.NoError(t, err)
+ assert.Contains(t, string(output), "hello")
+ })
+
+ t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
+ err := client.ExecuteCommandWithIO(ctx, "echo world")
+ assert.NoError(t, err)
+ })
+
+ t.Run("commands with flags work", func(t *testing.T) {
+ output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
+ assert.NoError(t, err)
+ assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
+ })
+
+ t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
+ var testCmd string
+ if runtime.GOOS == "windows" {
+ testCmd = "echo hello | Select-String notfound"
+ } else {
+ testCmd = "echo 'hello' | grep 'notfound'"
+ }
+ _, err := client.ExecuteCommand(ctx, testCmd)
+ assert.NoError(t, err)
+ })
+}
+
+func TestSSHClient_ConnectionHandling(t *testing.T) {
+ server, serverAddr, _ := setupTestSSHServerAndClient(t)
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+
+ // Generate client key for multiple connections
+
+ const numClients = 3
+ clients := make([]*Client, numClients)
+
+ currentUser := testutil.GetTestUsername(t)
+ for i := 0; i < numClients; i++ {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
+ InsecureSkipVerify: true,
+ })
+ cancel()
+ require.NoError(t, err, "Client %d should connect successfully", i)
+ clients[i] = client
+ }
+
+ for i, client := range clients {
+ err := client.Close()
+ assert.NoError(t, err, "Client %d should close without error", i)
+ }
+}
+
+func TestSSHClient_ContextCancellation(t *testing.T) {
+ server, serverAddr, _ := setupTestSSHServerAndClient(t)
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+
+ t.Run("connection with short timeout", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
+ defer cancel()
+
+ currentUser := testutil.GetTestUsername(t)
+ _, err := Dial(ctx, serverAddr, currentUser, DialOptions{
+ InsecureSkipVerify: true,
+ })
+ if err != nil {
+ // Check for actual timeout-related errors rather than string matching
+ assert.True(t,
+ errors.Is(err, context.DeadlineExceeded) ||
+ errors.Is(err, context.Canceled) ||
+ strings.Contains(err.Error(), "timeout"),
+ "Expected timeout-related error, got: %v", err)
+ }
+ })
+
+ t.Run("command execution cancellation", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ currentUser := testutil.GetTestUsername(t)
+ client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
+ InsecureSkipVerify: true,
+ })
+ require.NoError(t, err)
+ defer func() {
+ if err := client.Close(); err != nil {
+ t.Logf("client close error: %v", err)
+ }
+ }()
+
+ cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer cmdCancel()
+
+ err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
+ if err != nil {
+ var exitMissingErr *cryptossh.ExitMissingError
+ isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
+ errors.Is(err, context.Canceled) ||
+ errors.As(err, &exitMissingErr)
+ assert.True(t, isValidCancellation, "Should handle command cancellation properly")
+ }
+ })
+}
+
+func TestSSHClient_NoAuthMode(t *testing.T) {
+ hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
+ require.NoError(t, err)
+
+ serverConfig := &sshserver.Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := sshserver.New(serverConfig)
+ server.SetAllowRootLogin(true) // Allow root/admin login for tests
+
+ serverAddr := sshserver.StartTestServer(t, server)
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ currentUser := testutil.GetTestUsername(t)
+
+ t.Run("any key succeeds in no-auth mode", func(t *testing.T) {
+ client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
+ InsecureSkipVerify: true,
+ })
+ assert.NoError(t, err)
+ if client != nil {
+ require.NoError(t, client.Close(), "Client should close without error")
+ }
+ })
+}
+
+func TestSSHClient_TerminalState(t *testing.T) {
+ server, _, client := setupTestSSHServerAndClient(t)
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+ defer func() {
+ err := client.Close()
+ assert.NoError(t, err)
+ }()
+
+ assert.Nil(t, client.terminalState)
+ assert.Equal(t, 0, client.terminalFd)
+
+ client.restoreTerminal()
+ assert.Nil(t, client.terminalState)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
+ defer cancel()
+
+ err := client.OpenTerminal(ctx)
+ // In test environment without a real terminal, this may complete quickly or timeout
+ // Both behaviors are acceptable for testing terminal state management
+ if err != nil {
+ if runtime.GOOS == "windows" {
+ assert.True(t,
+ strings.Contains(err.Error(), "context deadline exceeded") ||
+ strings.Contains(err.Error(), "console"),
+ "Should timeout or have console error on Windows")
+ } else {
+ // On Unix systems in test environment, we may get various errors
+ // including timeouts or terminal-related errors
+ assert.True(t,
+ strings.Contains(err.Error(), "context deadline exceeded") ||
+ strings.Contains(err.Error(), "terminal") ||
+ strings.Contains(err.Error(), "pty"),
+ "Expected timeout or terminal-related error, got: %v", err)
+ }
+ }
+}
+
+func setupTestSSHServerAndClient(t *testing.T) (*sshserver.Server, string, *Client) {
+ hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
+ require.NoError(t, err)
+
+ serverConfig := &sshserver.Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := sshserver.New(serverConfig)
+ server.SetAllowRootLogin(true) // Allow root/admin login for tests
+
+ serverAddr := sshserver.StartTestServer(t, server)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ currentUser := testutil.GetTestUsername(t)
+ client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
+ InsecureSkipVerify: true,
+ })
+ require.NoError(t, err)
+
+ return server, serverAddr, client
+}
+
+func TestSSHClient_PortForwarding(t *testing.T) {
+ server, _, client := setupTestSSHServerAndClient(t)
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+ defer func() {
+ err := client.Close()
+ assert.NoError(t, err)
+ }()
+
+ t.Run("local forwarding times out gracefully", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer cancel()
+
+ err := client.LocalPortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080")
+ assert.Error(t, err)
+ assert.True(t,
+ errors.Is(err, context.DeadlineExceeded) ||
+ errors.Is(err, context.Canceled) ||
+ strings.Contains(err.Error(), "connection"),
+ "Expected context or connection error")
+ })
+
+ t.Run("remote forwarding denied", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
+ defer cancel()
+
+ err := client.RemotePortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080")
+ assert.Error(t, err)
+ assert.True(t,
+ strings.Contains(err.Error(), "denied") ||
+ strings.Contains(err.Error(), "disabled"),
+ "Should be denied by default")
+ })
+
+ t.Run("invalid addresses fail", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
+ defer cancel()
+
+ err := client.LocalPortForward(ctx, "invalid:address", "127.0.0.1:8080")
+ assert.Error(t, err)
+
+ err = client.LocalPortForward(ctx, "127.0.0.1:0", "invalid:address")
+ assert.Error(t, err)
+ })
+}
+
+func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
+ if testing.Short() {
+ t.Skip("Skipping data transfer test in short mode")
+ }
+
+ hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
+ require.NoError(t, err)
+
+ serverConfig := &sshserver.Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := sshserver.New(serverConfig)
+ server.SetAllowLocalPortForwarding(true)
+ server.SetAllowRootLogin(true) // Allow root/admin login for tests
+
+ serverAddr := sshserver.StartTestServer(t, server)
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ // Port forwarding requires the actual current user, not test user
+ realUser, err := getRealCurrentUser()
+ require.NoError(t, err)
+
+ // Skip if running as system account that can't do port forwarding
+ if testutil.IsSystemAccount(realUser) {
+ t.Skipf("Skipping port forwarding test - running as system account: %s", realUser)
+ }
+
+ client, err := Dial(ctx, serverAddr, realUser, DialOptions{
+ InsecureSkipVerify: true, // Skip host key verification for test
+ })
+ require.NoError(t, err)
+ defer func() {
+ if err := client.Close(); err != nil {
+ t.Logf("client close error: %v", err)
+ }
+ }()
+
+ testServer, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+ defer func() {
+ if err := testServer.Close(); err != nil {
+ t.Logf("test server close error: %v", err)
+ }
+ }()
+
+ testServerAddr := testServer.Addr().String()
+ expectedResponse := "Hello, World!"
+
+ go func() {
+ for {
+ conn, err := testServer.Accept()
+ if err != nil {
+ return
+ }
+ go func(c net.Conn) {
+ defer func() {
+ if err := c.Close(); err != nil {
+ t.Logf("connection close error: %v", err)
+ }
+ }()
+ buf := make([]byte, 1024)
+ if _, err := c.Read(buf); err != nil {
+ t.Logf("connection read error: %v", err)
+ return
+ }
+ if _, err := c.Write([]byte(expectedResponse)); err != nil {
+ t.Logf("connection write error: %v", err)
+ }
+ }(conn)
+ }
+ }()
+
+ localListener, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+ localAddr := localListener.Addr().String()
+ if err := localListener.Close(); err != nil {
+ t.Logf("local listener close error: %v", err)
+ }
+
+ ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ go func() {
+ err := client.LocalPortForward(ctx, localAddr, testServerAddr)
+ if err != nil && !errors.Is(err, context.Canceled) {
+ if isWindowsPrivilegeError(err) {
+ t.Logf("Port forward failed due to Windows privilege restrictions: %v", err)
+ } else {
+ t.Logf("Port forward error: %v", err)
+ }
+ }
+ }()
+
+ time.Sleep(100 * time.Millisecond)
+
+ conn, err := net.DialTimeout("tcp", localAddr, 2*time.Second)
+ require.NoError(t, err)
+ defer func() {
+ if err := conn.Close(); err != nil {
+ t.Logf("connection close error: %v", err)
+ }
+ }()
+
+ _, err = conn.Write([]byte("test"))
+ require.NoError(t, err)
+
+ if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
+ t.Logf("set read deadline error: %v", err)
+ }
+ response := make([]byte, len(expectedResponse))
+ n, err := io.ReadFull(conn, response)
+ require.NoError(t, err)
+ assert.Equal(t, len(expectedResponse), n)
+ assert.Equal(t, expectedResponse, string(response))
+}
+
+// getRealCurrentUser returns the actual current user (not test user) for features like port forwarding
+func getRealCurrentUser() (string, error) {
+ if runtime.GOOS == "windows" {
+ if currentUser, err := user.Current(); err == nil {
+ return currentUser.Username, nil
+ }
+ }
+
+ if username := os.Getenv("USER"); username != "" {
+ return username, nil
+ }
+
+ if currentUser, err := user.Current(); err == nil {
+ return currentUser.Username, nil
+ }
+
+ return "", fmt.Errorf("unable to determine current user")
+}
+
+// isWindowsPrivilegeError checks if an error is related to Windows privilege restrictions
+func isWindowsPrivilegeError(err error) bool {
+ if err == nil {
+ return false
+ }
+
+ errStr := strings.ToLower(err.Error())
+ return strings.Contains(errStr, "ntstatus=0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD
+ strings.Contains(errStr, "0xc0000041") || // STATUS_PRIVILEGE_NOT_HELD (LsaRegisterLogonProcess)
+ strings.Contains(errStr, "0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD (LsaLogonUser)
+ strings.Contains(errStr, "privilege") ||
+ strings.Contains(errStr, "access denied") ||
+ strings.Contains(errStr, "user authentication failed")
+}
diff --git a/client/ssh/client/terminal_unix.go b/client/ssh/client/terminal_unix.go
new file mode 100644
index 000000000..aaa3418f9
--- /dev/null
+++ b/client/ssh/client/terminal_unix.go
@@ -0,0 +1,127 @@
+//go:build !windows
+
+package client
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "os/signal"
+ "syscall"
+
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/crypto/ssh"
+ "golang.org/x/term"
+)
+
+func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) error {
+ stdinFd := int(os.Stdin.Fd())
+
+ if !term.IsTerminal(stdinFd) {
+ return c.setupNonTerminalMode(ctx, session)
+ }
+
+ fd := int(os.Stdin.Fd())
+
+ state, err := term.MakeRaw(fd)
+ if err != nil {
+ return c.setupNonTerminalMode(ctx, session)
+ }
+
+ if err := c.setupTerminal(session, fd); err != nil {
+ if restoreErr := term.Restore(fd, state); restoreErr != nil {
+ log.Debugf("restore terminal state: %v", restoreErr)
+ }
+ return err
+ }
+
+ c.terminalState = state
+ c.terminalFd = fd
+
+ sigChan := make(chan os.Signal, 1)
+ signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
+
+ go func() {
+ defer signal.Stop(sigChan)
+ select {
+ case <-ctx.Done():
+ if err := term.Restore(fd, state); err != nil {
+ log.Debugf("restore terminal state: %v", err)
+ }
+ case sig := <-sigChan:
+ if err := term.Restore(fd, state); err != nil {
+ log.Debugf("restore terminal state: %v", err)
+ }
+ signal.Reset(sig)
+ s, ok := sig.(syscall.Signal)
+ if !ok {
+ log.Debugf("signal %v is not a syscall.Signal: %T", sig, sig)
+ return
+ }
+ if err := syscall.Kill(syscall.Getpid(), s); err != nil {
+ log.Debugf("kill process with signal %v: %v", s, err)
+ }
+ }
+ }()
+
+ return nil
+}
+
+func (c *Client) setupNonTerminalMode(_ context.Context, session *ssh.Session) error {
+ return nil
+}
+
+// restoreWindowsConsoleState is a no-op on Unix systems
+func (c *Client) restoreWindowsConsoleState() error {
+ return nil
+}
+
+func (c *Client) setupTerminal(session *ssh.Session, fd int) error {
+ w, h, err := term.GetSize(fd)
+ if err != nil {
+ return fmt.Errorf("get terminal size: %w", err)
+ }
+
+ modes := ssh.TerminalModes{
+ ssh.ECHO: 1,
+ ssh.TTY_OP_ISPEED: 14400,
+ ssh.TTY_OP_OSPEED: 14400,
+ // Ctrl+C
+ ssh.VINTR: 3,
+ // Ctrl+\
+ ssh.VQUIT: 28,
+ // Backspace
+ ssh.VERASE: 127,
+ // Ctrl+U
+ ssh.VKILL: 21,
+ // Ctrl+D
+ ssh.VEOF: 4,
+ ssh.VEOL: 0,
+ ssh.VEOL2: 0,
+ // Ctrl+Q
+ ssh.VSTART: 17,
+ // Ctrl+S
+ ssh.VSTOP: 19,
+ // Ctrl+Z
+ ssh.VSUSP: 26,
+ // Ctrl+O
+ ssh.VDISCARD: 15,
+ // Ctrl+R
+ ssh.VREPRINT: 18,
+ // Ctrl+W
+ ssh.VWERASE: 23,
+ // Ctrl+V
+ ssh.VLNEXT: 22,
+ }
+
+ terminal := os.Getenv("TERM")
+ if terminal == "" {
+ terminal = "xterm-256color"
+ }
+
+ if err := session.RequestPty(terminal, h, w, modes); err != nil {
+ return fmt.Errorf("request pty: %w", err)
+ }
+
+ return nil
+}
diff --git a/client/ssh/client/terminal_windows.go b/client/ssh/client/terminal_windows.go
new file mode 100644
index 000000000..462438317
--- /dev/null
+++ b/client/ssh/client/terminal_windows.go
@@ -0,0 +1,265 @@
+package client
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "syscall"
+ "unsafe"
+
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/crypto/ssh"
+)
+
+const (
+ enableProcessedInput = 0x0001
+ enableLineInput = 0x0002
+ enableEchoInput = 0x0004 // Input mode: ENABLE_ECHO_INPUT
+ enableVirtualTerminalProcessing = 0x0004 // Output mode: ENABLE_VIRTUAL_TERMINAL_PROCESSING (same value, different mode)
+ enableVirtualTerminalInput = 0x0200
+)
+
+var (
+ kernel32 = syscall.NewLazyDLL("kernel32.dll")
+ procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
+ procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
+ procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
+)
+
+// ConsoleUnavailableError indicates that Windows console handles are not available
+// (e.g., in CI environments where stdout/stdin are redirected)
+type ConsoleUnavailableError struct {
+ Operation string
+ Err error
+}
+
+func (e *ConsoleUnavailableError) Error() string {
+ return fmt.Sprintf("console unavailable for %s: %v", e.Operation, e.Err)
+}
+
+func (e *ConsoleUnavailableError) Unwrap() error {
+ return e.Err
+}
+
+type coord struct {
+ x, y int16
+}
+
+type smallRect struct {
+ left, top, right, bottom int16
+}
+
+type consoleScreenBufferInfo struct {
+ size coord
+ cursorPosition coord
+ attributes uint16
+ window smallRect
+ maximumWindowSize coord
+}
+
+func (c *Client) setupTerminalMode(_ context.Context, session *ssh.Session) error {
+ if err := c.saveWindowsConsoleState(); err != nil {
+ var consoleErr *ConsoleUnavailableError
+ if errors.As(err, &consoleErr) {
+ log.Debugf("console unavailable, not requesting PTY: %v", err)
+ return nil
+ }
+ return fmt.Errorf("save console state: %w", err)
+ }
+
+ if err := c.enableWindowsVirtualTerminal(); err != nil {
+ var consoleErr *ConsoleUnavailableError
+ if errors.As(err, &consoleErr) {
+ log.Debugf("virtual terminal unavailable: %v", err)
+ } else {
+ return fmt.Errorf("failed to enable virtual terminal: %w", err)
+ }
+ }
+
+ w, h := c.getWindowsConsoleSize()
+
+ modes := ssh.TerminalModes{
+ ssh.ECHO: 1,
+ ssh.TTY_OP_ISPEED: 14400,
+ ssh.TTY_OP_OSPEED: 14400,
+ ssh.ICRNL: 1,
+ ssh.OPOST: 1,
+ ssh.ONLCR: 1,
+ ssh.ISIG: 1,
+ ssh.ICANON: 1,
+ ssh.VINTR: 3, // Ctrl+C
+ ssh.VQUIT: 28, // Ctrl+\
+ ssh.VERASE: 127, // Backspace
+ ssh.VKILL: 21, // Ctrl+U
+ ssh.VEOF: 4, // Ctrl+D
+ ssh.VEOL: 0,
+ ssh.VEOL2: 0,
+ ssh.VSTART: 17, // Ctrl+Q
+ ssh.VSTOP: 19, // Ctrl+S
+ ssh.VSUSP: 26, // Ctrl+Z
+ ssh.VDISCARD: 15, // Ctrl+O
+ ssh.VWERASE: 23, // Ctrl+W
+ ssh.VLNEXT: 22, // Ctrl+V
+ ssh.VREPRINT: 18, // Ctrl+R
+ }
+
+ if err := session.RequestPty("xterm-256color", h, w, modes); err != nil {
+ if restoreErr := c.restoreWindowsConsoleState(); restoreErr != nil {
+ log.Debugf("restore Windows console state: %v", restoreErr)
+ }
+ return fmt.Errorf("request pty: %w", err)
+ }
+
+ return nil
+}
+
+func (c *Client) saveWindowsConsoleState() error {
+ defer func() {
+ if r := recover(); r != nil {
+ log.Debugf("panic in saveWindowsConsoleState: %v", r)
+ }
+ }()
+
+ stdout := syscall.Handle(os.Stdout.Fd())
+ stdin := syscall.Handle(os.Stdin.Fd())
+
+ var stdoutMode, stdinMode uint32
+
+ ret, _, err := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&stdoutMode)))
+ if ret == 0 {
+ log.Debugf("failed to get stdout console mode: %v", err)
+ return &ConsoleUnavailableError{
+ Operation: "get stdout console mode",
+ Err: err,
+ }
+ }
+
+ ret, _, err = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&stdinMode)))
+ if ret == 0 {
+ log.Debugf("failed to get stdin console mode: %v", err)
+ return &ConsoleUnavailableError{
+ Operation: "get stdin console mode",
+ Err: err,
+ }
+ }
+
+ c.terminalFd = 1
+ c.windowsStdoutMode = stdoutMode
+ c.windowsStdinMode = stdinMode
+
+ log.Debugf("saved Windows console state - stdout: 0x%04x, stdin: 0x%04x", stdoutMode, stdinMode)
+ return nil
+}
+
+func (c *Client) enableWindowsVirtualTerminal() (err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ err = fmt.Errorf("panic in enableWindowsVirtualTerminal: %v", r)
+ }
+ }()
+
+ stdout := syscall.Handle(os.Stdout.Fd())
+ stdin := syscall.Handle(os.Stdin.Fd())
+ var mode uint32
+
+ ret, _, winErr := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&mode)))
+ if ret == 0 {
+ return &ConsoleUnavailableError{
+ Operation: "get stdout console mode for VT",
+ Err: winErr,
+ }
+ }
+
+ mode |= enableVirtualTerminalProcessing
+ ret, _, winErr = procSetConsoleMode.Call(uintptr(stdout), uintptr(mode))
+ if ret == 0 {
+ return &ConsoleUnavailableError{
+ Operation: "enable virtual terminal processing",
+ Err: winErr,
+ }
+ }
+
+ ret, _, winErr = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&mode)))
+ if ret == 0 {
+ return &ConsoleUnavailableError{
+ Operation: "get stdin console mode for VT",
+ Err: winErr,
+ }
+ }
+
+ mode &= ^uint32(enableLineInput | enableEchoInput | enableProcessedInput)
+ mode |= enableVirtualTerminalInput
+ ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(mode))
+ if ret == 0 {
+ return &ConsoleUnavailableError{
+ Operation: "set stdin raw mode",
+ Err: winErr,
+ }
+ }
+
+ log.Debugf("enabled Windows virtual terminal processing")
+ return nil
+}
+
+func (c *Client) getWindowsConsoleSize() (int, int) {
+ defer func() {
+ if r := recover(); r != nil {
+ log.Debugf("panic in getWindowsConsoleSize: %v", r)
+ }
+ }()
+
+ stdout := syscall.Handle(os.Stdout.Fd())
+ var csbi consoleScreenBufferInfo
+
+ ret, _, err := procGetConsoleScreenBufferInfo.Call(uintptr(stdout), uintptr(unsafe.Pointer(&csbi)))
+ if ret == 0 {
+ log.Debugf("failed to get console buffer info, using defaults: %v", err)
+ return 80, 24
+ }
+
+ width := int(csbi.window.right - csbi.window.left + 1)
+ height := int(csbi.window.bottom - csbi.window.top + 1)
+
+ log.Debugf("Windows console size: %dx%d", width, height)
+ return width, height
+}
+
+func (c *Client) restoreWindowsConsoleState() error {
+ var err error
+ defer func() {
+ if r := recover(); r != nil {
+ err = fmt.Errorf("panic in restoreWindowsConsoleState: %v", r)
+ }
+ }()
+
+ if c.terminalFd != 1 {
+ return nil
+ }
+
+ stdout := syscall.Handle(os.Stdout.Fd())
+ stdin := syscall.Handle(os.Stdin.Fd())
+
+ ret, _, winErr := procSetConsoleMode.Call(uintptr(stdout), uintptr(c.windowsStdoutMode))
+ if ret == 0 {
+ log.Debugf("failed to restore stdout console mode: %v", winErr)
+ if err == nil {
+ err = fmt.Errorf("restore stdout console mode: %w", winErr)
+ }
+ }
+
+ ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(c.windowsStdinMode))
+ if ret == 0 {
+ log.Debugf("failed to restore stdin console mode: %v", winErr)
+ if err == nil {
+ err = fmt.Errorf("restore stdin console mode: %w", winErr)
+ }
+ }
+
+ c.terminalFd = 0
+ c.windowsStdoutMode = 0
+ c.windowsStdinMode = 0
+
+ log.Debugf("restored Windows console state")
+ return err
+}
diff --git a/client/ssh/common.go b/client/ssh/common.go
new file mode 100644
index 000000000..6574437b5
--- /dev/null
+++ b/client/ssh/common.go
@@ -0,0 +1,195 @@
+package ssh
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/crypto/ssh"
+
+ "github.com/netbirdio/netbird/client/proto"
+)
+
+const (
+ NetBirdSSHConfigFile = "99-netbird.conf"
+
+ UnixSSHConfigDir = "/etc/ssh/ssh_config.d"
+ WindowsSSHConfigDir = "ssh/ssh_config.d"
+)
+
+var (
+ // ErrPeerNotFound indicates the peer was not found in the network
+ ErrPeerNotFound = errors.New("peer not found in network")
+ // ErrNoStoredKey indicates the peer has no stored SSH host key
+ ErrNoStoredKey = errors.New("peer has no stored SSH host key")
+)
+
+// HostKeyVerifier provides SSH host key verification
+type HostKeyVerifier interface {
+ VerifySSHHostKey(peerAddress string, key []byte) error
+}
+
+// DaemonHostKeyVerifier implements HostKeyVerifier using the NetBird daemon
+type DaemonHostKeyVerifier struct {
+ client proto.DaemonServiceClient
+}
+
+// NewDaemonHostKeyVerifier creates a new daemon-based host key verifier
+func NewDaemonHostKeyVerifier(client proto.DaemonServiceClient) *DaemonHostKeyVerifier {
+ return &DaemonHostKeyVerifier{
+ client: client,
+ }
+}
+
+// VerifySSHHostKey verifies an SSH host key by querying the NetBird daemon
+func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKey []byte) error {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ response, err := d.client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
+ PeerAddress: peerAddress,
+ })
+ if err != nil {
+ return err
+ }
+
+ if !response.GetFound() {
+ return ErrPeerNotFound
+ }
+
+ storedKeyData := response.GetSshHostKey()
+
+ return VerifyHostKey(storedKeyData, presentedKey, peerAddress)
+}
+
+// printAuthInstructions prints authentication instructions to stderr
+func printAuthInstructions(stderr io.Writer, authResponse *proto.RequestJWTAuthResponse, browserWillOpen bool) {
+ _, _ = fmt.Fprintln(stderr, "SSH authentication required.")
+
+ if browserWillOpen {
+ _, _ = fmt.Fprintln(stderr, "Please do the SSO login in your browser.")
+ _, _ = fmt.Fprintln(stderr, "If your browser didn't open automatically, use this URL to log in:")
+ _, _ = fmt.Fprintln(stderr)
+ }
+
+ _, _ = fmt.Fprintf(stderr, "%s\n", authResponse.VerificationURIComplete)
+
+ if authResponse.UserCode != "" {
+ _, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
+ }
+
+ if browserWillOpen {
+ _, _ = fmt.Fprintln(stderr)
+ }
+
+ _, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
+}
+
+// RequestJWTToken requests or retrieves a JWT token for SSH authentication
+func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string, openBrowser func(string) error) (string, error) {
+ req := &proto.RequestJWTAuthRequest{}
+ if hint != "" {
+ req.Hint = &hint
+ }
+ authResponse, err := client.RequestJWTAuth(ctx, req)
+ if err != nil {
+ return "", fmt.Errorf("request JWT auth: %w", err)
+ }
+
+ if useCache && authResponse.CachedToken != "" {
+ log.Debug("Using cached authentication token")
+ return authResponse.CachedToken, nil
+ }
+
+ if stderr != nil {
+ printAuthInstructions(stderr, authResponse, openBrowser != nil)
+ }
+
+ if openBrowser != nil {
+ if err := openBrowser(authResponse.VerificationURIComplete); err != nil {
+ log.Debugf("open browser: %v", err)
+ }
+ }
+
+ tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{
+ DeviceCode: authResponse.DeviceCode,
+ UserCode: authResponse.UserCode,
+ })
+ if err != nil {
+ return "", fmt.Errorf("wait for JWT token: %w", err)
+ }
+
+ if stdout != nil {
+ _, _ = fmt.Fprintln(stdout, "Authentication successful!")
+ }
+ return tokenResponse.Token, nil
+}
+
+// VerifyHostKey verifies an SSH host key against stored peer key data.
+// Returns nil only if the presented key matches the stored key.
+// Returns ErrNoStoredKey if storedKeyData is empty.
+// Returns an error if the keys don't match or if parsing fails.
+func VerifyHostKey(storedKeyData []byte, presentedKey []byte, peerAddress string) error {
+ if len(storedKeyData) == 0 {
+ return ErrNoStoredKey
+ }
+
+ storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
+ if err != nil {
+ return fmt.Errorf("parse stored SSH key for %s: %w", peerAddress, err)
+ }
+
+ if !bytes.Equal(presentedKey, storedPubKey.Marshal()) {
+ return fmt.Errorf("SSH host key mismatch for %s", peerAddress)
+ }
+
+ return nil
+}
+
+// AddJWTAuth prepends JWT password authentication to existing auth methods.
+// This ensures JWT auth is tried first while preserving any existing auth methods.
+func AddJWTAuth(config *ssh.ClientConfig, jwtToken string) *ssh.ClientConfig {
+ configWithJWT := *config
+ configWithJWT.Auth = append([]ssh.AuthMethod{ssh.Password(jwtToken)}, config.Auth...)
+ return &configWithJWT
+}
+
+// CreateHostKeyCallback creates an SSH host key verification callback using the provided verifier.
+// It tries multiple addresses (hostname, IP) for the peer before failing.
+func CreateHostKeyCallback(verifier HostKeyVerifier) ssh.HostKeyCallback {
+ return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
+ addresses := buildAddressList(hostname, remote)
+ presentedKey := key.Marshal()
+
+ for _, addr := range addresses {
+ if err := verifier.VerifySSHHostKey(addr, presentedKey); err != nil {
+ if errors.Is(err, ErrPeerNotFound) {
+ // Try other addresses for this peer
+ continue
+ }
+ return err
+ }
+ // Verified
+ return nil
+ }
+
+ return fmt.Errorf("SSH host key verification failed: peer %s not found in network", hostname)
+ }
+}
+
+// buildAddressList creates a list of addresses to check for host key verification.
+// It includes the original hostname and extracts the host part from the remote address if different.
+func buildAddressList(hostname string, remote net.Addr) []string {
+ addresses := []string{hostname}
+ if host, _, err := net.SplitHostPort(remote.String()); err == nil {
+ if host != hostname {
+ addresses = append(addresses, host)
+ }
+ }
+ return addresses
+}
diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go
new file mode 100644
index 000000000..cc47fd2d2
--- /dev/null
+++ b/client/ssh/config/manager.go
@@ -0,0 +1,277 @@
+package config
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strconv"
+ "strings"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ nbssh "github.com/netbirdio/netbird/client/ssh"
+)
+
+const (
+ EnvDisableSSHConfig = "NB_DISABLE_SSH_CONFIG"
+
+ EnvForceSSHConfig = "NB_FORCE_SSH_CONFIG"
+
+ MaxPeersForSSHConfig = 200
+
+ fileWriteTimeout = 2 * time.Second
+)
+
+func isSSHConfigDisabled() bool {
+ value := os.Getenv(EnvDisableSSHConfig)
+ if value == "" {
+ return false
+ }
+
+ disabled, err := strconv.ParseBool(value)
+ if err != nil {
+ return true
+ }
+ return disabled
+}
+
+func isSSHConfigForced() bool {
+ value := os.Getenv(EnvForceSSHConfig)
+ if value == "" {
+ return false
+ }
+
+ forced, err := strconv.ParseBool(value)
+ if err != nil {
+ return true
+ }
+ return forced
+}
+
+// shouldGenerateSSHConfig checks if SSH config should be generated based on peer count
+func shouldGenerateSSHConfig(peerCount int) bool {
+ if isSSHConfigDisabled() {
+ return false
+ }
+
+ if isSSHConfigForced() {
+ return true
+ }
+
+ return peerCount <= MaxPeersForSSHConfig
+}
+
+// writeFileWithTimeout writes data to a file with a timeout
+func writeFileWithTimeout(filename string, data []byte, perm os.FileMode) error {
+ ctx, cancel := context.WithTimeout(context.Background(), fileWriteTimeout)
+ defer cancel()
+
+ done := make(chan error, 1)
+ go func() {
+ done <- os.WriteFile(filename, data, perm)
+ }()
+
+ select {
+ case err := <-done:
+ return err
+ case <-ctx.Done():
+ return fmt.Errorf("file write timeout after %v: %s", fileWriteTimeout, filename)
+ }
+}
+
+// Manager handles SSH client configuration for NetBird peers
+type Manager struct {
+ sshConfigDir string
+ sshConfigFile string
+}
+
+// PeerSSHInfo represents a peer's SSH configuration information
+type PeerSSHInfo struct {
+ Hostname string
+ IP string
+ FQDN string
+}
+
+// New creates a new SSH config manager
+func New() *Manager {
+ sshConfigDir := getSystemSSHConfigDir()
+ return &Manager{
+ sshConfigDir: sshConfigDir,
+ sshConfigFile: nbssh.NetBirdSSHConfigFile,
+ }
+}
+
+// getSystemSSHConfigDir returns platform-specific SSH configuration directory
+func getSystemSSHConfigDir() string {
+ if runtime.GOOS == "windows" {
+ return getWindowsSSHConfigDir()
+ }
+ return nbssh.UnixSSHConfigDir
+}
+
+func getWindowsSSHConfigDir() string {
+ programData := os.Getenv("PROGRAMDATA")
+ if programData == "" {
+ programData = `C:\ProgramData`
+ }
+ return filepath.Join(programData, nbssh.WindowsSSHConfigDir)
+}
+
+// SetupSSHClientConfig creates SSH client configuration for NetBird peers
+func (m *Manager) SetupSSHClientConfig(peers []PeerSSHInfo) error {
+ if !shouldGenerateSSHConfig(len(peers)) {
+ m.logSkipReason(len(peers))
+ return nil
+ }
+
+ sshConfig, err := m.buildSSHConfig(peers)
+ if err != nil {
+ return fmt.Errorf("build SSH config: %w", err)
+ }
+ return m.writeSSHConfig(sshConfig)
+}
+
+func (m *Manager) logSkipReason(peerCount int) {
+ if isSSHConfigDisabled() {
+ log.Debugf("SSH config management disabled via %s", EnvDisableSSHConfig)
+ } else {
+ log.Infof("SSH config generation skipped: too many peers (%d > %d). Use %s=true to force.",
+ peerCount, MaxPeersForSSHConfig, EnvForceSSHConfig)
+ }
+}
+
+func (m *Manager) buildSSHConfig(peers []PeerSSHInfo) (string, error) {
+ sshConfig := m.buildConfigHeader()
+
+ var allHostPatterns []string
+ for _, peer := range peers {
+ hostPatterns := m.buildHostPatterns(peer)
+ allHostPatterns = append(allHostPatterns, hostPatterns...)
+ }
+
+ if len(allHostPatterns) > 0 {
+ peerConfig, err := m.buildPeerConfig(allHostPatterns)
+ if err != nil {
+ return "", err
+ }
+ sshConfig += peerConfig
+ }
+
+ return sshConfig, nil
+}
+
+func (m *Manager) buildConfigHeader() string {
+ return "# NetBird SSH client configuration\n" +
+ "# Generated automatically - do not edit manually\n" +
+ "#\n" +
+ "# To disable SSH config management, use:\n" +
+ "# netbird service reconfigure --service-env NB_DISABLE_SSH_CONFIG=true\n" +
+ "#\n\n"
+}
+
+func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
+ uniquePatterns := make(map[string]bool)
+ var deduplicatedPatterns []string
+ for _, pattern := range allHostPatterns {
+ if !uniquePatterns[pattern] {
+ uniquePatterns[pattern] = true
+ deduplicatedPatterns = append(deduplicatedPatterns, pattern)
+ }
+ }
+
+ execPath, err := m.getNetBirdExecutablePath()
+ if err != nil {
+ return "", fmt.Errorf("get NetBird executable path: %w", err)
+ }
+
+ hostLine := strings.Join(deduplicatedPatterns, " ")
+ config := fmt.Sprintf("Host %s\n", hostLine)
+ config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
+ config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
+ config += " PasswordAuthentication yes\n"
+ config += " PubkeyAuthentication yes\n"
+ config += " BatchMode no\n"
+ config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
+ config += " StrictHostKeyChecking no\n"
+
+ if runtime.GOOS == "windows" {
+ config += " UserKnownHostsFile NUL\n"
+ } else {
+ config += " UserKnownHostsFile /dev/null\n"
+ }
+
+ config += " CheckHostIP no\n"
+ config += " LogLevel ERROR\n\n"
+
+ return config, nil
+}
+
+func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
+ var hostPatterns []string
+ if peer.IP != "" {
+ hostPatterns = append(hostPatterns, peer.IP)
+ }
+ if peer.FQDN != "" {
+ hostPatterns = append(hostPatterns, peer.FQDN)
+ }
+ if peer.Hostname != "" && peer.Hostname != peer.FQDN {
+ hostPatterns = append(hostPatterns, peer.Hostname)
+ }
+ return hostPatterns
+}
+
+func (m *Manager) writeSSHConfig(sshConfig string) error {
+ sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
+
+ if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
+ return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
+ }
+
+ if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
+ return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
+ }
+
+ log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
+ return nil
+}
+
+// RemoveSSHClientConfig removes NetBird SSH configuration
+func (m *Manager) RemoveSSHClientConfig() error {
+ sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
+ err := os.Remove(sshConfigPath)
+ if err != nil && !os.IsNotExist(err) {
+ return fmt.Errorf("remove SSH config %s: %w", sshConfigPath, err)
+ }
+ if err == nil {
+ log.Infof("Removed NetBird SSH config: %s", sshConfigPath)
+ }
+ return nil
+}
+
+func (m *Manager) getNetBirdExecutablePath() (string, error) {
+ execPath, err := os.Executable()
+ if err != nil {
+ return "", fmt.Errorf("retrieve executable path: %w", err)
+ }
+
+ realPath, err := filepath.EvalSymlinks(execPath)
+ if err != nil {
+ log.Debugf("symlink resolution failed: %v", err)
+ return execPath, nil
+ }
+
+ return realPath, nil
+}
+
+// GetSSHConfigDir returns the SSH config directory path
+func (m *Manager) GetSSHConfigDir() string {
+ return m.sshConfigDir
+}
+
+// GetSSHConfigFile returns the SSH config file name
+func (m *Manager) GetSSHConfigFile() string {
+ return m.sshConfigFile
+}
diff --git a/client/ssh/config/manager_test.go b/client/ssh/config/manager_test.go
new file mode 100644
index 000000000..dc3ad95b3
--- /dev/null
+++ b/client/ssh/config/manager_test.go
@@ -0,0 +1,159 @@
+package config
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestManager_SetupSSHClientConfig(t *testing.T) {
+ // Create temporary directory for test
+ tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
+ require.NoError(t, err)
+ defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
+
+ // Override manager paths to use temp directory
+ manager := &Manager{
+ sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
+ sshConfigFile: "99-netbird.conf",
+ }
+
+ // Test SSH config generation with peers
+ peers := []PeerSSHInfo{
+ {
+ Hostname: "peer1",
+ IP: "100.125.1.1",
+ FQDN: "peer1.nb.internal",
+ },
+ {
+ Hostname: "peer2",
+ IP: "100.125.1.2",
+ FQDN: "peer2.nb.internal",
+ },
+ }
+
+ err = manager.SetupSSHClientConfig(peers)
+ require.NoError(t, err)
+
+ // Read generated config
+ configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
+ content, err := os.ReadFile(configPath)
+ require.NoError(t, err)
+
+ configStr := string(content)
+
+ // Verify the basic SSH config structure exists
+ assert.Contains(t, configStr, "# NetBird SSH client configuration")
+ assert.Contains(t, configStr, "Generated automatically - do not edit manually")
+
+ // Check that peer hostnames are included
+ assert.Contains(t, configStr, "100.125.1.1")
+ assert.Contains(t, configStr, "100.125.1.2")
+ assert.Contains(t, configStr, "peer1.nb.internal")
+ assert.Contains(t, configStr, "peer2.nb.internal")
+
+ // Check platform-specific UserKnownHostsFile
+ if runtime.GOOS == "windows" {
+ assert.Contains(t, configStr, "UserKnownHostsFile NUL")
+ } else {
+ assert.Contains(t, configStr, "UserKnownHostsFile /dev/null")
+ }
+}
+
+func TestGetSystemSSHConfigDir(t *testing.T) {
+ configDir := getSystemSSHConfigDir()
+
+ // Path should not be empty
+ assert.NotEmpty(t, configDir)
+
+ // Should be an absolute path
+ assert.True(t, filepath.IsAbs(configDir))
+
+ // On Unix systems, should start with /etc
+ // On Windows, should contain ProgramData
+ if runtime.GOOS == "windows" {
+ assert.Contains(t, strings.ToLower(configDir), "programdata")
+ } else {
+ assert.Contains(t, configDir, "/etc/ssh")
+ }
+}
+
+func TestManager_PeerLimit(t *testing.T) {
+ // Create temporary directory for test
+ tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
+ require.NoError(t, err)
+ defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
+
+ // Override manager paths to use temp directory
+ manager := &Manager{
+ sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
+ sshConfigFile: "99-netbird.conf",
+ }
+
+ // Generate many peers (more than limit)
+ var peers []PeerSSHInfo
+ for i := 0; i < MaxPeersForSSHConfig+10; i++ {
+ peers = append(peers, PeerSSHInfo{
+ Hostname: fmt.Sprintf("peer%d", i),
+ IP: fmt.Sprintf("100.125.1.%d", i%254+1),
+ FQDN: fmt.Sprintf("peer%d.nb.internal", i),
+ })
+ }
+
+ // Test that SSH config generation is skipped when too many peers
+ err = manager.SetupSSHClientConfig(peers)
+ require.NoError(t, err)
+
+ // Config should not be created due to peer limit
+ configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
+ _, err = os.Stat(configPath)
+ assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
+}
+
+func TestManager_ForcedSSHConfig(t *testing.T) {
+ // Set force environment variable
+ t.Setenv(EnvForceSSHConfig, "true")
+
+ // Create temporary directory for test
+ tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
+ require.NoError(t, err)
+ defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
+
+ // Override manager paths to use temp directory
+ manager := &Manager{
+ sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
+ sshConfigFile: "99-netbird.conf",
+ }
+
+ // Generate many peers (more than limit)
+ var peers []PeerSSHInfo
+ for i := 0; i < MaxPeersForSSHConfig+10; i++ {
+ peers = append(peers, PeerSSHInfo{
+ Hostname: fmt.Sprintf("peer%d", i),
+ IP: fmt.Sprintf("100.125.1.%d", i%254+1),
+ FQDN: fmt.Sprintf("peer%d.nb.internal", i),
+ })
+ }
+
+ // Test that SSH config generation is forced despite many peers
+ err = manager.SetupSSHClientConfig(peers)
+ require.NoError(t, err)
+
+ // Config should be created despite peer limit due to force flag
+ configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
+ _, err = os.Stat(configPath)
+ require.NoError(t, err, "SSH config should be created when forced")
+
+ // Verify config contains peer hostnames
+ content, err := os.ReadFile(configPath)
+ require.NoError(t, err)
+ configStr := string(content)
+ assert.Contains(t, configStr, "peer0.nb.internal")
+ assert.Contains(t, configStr, "peer1.nb.internal")
+}
diff --git a/client/ssh/config/shutdown_state.go b/client/ssh/config/shutdown_state.go
new file mode 100644
index 000000000..22f0e0678
--- /dev/null
+++ b/client/ssh/config/shutdown_state.go
@@ -0,0 +1,22 @@
+package config
+
+// ShutdownState represents SSH configuration state that needs to be cleaned up.
+type ShutdownState struct {
+ SSHConfigDir string
+ SSHConfigFile string
+}
+
+// Name returns the state name for the state manager.
+func (s *ShutdownState) Name() string {
+ return "ssh_config_state"
+}
+
+// Cleanup removes SSH client configuration files.
+func (s *ShutdownState) Cleanup() error {
+ manager := &Manager{
+ sshConfigDir: s.SSHConfigDir,
+ sshConfigFile: s.SSHConfigFile,
+ }
+
+ return manager.RemoveSSHClientConfig()
+}
diff --git a/client/ssh/detection/detection.go b/client/ssh/detection/detection.go
new file mode 100644
index 000000000..f23ea4c37
--- /dev/null
+++ b/client/ssh/detection/detection.go
@@ -0,0 +1,99 @@
+package detection
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "net"
+ "strconv"
+ "strings"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ // ServerIdentifier is the base response for NetBird SSH servers
+ ServerIdentifier = "NetBird-SSH-Server"
+ // ProxyIdentifier is the base response for NetBird SSH proxy
+ ProxyIdentifier = "NetBird-SSH-Proxy"
+ // JWTRequiredMarker is appended to responses when JWT is required
+ JWTRequiredMarker = "NetBird-JWT-Required"
+
+ // DefaultTimeout is the default timeout for SSH server detection
+ DefaultTimeout = 5 * time.Second
+)
+
+type ServerType string
+
+const (
+ ServerTypeNetBirdJWT ServerType = "netbird-jwt"
+ ServerTypeNetBirdNoJWT ServerType = "netbird-no-jwt"
+ ServerTypeRegular ServerType = "regular"
+)
+
+// Dialer provides network connection capabilities
+type Dialer interface {
+ DialContext(ctx context.Context, network, address string) (net.Conn, error)
+}
+
+// RequiresJWT checks if the server type requires JWT authentication
+func (s ServerType) RequiresJWT() bool {
+ return s == ServerTypeNetBirdJWT
+}
+
+// ExitCode returns the exit code for the detect command
+func (s ServerType) ExitCode() int {
+ switch s {
+ case ServerTypeNetBirdJWT:
+ return 0
+ case ServerTypeNetBirdNoJWT:
+ return 1
+ case ServerTypeRegular:
+ return 2
+ default:
+ return 2
+ }
+}
+
+// DetectSSHServerType detects SSH server type using the provided dialer
+func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port int) (ServerType, error) {
+ targetAddr := net.JoinHostPort(host, strconv.Itoa(port))
+
+ conn, err := dialer.DialContext(ctx, "tcp", targetAddr)
+ if err != nil {
+ return ServerTypeRegular, fmt.Errorf("connect to %s: %w", targetAddr, err)
+ }
+ defer conn.Close()
+
+ if deadline, ok := ctx.Deadline(); ok {
+ if err := conn.SetReadDeadline(deadline); err != nil {
+ return ServerTypeRegular, fmt.Errorf("set read deadline: %w", err)
+ }
+ }
+
+ reader := bufio.NewReader(conn)
+ serverBanner, err := reader.ReadString('\n')
+ if err != nil {
+ return ServerTypeRegular, fmt.Errorf("read SSH banner: %w", err)
+ }
+
+ serverBanner = strings.TrimSpace(serverBanner)
+ log.Debugf("SSH server banner: %s", serverBanner)
+
+ if !strings.HasPrefix(serverBanner, "SSH-") {
+ log.Debugf("Invalid SSH banner")
+ return ServerTypeRegular, nil
+ }
+
+ if !strings.Contains(serverBanner, ServerIdentifier) {
+ log.Debugf("Server banner does not contain identifier '%s'", ServerIdentifier)
+ return ServerTypeRegular, nil
+ }
+
+ if strings.Contains(serverBanner, JWTRequiredMarker) {
+ return ServerTypeNetBirdJWT, nil
+ }
+
+ return ServerTypeNetBirdNoJWT, nil
+}
diff --git a/client/ssh/login.go b/client/ssh/login.go
deleted file mode 100644
index cb2615e55..000000000
--- a/client/ssh/login.go
+++ /dev/null
@@ -1,53 +0,0 @@
-//go:build !js
-
-package ssh
-
-import (
- "fmt"
- "net"
- "net/netip"
- "os"
- "os/exec"
- "runtime"
-
- "github.com/netbirdio/netbird/util"
-)
-
-func isRoot() bool {
- return os.Geteuid() == 0
-}
-
-func getLoginCmd(user string, remoteAddr net.Addr) (loginPath string, args []string, err error) {
- if !isRoot() {
- shell := getUserShell(user)
- if shell == "" {
- shell = "/bin/sh"
- }
-
- return shell, []string{"-l"}, nil
- }
-
- loginPath, err = exec.LookPath("login")
- if err != nil {
- return "", nil, err
- }
-
- addrPort, err := netip.ParseAddrPort(remoteAddr.String())
- if err != nil {
- return "", nil, err
- }
-
- switch runtime.GOOS {
- case "linux":
- if util.FileExists("/etc/arch-release") && !util.FileExists("/etc/pam.d/remote") {
- return loginPath, []string{"-f", user, "-p"}, nil
- }
- return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil
- case "darwin":
- return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), user}, nil
- case "freebsd":
- return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil
- default:
- return "", nil, fmt.Errorf("unsupported platform: %s", runtime.GOOS)
- }
-}
diff --git a/client/ssh/lookup.go b/client/ssh/lookup.go
deleted file mode 100644
index 9a7f6ff2e..000000000
--- a/client/ssh/lookup.go
+++ /dev/null
@@ -1,14 +0,0 @@
-//go:build !darwin
-// +build !darwin
-
-package ssh
-
-import "os/user"
-
-func userNameLookup(username string) (*user.User, error) {
- if username == "" || (username == "root" && !isRoot()) {
- return user.Current()
- }
-
- return user.Lookup(username)
-}
diff --git a/client/ssh/lookup_darwin.go b/client/ssh/lookup_darwin.go
deleted file mode 100644
index 913d049dc..000000000
--- a/client/ssh/lookup_darwin.go
+++ /dev/null
@@ -1,51 +0,0 @@
-//go:build darwin
-// +build darwin
-
-package ssh
-
-import (
- "bytes"
- "fmt"
- "os/exec"
- "os/user"
- "strings"
-)
-
-func userNameLookup(username string) (*user.User, error) {
- if username == "" || (username == "root" && !isRoot()) {
- return user.Current()
- }
-
- var userObject *user.User
- userObject, err := user.Lookup(username)
- if err != nil && err.Error() == user.UnknownUserError(username).Error() {
- return idUserNameLookup(username)
- } else if err != nil {
- return nil, err
- }
-
- return userObject, nil
-}
-
-func idUserNameLookup(username string) (*user.User, error) {
- cmd := exec.Command("id", "-P", username)
- out, err := cmd.CombinedOutput()
- if err != nil {
- return nil, fmt.Errorf("error while retrieving user with id -P command, error: %v", err)
- }
- colon := ":"
-
- if !bytes.Contains(out, []byte(username+colon)) {
- return nil, fmt.Errorf("unable to find user in returned string")
- }
- // netbird:********:501:20::0:0:netbird:/Users/netbird:/bin/zsh
- parts := strings.SplitN(string(out), colon, 10)
- userObject := &user.User{
- Username: parts[0],
- Uid: parts[2],
- Gid: parts[3],
- Name: parts[7],
- HomeDir: parts[8],
- }
- return userObject, nil
-}
diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go
new file mode 100644
index 000000000..4e807e33c
--- /dev/null
+++ b/client/ssh/proxy/proxy.go
@@ -0,0 +1,394 @@
+package proxy
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/gliderlabs/ssh"
+ log "github.com/sirupsen/logrus"
+ cryptossh "golang.org/x/crypto/ssh"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials/insecure"
+
+ "github.com/netbirdio/netbird/client/internal/profilemanager"
+ "github.com/netbirdio/netbird/client/proto"
+ nbssh "github.com/netbirdio/netbird/client/ssh"
+ "github.com/netbirdio/netbird/client/ssh/detection"
+ "github.com/netbirdio/netbird/version"
+)
+
+const (
+ // sshConnectionTimeout is the timeout for SSH TCP connection establishment
+ sshConnectionTimeout = 120 * time.Second
+ // sshHandshakeTimeout is the timeout for SSH handshake completion
+ sshHandshakeTimeout = 30 * time.Second
+
+ jwtAuthErrorMsg = "JWT authentication: %w"
+)
+
+type SSHProxy struct {
+ daemonAddr string
+ targetHost string
+ targetPort int
+ stderr io.Writer
+ conn *grpc.ClientConn
+ daemonClient proto.DaemonServiceClient
+ browserOpener func(string) error
+}
+
+func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) {
+ grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
+ grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
+ if err != nil {
+ return nil, fmt.Errorf("connect to daemon: %w", err)
+ }
+
+ return &SSHProxy{
+ daemonAddr: daemonAddr,
+ targetHost: targetHost,
+ targetPort: targetPort,
+ stderr: stderr,
+ conn: grpcConn,
+ daemonClient: proto.NewDaemonServiceClient(grpcConn),
+ browserOpener: browserOpener,
+ }, nil
+}
+
+func (p *SSHProxy) Close() error {
+ if p.conn != nil {
+ return p.conn.Close()
+ }
+ return nil
+}
+
+func (p *SSHProxy) Connect(ctx context.Context) error {
+ hint := profilemanager.GetLoginHint()
+
+ jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint, p.browserOpener)
+ if err != nil {
+ return fmt.Errorf(jwtAuthErrorMsg, err)
+ }
+
+ return p.runProxySSHServer(ctx, jwtToken)
+}
+
+func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error {
+ serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
+
+ sshServer := &ssh.Server{
+ Handler: func(s ssh.Session) {
+ p.handleSSHSession(ctx, s, jwtToken)
+ },
+ ChannelHandlers: map[string]ssh.ChannelHandler{
+ "session": ssh.DefaultSessionHandler,
+ "direct-tcpip": p.directTCPIPHandler,
+ },
+ SubsystemHandlers: map[string]ssh.SubsystemHandler{
+ "sftp": func(s ssh.Session) {
+ p.sftpSubsystemHandler(s, jwtToken)
+ },
+ },
+ RequestHandlers: map[string]ssh.RequestHandler{
+ "tcpip-forward": p.tcpipForwardHandler,
+ "cancel-tcpip-forward": p.cancelTcpipForwardHandler,
+ },
+ Version: serverVersion,
+ }
+
+ hostKey, err := generateHostKey()
+ if err != nil {
+ return fmt.Errorf("generate host key: %w", err)
+ }
+ sshServer.HostSigners = []ssh.Signer{hostKey}
+
+ conn := &stdioConn{
+ stdin: os.Stdin,
+ stdout: os.Stdout,
+ }
+
+ sshServer.HandleConn(conn)
+
+ return nil
+}
+
+func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) {
+ targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
+
+ sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken)
+ if err != nil {
+ _, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
+ return
+ }
+ defer func() { _ = sshClient.Close() }()
+
+ serverSession, err := sshClient.NewSession()
+ if err != nil {
+ _, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
+ return
+ }
+ defer func() { _ = serverSession.Close() }()
+
+ serverSession.Stdin = session
+ serverSession.Stdout = session
+ serverSession.Stderr = session.Stderr()
+
+ ptyReq, winCh, isPty := session.Pty()
+ if isPty {
+ if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil {
+ log.Debugf("PTY request to backend: %v", err)
+ }
+
+ go func() {
+ for win := range winCh {
+ if err := serverSession.WindowChange(win.Height, win.Width); err != nil {
+ log.Debugf("window change: %v", err)
+ }
+ }
+ }()
+ }
+
+ if len(session.Command()) > 0 {
+ if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
+ log.Debugf("run command: %v", err)
+ p.handleProxyExitCode(session, err)
+ }
+ return
+ }
+
+ if err = serverSession.Shell(); err != nil {
+ log.Debugf("start shell: %v", err)
+ return
+ }
+ if err := serverSession.Wait(); err != nil {
+ log.Debugf("session wait: %v", err)
+ p.handleProxyExitCode(session, err)
+ }
+}
+
+func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
+ var exitErr *cryptossh.ExitError
+ if errors.As(err, &exitErr) {
+ if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil {
+ log.Debugf("set exit status: %v", exitErr)
+ }
+ }
+}
+
+func generateHostKey() (ssh.Signer, error) {
+ keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ if err != nil {
+ return nil, fmt.Errorf("generate ED25519 key: %w", err)
+ }
+
+ signer, err := cryptossh.ParsePrivateKey(keyPEM)
+ if err != nil {
+ return nil, fmt.Errorf("parse private key: %w", err)
+ }
+
+ return signer, nil
+}
+
+type stdioConn struct {
+ stdin io.Reader
+ stdout io.Writer
+ closed bool
+ mu sync.Mutex
+}
+
+func (c *stdioConn) Read(b []byte) (n int, err error) {
+ c.mu.Lock()
+ if c.closed {
+ c.mu.Unlock()
+ return 0, io.EOF
+ }
+ c.mu.Unlock()
+ return c.stdin.Read(b)
+}
+
+func (c *stdioConn) Write(b []byte) (n int, err error) {
+ c.mu.Lock()
+ if c.closed {
+ c.mu.Unlock()
+ return 0, io.ErrClosedPipe
+ }
+ c.mu.Unlock()
+ return c.stdout.Write(b)
+}
+
+func (c *stdioConn) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.closed = true
+ return nil
+}
+
+func (c *stdioConn) LocalAddr() net.Addr {
+ return &net.UnixAddr{Name: "stdio", Net: "unix"}
+}
+
+func (c *stdioConn) RemoteAddr() net.Addr {
+ return &net.UnixAddr{Name: "stdio", Net: "unix"}
+}
+
+func (c *stdioConn) SetDeadline(_ time.Time) error {
+ return nil
+}
+
+func (c *stdioConn) SetReadDeadline(_ time.Time) error {
+ return nil
+}
+
+func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
+ return nil
+}
+
+func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) {
+ _ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy")
+}
+
+func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
+ ctx, cancel := context.WithCancel(s.Context())
+ defer cancel()
+
+ targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
+
+ sshClient, err := p.dialBackend(ctx, targetAddr, s.User(), jwtToken)
+ if err != nil {
+ _, _ = fmt.Fprintf(s, "SSH connection failed: %v\n", err)
+ _ = s.Exit(1)
+ return
+ }
+ defer func() {
+ if err := sshClient.Close(); err != nil {
+ log.Debugf("close SSH client: %v", err)
+ }
+ }()
+
+ serverSession, err := sshClient.NewSession()
+ if err != nil {
+ _, _ = fmt.Fprintf(s, "create server session: %v\n", err)
+ _ = s.Exit(1)
+ return
+ }
+ defer func() {
+ if err := serverSession.Close(); err != nil {
+ log.Debugf("close server session: %v", err)
+ }
+ }()
+
+ stdin, stdout, err := p.setupSFTPPipes(serverSession)
+ if err != nil {
+ log.Debugf("setup SFTP pipes: %v", err)
+ _ = s.Exit(1)
+ return
+ }
+
+ if err := serverSession.RequestSubsystem("sftp"); err != nil {
+ _, _ = fmt.Fprintf(s, "SFTP subsystem request failed: %v\n", err)
+ _ = s.Exit(1)
+ return
+ }
+
+ p.runSFTPBridge(ctx, s, stdin, stdout, serverSession)
+}
+
+func (p *SSHProxy) setupSFTPPipes(serverSession *cryptossh.Session) (io.WriteCloser, io.Reader, error) {
+ stdin, err := serverSession.StdinPipe()
+ if err != nil {
+ return nil, nil, fmt.Errorf("get stdin pipe: %w", err)
+ }
+
+ stdout, err := serverSession.StdoutPipe()
+ if err != nil {
+ return nil, nil, fmt.Errorf("get stdout pipe: %w", err)
+ }
+
+ return stdin, stdout, nil
+}
+
+func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.WriteCloser, stdout io.Reader, serverSession *cryptossh.Session) {
+ copyErrCh := make(chan error, 2)
+
+ go func() {
+ _, err := io.Copy(stdin, s)
+ if err != nil {
+ log.Debugf("SFTP client to server copy: %v", err)
+ }
+ if err := stdin.Close(); err != nil {
+ log.Debugf("close stdin: %v", err)
+ }
+ copyErrCh <- err
+ }()
+
+ go func() {
+ _, err := io.Copy(s, stdout)
+ if err != nil {
+ log.Debugf("SFTP server to client copy: %v", err)
+ }
+ copyErrCh <- err
+ }()
+
+ go func() {
+ <-ctx.Done()
+ if err := serverSession.Close(); err != nil {
+ log.Debugf("force close server session on context cancellation: %v", err)
+ }
+ }()
+
+ for i := 0; i < 2; i++ {
+ if err := <-copyErrCh; err != nil && !errors.Is(err, io.EOF) {
+ log.Debugf("SFTP copy error: %v", err)
+ }
+ }
+
+ if err := serverSession.Wait(); err != nil {
+ log.Debugf("SFTP session ended: %v", err)
+ }
+}
+
+func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
+ return false, []byte("port forwarding not supported in proxy")
+}
+
+func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
+ return true, nil
+}
+
+func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
+ config := &cryptossh.ClientConfig{
+ User: user,
+ Auth: []cryptossh.AuthMethod{cryptossh.Password(jwtToken)},
+ Timeout: sshHandshakeTimeout,
+ HostKeyCallback: p.verifyHostKey,
+ }
+
+ dialer := &net.Dialer{
+ Timeout: sshConnectionTimeout,
+ }
+ conn, err := dialer.DialContext(ctx, "tcp", addr)
+ if err != nil {
+ return nil, fmt.Errorf("connect to server: %w", err)
+ }
+
+ clientConn, chans, reqs, err := cryptossh.NewClientConn(conn, addr, config)
+ if err != nil {
+ _ = conn.Close()
+ return nil, fmt.Errorf("SSH handshake: %w", err)
+ }
+
+ return cryptossh.NewClient(clientConn, chans, reqs), nil
+}
+
+func (p *SSHProxy) verifyHostKey(hostname string, remote net.Addr, key cryptossh.PublicKey) error {
+ verifier := nbssh.NewDaemonHostKeyVerifier(p.daemonClient)
+ callback := nbssh.CreateHostKeyCallback(verifier)
+ return callback(hostname, remote, key)
+}
diff --git a/client/ssh/proxy/proxy_test.go b/client/ssh/proxy/proxy_test.go
new file mode 100644
index 000000000..81d588801
--- /dev/null
+++ b/client/ssh/proxy/proxy_test.go
@@ -0,0 +1,384 @@
+package proxy
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "math/big"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "runtime"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ cryptossh "golang.org/x/crypto/ssh"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials/insecure"
+
+ "github.com/netbirdio/netbird/client/proto"
+ nbssh "github.com/netbirdio/netbird/client/ssh"
+ sshauth "github.com/netbirdio/netbird/client/ssh/auth"
+ "github.com/netbirdio/netbird/client/ssh/server"
+ "github.com/netbirdio/netbird/client/ssh/testutil"
+ nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
+ sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
+)
+
+func TestMain(m *testing.M) {
+ if len(os.Args) > 2 && os.Args[1] == "ssh" {
+ if os.Args[2] == "exec" {
+ if len(os.Args) > 3 {
+ cmd := os.Args[3]
+ if cmd == "echo" && len(os.Args) > 4 {
+ fmt.Fprintln(os.Stdout, os.Args[4])
+ os.Exit(0)
+ }
+ }
+ fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' with args: %v - preventing infinite recursion\n", os.Args)
+ os.Exit(1)
+ }
+ }
+
+ code := m.Run()
+
+ testutil.CleanupTestUsers()
+
+ os.Exit(code)
+}
+
+func TestSSHProxy_verifyHostKey(t *testing.T) {
+ t.Run("calls daemon to verify host key", func(t *testing.T) {
+ mockDaemon := startMockDaemon(t)
+ defer mockDaemon.stop()
+
+ grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
+ require.NoError(t, err)
+ defer func() { _ = grpcConn.Close() }()
+
+ proxy := &SSHProxy{
+ daemonAddr: mockDaemon.addr,
+ daemonClient: proto.NewDaemonServiceClient(grpcConn),
+ }
+
+ testKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+ testPubKey, err := nbssh.GeneratePublicKey(testKey)
+ require.NoError(t, err)
+
+ mockDaemon.setHostKey("test-host", testPubKey)
+
+ err = proxy.verifyHostKey("test-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, testPubKey))
+ assert.NoError(t, err)
+ })
+
+ t.Run("rejects unknown host key", func(t *testing.T) {
+ mockDaemon := startMockDaemon(t)
+ defer mockDaemon.stop()
+
+ grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
+ require.NoError(t, err)
+ defer func() { _ = grpcConn.Close() }()
+
+ proxy := &SSHProxy{
+ daemonAddr: mockDaemon.addr,
+ daemonClient: proto.NewDaemonServiceClient(grpcConn),
+ }
+
+ unknownKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+ unknownPubKey, err := nbssh.GeneratePublicKey(unknownKey)
+ require.NoError(t, err)
+
+ err = proxy.verifyHostKey("unknown-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, unknownPubKey))
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "peer unknown-host not found in network")
+ })
+}
+
+func TestSSHProxy_Connect(t *testing.T) {
+ if testing.Short() {
+ t.Skip("Skipping integration test in short mode")
+ }
+
+ // TODO: Windows test times out - user switching and command execution tested on Linux
+ if runtime.GOOS == "windows" {
+ t.Skip("Skipping on Windows - covered by Linux tests")
+ }
+
+ const (
+ issuer = "https://test-issuer.example.com"
+ audience = "test-audience"
+ )
+
+ jwksServer, privateKey, jwksURL := setupJWKSServer(t)
+ defer jwksServer.Close()
+
+ hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+ hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
+ require.NoError(t, err)
+
+ serverConfig := &server.Config{
+ HostKeyPEM: hostKey,
+ JWT: &server.JWTConfig{
+ Issuer: issuer,
+ Audience: audience,
+ KeysLocation: jwksURL,
+ },
+ }
+ sshServer := server.New(serverConfig)
+ sshServer.SetAllowRootLogin(true)
+
+ // Configure SSH authorization for the test user
+ testUsername := testutil.GetTestUsername(t)
+ testJWTUser := "test-username"
+ testUserHash, err := sshuserhash.HashUserID(testJWTUser)
+ require.NoError(t, err)
+
+ authConfig := &sshauth.Config{
+ UserIDClaim: sshauth.DefaultUserIDClaim,
+ AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
+ MachineUsers: map[string][]uint32{
+ testUsername: {0}, // Index 0 in AuthorizedUsers
+ },
+ }
+ sshServer.UpdateSSHAuth(authConfig)
+
+ sshServerAddr := server.StartTestServer(t, sshServer)
+ defer func() { _ = sshServer.Stop() }()
+
+ mockDaemon := startMockDaemon(t)
+ defer mockDaemon.stop()
+
+ host, portStr, err := net.SplitHostPort(sshServerAddr)
+ require.NoError(t, err)
+ port, err := strconv.Atoi(portStr)
+ require.NoError(t, err)
+
+ mockDaemon.setHostKey(host, hostPubKey)
+
+ validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
+ mockDaemon.setJWTToken(validToken)
+
+ proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
+ require.NoError(t, err)
+
+ clientConn, proxyConn := net.Pipe()
+ defer func() { _ = clientConn.Close() }()
+
+ origStdin := os.Stdin
+ origStdout := os.Stdout
+ defer func() {
+ os.Stdin = origStdin
+ os.Stdout = origStdout
+ }()
+
+ stdinReader, stdinWriter, err := os.Pipe()
+ require.NoError(t, err)
+ stdoutReader, stdoutWriter, err := os.Pipe()
+ require.NoError(t, err)
+
+ os.Stdin = stdinReader
+ os.Stdout = stdoutWriter
+
+ go func() {
+ _, _ = io.Copy(stdinWriter, proxyConn)
+ }()
+ go func() {
+ _, _ = io.Copy(proxyConn, stdoutReader)
+ }()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ connectErrCh := make(chan error, 1)
+ go func() {
+ connectErrCh <- proxyInstance.Connect(ctx)
+ }()
+
+ sshConfig := &cryptossh.ClientConfig{
+ User: testutil.GetTestUsername(t),
+ Auth: []cryptossh.AuthMethod{},
+ HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
+ Timeout: 3 * time.Second,
+ }
+
+ sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
+ require.NoError(t, err, "Should connect to proxy server")
+ defer func() { _ = sshClientConn.Close() }()
+
+ sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
+
+ session, err := sshClient.NewSession()
+ require.NoError(t, err, "Should create session through full proxy to backend")
+
+ outputCh := make(chan []byte, 1)
+ errCh := make(chan error, 1)
+ go func() {
+ output, err := session.Output("echo hello-from-proxy")
+ outputCh <- output
+ errCh <- err
+ }()
+
+ select {
+ case output := <-outputCh:
+ err := <-errCh
+ require.NoError(t, err, "Command should execute successfully through proxy")
+ assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
+ case <-time.After(3 * time.Second):
+ t.Fatal("Command execution timed out")
+ }
+
+ _ = session.Close()
+ _ = sshClient.Close()
+ _ = clientConn.Close()
+ cancel()
+}
+
+type mockDaemonServer struct {
+ proto.UnimplementedDaemonServiceServer
+ hostKeys map[string][]byte
+ jwtToken string
+}
+
+func (m *mockDaemonServer) GetPeerSSHHostKey(ctx context.Context, req *proto.GetPeerSSHHostKeyRequest) (*proto.GetPeerSSHHostKeyResponse, error) {
+ key, found := m.hostKeys[req.PeerAddress]
+ return &proto.GetPeerSSHHostKeyResponse{
+ Found: found,
+ SshHostKey: key,
+ }, nil
+}
+
+func (m *mockDaemonServer) RequestJWTAuth(ctx context.Context, req *proto.RequestJWTAuthRequest) (*proto.RequestJWTAuthResponse, error) {
+ return &proto.RequestJWTAuthResponse{
+ CachedToken: m.jwtToken,
+ }, nil
+}
+
+func (m *mockDaemonServer) WaitJWTToken(ctx context.Context, req *proto.WaitJWTTokenRequest) (*proto.WaitJWTTokenResponse, error) {
+ return &proto.WaitJWTTokenResponse{
+ Token: m.jwtToken,
+ }, nil
+}
+
+type mockDaemon struct {
+ addr string
+ server *grpc.Server
+ impl *mockDaemonServer
+}
+
+func startMockDaemon(t *testing.T) *mockDaemon {
+ t.Helper()
+
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+
+ impl := &mockDaemonServer{
+ hostKeys: make(map[string][]byte),
+ jwtToken: "test-jwt-token",
+ }
+
+ grpcServer := grpc.NewServer()
+ proto.RegisterDaemonServiceServer(grpcServer, impl)
+
+ go func() {
+ _ = grpcServer.Serve(listener)
+ }()
+
+ return &mockDaemon{
+ addr: listener.Addr().String(),
+ server: grpcServer,
+ impl: impl,
+ }
+}
+
+func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
+ m.impl.hostKeys[addr] = pubKey
+}
+
+func (m *mockDaemon) setJWTToken(token string) {
+ m.impl.jwtToken = token
+}
+
+func (m *mockDaemon) stop() {
+ if m.server != nil {
+ m.server.Stop()
+ }
+}
+
+func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
+ t.Helper()
+ pubKey, _, _, _, err := cryptossh.ParseAuthorizedKey(pubKeyBytes)
+ require.NoError(t, err)
+ return pubKey
+}
+
+func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
+ t.Helper()
+ privateKey, jwksJSON := generateTestJWKS(t)
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ if _, err := w.Write(jwksJSON); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
+ }))
+
+ return server, privateKey, server.URL
+}
+
+func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
+ t.Helper()
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ publicKey := &privateKey.PublicKey
+ n := publicKey.N.Bytes()
+ e := publicKey.E
+
+ jwk := nbjwt.JSONWebKey{
+ Kty: "RSA",
+ Kid: "test-key-id",
+ Use: "sig",
+ N: base64.RawURLEncoding.EncodeToString(n),
+ E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
+ }
+
+ jwks := nbjwt.Jwks{
+ Keys: []nbjwt.JSONWebKey{jwk},
+ }
+
+ jwksJSON, err := json.Marshal(jwks)
+ require.NoError(t, err)
+
+ return privateKey, jwksJSON
+}
+
+func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
+ t.Helper()
+ claims := jwt.MapClaims{
+ "iss": issuer,
+ "aud": audience,
+ "sub": user,
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ }
+
+ token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
+ token.Header["kid"] = "test-key-id"
+
+ tokenString, err := token.SignedString(privateKey)
+ require.NoError(t, err)
+
+ return tokenString
+}
diff --git a/client/ssh/server.go b/client/ssh/server.go
deleted file mode 100644
index 8c5db2547..000000000
--- a/client/ssh/server.go
+++ /dev/null
@@ -1,280 +0,0 @@
-//go:build !js
-
-package ssh
-
-import (
- "fmt"
- "io"
- "net"
- "os"
- "os/exec"
- "os/user"
- "runtime"
- "strings"
- "sync"
- "time"
-
- "github.com/creack/pty"
- "github.com/gliderlabs/ssh"
- log "github.com/sirupsen/logrus"
-)
-
-// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
-const DefaultSSHPort = 44338
-
-// TerminalTimeout is the timeout for terminal session to be ready
-const TerminalTimeout = 10 * time.Second
-
-// TerminalBackoffDelay is the delay between terminal session readiness checks
-const TerminalBackoffDelay = 500 * time.Millisecond
-
-// DefaultSSHServer is a function that creates DefaultServer
-func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
- return newDefaultServer(hostKeyPEM, addr)
-}
-
-// Server is an interface of SSH server
-type Server interface {
- // Stop stops SSH server.
- Stop() error
- // Start starts SSH server. Blocking
- Start() error
- // RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys
- RemoveAuthorizedKey(peer string)
- // AddAuthorizedKey add a given peer key to server authorized keys
- AddAuthorizedKey(peer, newKey string) error
-}
-
-// DefaultServer is the embedded NetBird SSH server
-type DefaultServer struct {
- listener net.Listener
- // authorizedKeys is ssh pub key indexed by peer WireGuard public key
- authorizedKeys map[string]ssh.PublicKey
- mu sync.Mutex
- hostKeyPEM []byte
- sessions []ssh.Session
-}
-
-// newDefaultServer creates new server with provided host key
-func newDefaultServer(hostKeyPEM []byte, addr string) (*DefaultServer, error) {
- ln, err := net.Listen("tcp", addr)
- if err != nil {
- return nil, err
- }
- allowedKeys := make(map[string]ssh.PublicKey)
- return &DefaultServer{listener: ln, mu: sync.Mutex{}, hostKeyPEM: hostKeyPEM, authorizedKeys: allowedKeys, sessions: make([]ssh.Session, 0)}, nil
-}
-
-// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys
-func (srv *DefaultServer) RemoveAuthorizedKey(peer string) {
- srv.mu.Lock()
- defer srv.mu.Unlock()
-
- delete(srv.authorizedKeys, peer)
-}
-
-// AddAuthorizedKey add a given peer key to server authorized keys
-func (srv *DefaultServer) AddAuthorizedKey(peer, newKey string) error {
- srv.mu.Lock()
- defer srv.mu.Unlock()
-
- parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(newKey))
- if err != nil {
- return err
- }
-
- srv.authorizedKeys[peer] = parsedKey
- return nil
-}
-
-// Stop stops SSH server.
-func (srv *DefaultServer) Stop() error {
- srv.mu.Lock()
- defer srv.mu.Unlock()
- err := srv.listener.Close()
- if err != nil {
- return err
- }
- for _, session := range srv.sessions {
- err := session.Close()
- if err != nil {
- log.Warnf("failed closing SSH session from %v", err)
- }
- }
-
- return nil
-}
-
-func (srv *DefaultServer) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
- srv.mu.Lock()
- defer srv.mu.Unlock()
-
- for _, allowed := range srv.authorizedKeys {
- if ssh.KeysEqual(allowed, key) {
- return true
- }
- }
-
- return false
-}
-
-func prepareUserEnv(user *user.User, shell string) []string {
- return []string{
- fmt.Sprint("SHELL=" + shell),
- fmt.Sprint("USER=" + user.Username),
- fmt.Sprint("HOME=" + user.HomeDir),
- }
-}
-
-func acceptEnv(s string) bool {
- split := strings.Split(s, "=")
- if len(split) != 2 {
- return false
- }
- return split[0] == "TERM" || split[0] == "LANG" || strings.HasPrefix(split[0], "LC_")
-}
-
-// sessionHandler handles SSH session post auth
-func (srv *DefaultServer) sessionHandler(session ssh.Session) {
- srv.mu.Lock()
- srv.sessions = append(srv.sessions, session)
- srv.mu.Unlock()
-
- defer func() {
- err := session.Close()
- if err != nil {
- return
- }
- }()
-
- log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String())
-
- localUser, err := userNameLookup(session.User())
- if err != nil {
- _, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
- err = session.Exit(1)
- if err != nil {
- return
- }
- log.Warnf("failed SSH session from %v, user %s", session.RemoteAddr(), session.User())
- return
- }
-
- ptyReq, winCh, isPty := session.Pty()
- if isPty {
- loginCmd, loginArgs, err := getLoginCmd(localUser.Username, session.RemoteAddr())
- if err != nil {
- log.Warnf("failed logging-in user %s from remote IP %s", localUser.Username, session.RemoteAddr().String())
- return
- }
- cmd := exec.Command(loginCmd, loginArgs...)
- go func() {
- <-session.Context().Done()
- if cmd.Process == nil {
- return
- }
- err := cmd.Process.Kill()
- if err != nil {
- log.Debugf("failed killing SSH process %v", err)
- return
- }
- }()
- cmd.Dir = localUser.HomeDir
- cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
- cmd.Env = append(cmd.Env, prepareUserEnv(localUser, getUserShell(localUser.Uid))...)
- for _, v := range session.Environ() {
- if acceptEnv(v) {
- cmd.Env = append(cmd.Env, v)
- }
- }
-
- log.Debugf("Login command: %s", cmd.String())
- file, err := pty.Start(cmd)
- if err != nil {
- log.Errorf("failed starting SSH server: %v", err)
- }
-
- go func() {
- for win := range winCh {
- setWinSize(file, win.Width, win.Height)
- }
- }()
-
- srv.stdInOut(file, session)
-
- err = cmd.Wait()
- if err != nil {
- return
- }
- } else {
- _, err := io.WriteString(session, "only PTY is supported.\n")
- if err != nil {
- return
- }
- err = session.Exit(1)
- if err != nil {
- return
- }
- }
- log.Debugf("SSH session ended")
-}
-
-func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
- go func() {
- // stdin
- _, err := io.Copy(file, session)
- if err != nil {
- _ = session.Exit(1)
- return
- }
- }()
-
- // AWS Linux 2 machines need some time to open the terminal so we need to wait for it
- timer := time.NewTimer(TerminalTimeout)
- for {
- select {
- case <-timer.C:
- _, _ = session.Write([]byte("Reached timeout while opening connection\n"))
- _ = session.Exit(1)
- return
- default:
- // stdout
- writtenBytes, err := io.Copy(session, file)
- if err != nil && writtenBytes != 0 {
- _ = session.Exit(0)
- return
- }
- time.Sleep(TerminalBackoffDelay)
- }
- }
-}
-
-// Start starts SSH server. Blocking
-func (srv *DefaultServer) Start() error {
- log.Infof("starting SSH server on addr: %s", srv.listener.Addr().String())
-
- publicKeyOption := ssh.PublicKeyAuth(srv.publicKeyHandler)
- hostKeyPEM := ssh.HostKeyPEM(srv.hostKeyPEM)
- err := ssh.Serve(srv.listener, srv.sessionHandler, publicKeyOption, hostKeyPEM)
- if err != nil {
- return err
- }
-
- return nil
-}
-
-func getUserShell(userID string) string {
- if runtime.GOOS == "linux" {
- output, _ := exec.Command("getent", "passwd", userID).Output()
- line := strings.SplitN(string(output), ":", 10)
- if len(line) > 6 {
- return strings.TrimSpace(line[6])
- }
- }
-
- shell := os.Getenv("SHELL")
- if shell == "" {
- shell = "/bin/sh"
- }
- return shell
-}
diff --git a/client/ssh/server/command_execution.go b/client/ssh/server/command_execution.go
new file mode 100644
index 000000000..7a01ce4f6
--- /dev/null
+++ b/client/ssh/server/command_execution.go
@@ -0,0 +1,206 @@
+package server
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "time"
+
+ "github.com/gliderlabs/ssh"
+ log "github.com/sirupsen/logrus"
+)
+
+// handleCommand executes an SSH command with privilege validation
+func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) {
+ hasPty := winCh != nil
+
+ commandType := "command"
+ if hasPty {
+ commandType = "Pty command"
+ }
+
+ logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
+
+ execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
+ if err != nil {
+ logger.Errorf("%s creation failed: %v", commandType, err)
+
+ errorMsg := fmt.Sprintf("Cannot create %s - platform may not support user switching", commandType)
+ if hasPty {
+ errorMsg += " with Pty"
+ }
+ errorMsg += "\n"
+
+ if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
+ logger.Debugf(errWriteSession, writeErr)
+ }
+ if err := session.Exit(1); err != nil {
+ logSessionExitError(logger, err)
+ }
+ return
+ }
+
+ if !hasPty {
+ if s.executeCommand(logger, session, execCmd, cleanup) {
+ logger.Debugf("%s execution completed", commandType)
+ }
+ return
+ }
+
+ defer cleanup()
+
+ ptyReq, _, _ := session.Pty()
+ if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
+ logger.Debugf("%s execution completed", commandType)
+ }
+}
+
+func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
+ localUser := privilegeResult.User
+ if localUser == nil {
+ return nil, nil, errors.New("no user in privilege result")
+ }
+
+ // If PTY requested but su doesn't support --pty, skip su and use executor
+ // This ensures PTY functionality is provided (executor runs within our allocated PTY)
+ if hasPty && !s.suSupportsPty {
+ log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
+ cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
+ if err != nil {
+ return nil, nil, fmt.Errorf("create command with privileges: %w", err)
+ }
+ cmd.Env = s.prepareCommandEnv(localUser, session)
+ return cmd, cleanup, nil
+ }
+
+ // Try su first for system integration (PAM/audit) when privileged
+ cmd, err := s.createSuCommand(session, localUser, hasPty)
+ if err != nil || privilegeResult.UsedFallback {
+ log.Debugf("su command failed, falling back to executor: %v", err)
+ cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
+ if err != nil {
+ return nil, nil, fmt.Errorf("create command with privileges: %w", err)
+ }
+ cmd.Env = s.prepareCommandEnv(localUser, session)
+ return cmd, cleanup, nil
+ }
+
+ cmd.Env = s.prepareCommandEnv(localUser, session)
+ return cmd, func() {}, nil
+}
+
+// executeCommand executes the command and handles I/O and exit codes
+func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, cleanup func()) bool {
+ defer cleanup()
+
+ s.setupProcessGroup(execCmd)
+
+ stdinPipe, err := execCmd.StdinPipe()
+ if err != nil {
+ logger.Errorf("create stdin pipe: %v", err)
+ if err := session.Exit(1); err != nil {
+ logSessionExitError(logger, err)
+ }
+ return false
+ }
+
+ execCmd.Stdout = session
+ execCmd.Stderr = session.Stderr()
+
+ if execCmd.Dir != "" {
+ if _, err := os.Stat(execCmd.Dir); err != nil {
+ logger.Warnf("working directory does not exist: %s (%v)", execCmd.Dir, err)
+ execCmd.Dir = "/"
+ }
+ }
+
+ if err := execCmd.Start(); err != nil {
+ logger.Errorf("command start failed: %v", err)
+ // no user message for exec failure, just exit
+ if err := session.Exit(1); err != nil {
+ logSessionExitError(logger, err)
+ }
+ return false
+ }
+
+ go s.handleCommandIO(logger, stdinPipe, session)
+ return s.waitForCommandCleanup(logger, session, execCmd)
+}
+
+// handleCommandIO manages stdin/stdout copying in a goroutine
+func (s *Server) handleCommandIO(logger *log.Entry, stdinPipe io.WriteCloser, session ssh.Session) {
+ defer func() {
+ if err := stdinPipe.Close(); err != nil {
+ logger.Debugf("stdin pipe close error: %v", err)
+ }
+ }()
+ if _, err := io.Copy(stdinPipe, session); err != nil {
+ logger.Debugf("stdin copy error: %v", err)
+ }
+}
+
+// waitForCommandCleanup waits for command completion with session disconnect handling
+func (s *Server) waitForCommandCleanup(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd) bool {
+ ctx := session.Context()
+ done := make(chan error, 1)
+ go func() {
+ done <- execCmd.Wait()
+ }()
+
+ select {
+ case <-ctx.Done():
+ logger.Debugf("session cancelled, terminating command")
+ s.killProcessGroup(execCmd)
+
+ select {
+ case err := <-done:
+ logger.Tracef("command terminated after session cancellation: %v", err)
+ case <-time.After(5 * time.Second):
+ logger.Warnf("command did not terminate within 5 seconds after session cancellation")
+ }
+
+ if err := session.Exit(130); err != nil {
+ logSessionExitError(logger, err)
+ }
+ return false
+
+ case err := <-done:
+ return s.handleCommandCompletion(logger, session, err)
+ }
+}
+
+// handleCommandCompletion handles command completion
+func (s *Server) handleCommandCompletion(logger *log.Entry, session ssh.Session, err error) bool {
+ if err != nil {
+ logger.Debugf("command execution failed: %v", err)
+ s.handleSessionExit(session, err, logger)
+ return false
+ }
+
+ s.handleSessionExit(session, nil, logger)
+ return true
+}
+
+// handleSessionExit handles command errors and sets appropriate exit codes
+func (s *Server) handleSessionExit(session ssh.Session, err error, logger *log.Entry) {
+ if err == nil {
+ if err := session.Exit(0); err != nil {
+ logSessionExitError(logger, err)
+ }
+ return
+ }
+
+ var exitError *exec.ExitError
+ if errors.As(err, &exitError) {
+ if err := session.Exit(exitError.ExitCode()); err != nil {
+ logSessionExitError(logger, err)
+ }
+ } else {
+ logger.Debugf("non-exit error in command execution: %v", err)
+ if err := session.Exit(1); err != nil {
+ logSessionExitError(logger, err)
+ }
+ }
+}
diff --git a/client/ssh/server/command_execution_js.go b/client/ssh/server/command_execution_js.go
new file mode 100644
index 000000000..01759a337
--- /dev/null
+++ b/client/ssh/server/command_execution_js.go
@@ -0,0 +1,57 @@
+//go:build js
+
+package server
+
+import (
+ "context"
+ "errors"
+ "os/exec"
+ "os/user"
+
+ "github.com/gliderlabs/ssh"
+ log "github.com/sirupsen/logrus"
+)
+
+var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
+
+// createSuCommand is not supported on JS/WASM
+func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
+ return nil, errNotSupported
+}
+
+// createExecutorCommand is not supported on JS/WASM
+func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
+ return nil, nil, errNotSupported
+}
+
+// prepareCommandEnv is not supported on JS/WASM
+func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string {
+ return nil
+}
+
+// setupProcessGroup is not supported on JS/WASM
+func (s *Server) setupProcessGroup(_ *exec.Cmd) {
+}
+
+// killProcessGroup is not supported on JS/WASM
+func (s *Server) killProcessGroup(*exec.Cmd) {
+}
+
+// detectSuPtySupport always returns false on JS/WASM
+func (s *Server) detectSuPtySupport(context.Context) bool {
+ return false
+}
+
+// detectUtilLinuxLogin always returns false on JS/WASM
+func (s *Server) detectUtilLinuxLogin(context.Context) bool {
+ return false
+}
+
+// executeCommandWithPty is not supported on JS/WASM
+func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
+ logger.Errorf("PTY command execution not supported on JS/WASM")
+ if err := session.Exit(1); err != nil {
+ logSessionExitError(logger, err)
+ }
+ return false
+}
diff --git a/client/ssh/server/command_execution_unix.go b/client/ssh/server/command_execution_unix.go
new file mode 100644
index 000000000..db1a9bcfe
--- /dev/null
+++ b/client/ssh/server/command_execution_unix.go
@@ -0,0 +1,353 @@
+//go:build unix
+
+package server
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "os/user"
+ "runtime"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+
+ "github.com/creack/pty"
+ "github.com/gliderlabs/ssh"
+ log "github.com/sirupsen/logrus"
+)
+
+// ptyManager manages Pty file operations with thread safety
+type ptyManager struct {
+ file *os.File
+ mu sync.RWMutex
+ closed bool
+ closeErr error
+ once sync.Once
+}
+
+func newPtyManager(file *os.File) *ptyManager {
+ return &ptyManager{file: file}
+}
+
+func (pm *ptyManager) Close() error {
+ pm.once.Do(func() {
+ pm.mu.Lock()
+ pm.closed = true
+ pm.closeErr = pm.file.Close()
+ pm.mu.Unlock()
+ })
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
+ return pm.closeErr
+}
+
+func (pm *ptyManager) Setsize(ws *pty.Winsize) error {
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
+ if pm.closed {
+ return errors.New("pty is closed")
+ }
+ return pty.Setsize(pm.file, ws)
+}
+
+func (pm *ptyManager) File() *os.File {
+ return pm.file
+}
+
+// detectSuPtySupport checks if su supports the --pty flag
+func (s *Server) detectSuPtySupport(ctx context.Context) bool {
+ ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
+ defer cancel()
+
+ cmd := exec.CommandContext(ctx, "su", "--help")
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ log.Debugf("su --help failed (may not support --help): %v", err)
+ return false
+ }
+
+ supported := strings.Contains(string(output), "--pty")
+ log.Debugf("su --pty support detected: %v", supported)
+ return supported
+}
+
+// detectUtilLinuxLogin checks if login is from util-linux (vs shadow-utils).
+// util-linux login uses vhangup() which requires setsid wrapper to avoid killing parent.
+// See https://bugs.debian.org/1078023 for details.
+func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
+ if runtime.GOOS != "linux" {
+ return false
+ }
+
+ ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
+ defer cancel()
+
+ cmd := exec.CommandContext(ctx, "login", "--version")
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ log.Debugf("login --version failed (likely shadow-utils): %v", err)
+ return false
+ }
+
+ isUtilLinux := strings.Contains(string(output), "util-linux")
+ log.Debugf("util-linux login detected: %v", isUtilLinux)
+ return isUtilLinux
+}
+
+// createSuCommand creates a command using su -l -c for privilege switching
+func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
+ suPath, err := exec.LookPath("su")
+ if err != nil {
+ return nil, fmt.Errorf("su command not available: %w", err)
+ }
+
+ command := session.RawCommand()
+ if command == "" {
+ return nil, fmt.Errorf("no command specified for su execution")
+ }
+
+ args := []string{"-l"}
+ if hasPty && s.suSupportsPty {
+ args = append(args, "--pty")
+ }
+ args = append(args, localUser.Username, "-c", command)
+
+ cmd := exec.CommandContext(session.Context(), suPath, args...)
+ cmd.Dir = localUser.HomeDir
+
+ return cmd, nil
+}
+
+// getShellCommandArgs returns the shell command and arguments for executing a command string
+func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
+ if cmdString == "" {
+ return []string{shell, "-l"}
+ }
+ return []string{shell, "-l", "-c", cmdString}
+}
+
+// prepareCommandEnv prepares environment variables for command execution on Unix
+func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
+ env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
+ env = append(env, prepareSSHEnv(session)...)
+ for _, v := range session.Environ() {
+ if acceptEnv(v) {
+ env = append(env, v)
+ }
+ }
+ return env
+}
+
+// executeCommandWithPty executes a command with PTY allocation
+func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
+ termType := ptyReq.Term
+ if termType == "" {
+ termType = "xterm-256color"
+ }
+ execCmd.Env = append(execCmd.Env, fmt.Sprintf("TERM=%s", termType))
+
+ return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
+}
+
+func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
+ execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
+ if err != nil {
+ logger.Errorf("Pty command creation failed: %v", err)
+ errorMsg := "User switching failed - login command not available\r\n"
+ if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
+ logger.Debugf(errWriteSession, writeErr)
+ }
+ if err := session.Exit(1); err != nil {
+ logSessionExitError(logger, err)
+ }
+ return false
+ }
+
+ logger.Infof("starting interactive shell: %s", strings.Join(execCmd.Args, " "))
+ return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
+}
+
+// runPtyCommand runs a command with PTY management (common code for interactive and command execution)
+func (s *Server) runPtyCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
+ ptmx, err := s.startPtyCommandWithSize(execCmd, ptyReq)
+ if err != nil {
+ logger.Errorf("Pty start failed: %v", err)
+ if err := session.Exit(1); err != nil {
+ logSessionExitError(logger, err)
+ }
+ return false
+ }
+
+ ptyMgr := newPtyManager(ptmx)
+ defer func() {
+ if err := ptyMgr.Close(); err != nil {
+ logger.Debugf("Pty close error: %v", err)
+ }
+ }()
+
+ go s.handlePtyWindowResize(logger, session, ptyMgr, winCh)
+ s.handlePtyIO(logger, session, ptyMgr)
+ s.waitForPtyCompletion(logger, session, execCmd, ptyMgr)
+ return true
+}
+
+func (s *Server) startPtyCommandWithSize(execCmd *exec.Cmd, ptyReq ssh.Pty) (*os.File, error) {
+ winSize := &pty.Winsize{
+ Cols: uint16(ptyReq.Window.Width),
+ Rows: uint16(ptyReq.Window.Height),
+ }
+ if winSize.Cols == 0 {
+ winSize.Cols = 80
+ }
+ if winSize.Rows == 0 {
+ winSize.Rows = 24
+ }
+
+ ptmx, err := pty.StartWithSize(execCmd, winSize)
+ if err != nil {
+ return nil, fmt.Errorf("start Pty: %w", err)
+ }
+
+ return ptmx, nil
+}
+
+func (s *Server) handlePtyWindowResize(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, winCh <-chan ssh.Window) {
+ for {
+ select {
+ case <-session.Context().Done():
+ return
+ case win, ok := <-winCh:
+ if !ok {
+ return
+ }
+ if err := ptyMgr.Setsize(&pty.Winsize{Rows: uint16(win.Height), Cols: uint16(win.Width)}); err != nil {
+ logger.Debugf("Pty resize to %dx%d: %v", win.Width, win.Height, err)
+ }
+ }
+ }
+}
+
+func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager) {
+ ptmx := ptyMgr.File()
+
+ go func() {
+ if _, err := io.Copy(ptmx, session); err != nil {
+ if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
+ logger.Warnf("Pty input copy error: %v", err)
+ }
+ }
+ }()
+
+ go func() {
+ defer func() {
+ if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
+ logger.Debugf("session close error: %v", err)
+ }
+ }()
+ if _, err := io.Copy(session, ptmx); err != nil {
+ if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
+ logger.Warnf("Pty output copy error: %v", err)
+ }
+ }
+ }()
+}
+
+func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager) {
+ ctx := session.Context()
+ done := make(chan error, 1)
+ go func() {
+ done <- execCmd.Wait()
+ }()
+
+ select {
+ case <-ctx.Done():
+ s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
+ case err := <-done:
+ s.handlePtyCommandCompletion(logger, session, err)
+ }
+}
+
+func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager, done <-chan error) {
+ logger.Debugf("Pty session cancelled, terminating command")
+ if err := ptyMgr.Close(); err != nil {
+ logger.Debugf("Pty close during session cancellation: %v", err)
+ }
+
+ s.killProcessGroup(execCmd)
+
+ select {
+ case err := <-done:
+ if err != nil {
+ logger.Debugf("Pty command terminated after session cancellation with error: %v", err)
+ } else {
+ logger.Debugf("Pty command terminated after session cancellation")
+ }
+ case <-time.After(5 * time.Second):
+ logger.Warnf("Pty command did not terminate within 5 seconds after session cancellation")
+ }
+
+ if err := session.Exit(130); err != nil {
+ logSessionExitError(logger, err)
+ }
+}
+
+func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) {
+ if err != nil {
+ logger.Debugf("Pty command execution failed: %v", err)
+ s.handleSessionExit(session, err, logger)
+ return
+ }
+
+ // Normal completion
+ logger.Debugf("Pty command completed successfully")
+ if err := session.Exit(0); err != nil {
+ logSessionExitError(logger, err)
+ }
+}
+
+func (s *Server) setupProcessGroup(cmd *exec.Cmd) {
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Setpgid: true,
+ }
+}
+
+func (s *Server) killProcessGroup(cmd *exec.Cmd) {
+ if cmd.Process == nil {
+ return
+ }
+
+ logger := log.WithField("pid", cmd.Process.Pid)
+ pgid := cmd.Process.Pid
+
+ if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
+ logger.Debugf("kill process group SIGTERM: %v", err)
+ return
+ }
+
+ const gracePeriod = 500 * time.Millisecond
+ const checkInterval = 50 * time.Millisecond
+
+ ticker := time.NewTicker(checkInterval)
+ defer ticker.Stop()
+
+ timeout := time.After(gracePeriod)
+
+ for {
+ select {
+ case <-timeout:
+ if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil {
+ logger.Debugf("kill process group SIGKILL: %v", err)
+ }
+ return
+ case <-ticker.C:
+ if err := syscall.Kill(-pgid, 0); err != nil {
+ return
+ }
+ }
+ }
+}
diff --git a/client/ssh/server/command_execution_windows.go b/client/ssh/server/command_execution_windows.go
new file mode 100644
index 000000000..998796871
--- /dev/null
+++ b/client/ssh/server/command_execution_windows.go
@@ -0,0 +1,435 @@
+package server
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "os/exec"
+ "os/user"
+ "path/filepath"
+ "strings"
+ "unsafe"
+
+ "github.com/gliderlabs/ssh"
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/registry"
+
+ "github.com/netbirdio/netbird/client/ssh/server/winpty"
+)
+
+// getUserEnvironment retrieves the Windows environment for the target user.
+// Follows OpenSSH's resilient approach with graceful degradation on failures.
+func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
+ userToken, err := s.getUserToken(username, domain)
+ if err != nil {
+ return nil, fmt.Errorf("get user token: %w", err)
+ }
+ defer func() {
+ if err := windows.CloseHandle(userToken); err != nil {
+ log.Debugf("close user token: %v", err)
+ }
+ }()
+
+ return s.getUserEnvironmentWithToken(userToken, username, domain)
+}
+
+// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
+func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
+ userProfile, err := s.loadUserProfile(userToken, username, domain)
+ if err != nil {
+ log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
+ userProfile = fmt.Sprintf("C:\\Users\\%s", username)
+ }
+
+ envMap := make(map[string]string)
+
+ if err := s.loadSystemEnvironment(envMap); err != nil {
+ log.Debugf("failed to load system environment from registry: %v", err)
+ }
+
+ s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
+
+ var env []string
+ for key, value := range envMap {
+ env = append(env, key+"="+value)
+ }
+
+ return env, nil
+}
+
+// getUserToken creates a user token for the specified user.
+func (s *Server) getUserToken(username, domain string) (windows.Handle, error) {
+ privilegeDropper := NewPrivilegeDropper()
+ token, err := privilegeDropper.createToken(username, domain)
+ if err != nil {
+ return 0, fmt.Errorf("generate S4U user token: %w", err)
+ }
+ return token, nil
+}
+
+// loadUserProfile loads the Windows user profile and returns the profile path.
+func (s *Server) loadUserProfile(userToken windows.Handle, username, domain string) (string, error) {
+ usernamePtr, err := windows.UTF16PtrFromString(username)
+ if err != nil {
+ return "", fmt.Errorf("convert username to UTF-16: %w", err)
+ }
+
+ var domainUTF16 *uint16
+ if domain != "" && domain != "." {
+ domainUTF16, err = windows.UTF16PtrFromString(domain)
+ if err != nil {
+ return "", fmt.Errorf("convert domain to UTF-16: %w", err)
+ }
+ }
+
+ type profileInfo struct {
+ dwSize uint32
+ dwFlags uint32
+ lpUserName *uint16
+ lpProfilePath *uint16
+ lpDefaultPath *uint16
+ lpServerName *uint16
+ lpPolicyPath *uint16
+ hProfile windows.Handle
+ }
+
+ const PI_NOUI = 0x00000001
+
+ profile := profileInfo{
+ dwSize: uint32(unsafe.Sizeof(profileInfo{})),
+ dwFlags: PI_NOUI,
+ lpUserName: usernamePtr,
+ lpServerName: domainUTF16,
+ }
+
+ userenv := windows.NewLazySystemDLL("userenv.dll")
+ loadUserProfileW := userenv.NewProc("LoadUserProfileW")
+
+ ret, _, err := loadUserProfileW.Call(
+ uintptr(userToken),
+ uintptr(unsafe.Pointer(&profile)),
+ )
+
+ if ret == 0 {
+ return "", fmt.Errorf("LoadUserProfileW: %w", err)
+ }
+
+ if profile.lpProfilePath == nil {
+ return "", fmt.Errorf("LoadUserProfileW returned null profile path")
+ }
+
+ profilePath := windows.UTF16PtrToString(profile.lpProfilePath)
+ return profilePath, nil
+}
+
+// loadSystemEnvironment loads system-wide environment variables from registry.
+func (s *Server) loadSystemEnvironment(envMap map[string]string) error {
+ key, err := registry.OpenKey(registry.LOCAL_MACHINE,
+ `SYSTEM\CurrentControlSet\Control\Session Manager\Environment`,
+ registry.QUERY_VALUE)
+ if err != nil {
+ return fmt.Errorf("open system environment registry key: %w", err)
+ }
+ defer func() {
+ if err := key.Close(); err != nil {
+ log.Debugf("close registry key: %v", err)
+ }
+ }()
+
+ return s.readRegistryEnvironment(key, envMap)
+}
+
+// readRegistryEnvironment reads environment variables from a registry key.
+func (s *Server) readRegistryEnvironment(key registry.Key, envMap map[string]string) error {
+ names, err := key.ReadValueNames(0)
+ if err != nil {
+ return fmt.Errorf("read registry value names: %w", err)
+ }
+
+ for _, name := range names {
+ value, valueType, err := key.GetStringValue(name)
+ if err != nil {
+ log.Debugf("failed to read registry value %s: %v", name, err)
+ continue
+ }
+
+ finalValue := s.expandRegistryValue(value, valueType, name)
+ s.setEnvironmentVariable(envMap, name, finalValue)
+ }
+
+ return nil
+}
+
+// expandRegistryValue expands registry values if they contain environment variables.
+func (s *Server) expandRegistryValue(value string, valueType uint32, name string) string {
+ if valueType != registry.EXPAND_SZ {
+ return value
+ }
+
+ sourcePtr := windows.StringToUTF16Ptr(value)
+ expandedBuffer := make([]uint16, 1024)
+ expandedLen, err := windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer)))
+ if err != nil {
+ log.Debugf("failed to expand environment string for %s: %v", name, err)
+ return value
+ }
+
+ // If buffer was too small, retry with larger buffer
+ if expandedLen > uint32(len(expandedBuffer)) {
+ expandedBuffer = make([]uint16, expandedLen)
+ expandedLen, err = windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer)))
+ if err != nil {
+ log.Debugf("failed to expand environment string for %s on retry: %v", name, err)
+ return value
+ }
+ }
+
+ if expandedLen > 0 && expandedLen <= uint32(len(expandedBuffer)) {
+ return windows.UTF16ToString(expandedBuffer[:expandedLen-1])
+ }
+ return value
+}
+
+// setEnvironmentVariable sets an environment variable with special handling for PATH.
+func (s *Server) setEnvironmentVariable(envMap map[string]string, name, value string) {
+ upperName := strings.ToUpper(name)
+
+ if upperName == "PATH" {
+ if existing, exists := envMap["PATH"]; exists && existing != value {
+ envMap["PATH"] = existing + ";" + value
+ } else {
+ envMap["PATH"] = value
+ }
+ } else {
+ envMap[upperName] = value
+ }
+}
+
+// setUserEnvironmentVariables sets critical user-specific environment variables.
+func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfile, username, domain string) {
+ envMap["USERPROFILE"] = userProfile
+
+ if len(userProfile) >= 2 && userProfile[1] == ':' {
+ envMap["HOMEDRIVE"] = userProfile[:2]
+ envMap["HOMEPATH"] = userProfile[2:]
+ }
+
+ envMap["APPDATA"] = filepath.Join(userProfile, "AppData", "Roaming")
+ envMap["LOCALAPPDATA"] = filepath.Join(userProfile, "AppData", "Local")
+
+ tempDir := filepath.Join(userProfile, "AppData", "Local", "Temp")
+ envMap["TEMP"] = tempDir
+ envMap["TMP"] = tempDir
+
+ envMap["USERNAME"] = username
+ if domain != "" && domain != "." {
+ envMap["USERDOMAIN"] = domain
+ envMap["USERDNSDOMAIN"] = domain
+ }
+
+ systemVars := []string{
+ "PROCESSOR_ARCHITECTURE", "PROCESSOR_IDENTIFIER", "PROCESSOR_LEVEL", "PROCESSOR_REVISION",
+ "SYSTEMDRIVE", "SYSTEMROOT", "WINDIR", "COMPUTERNAME", "OS", "PATHEXT",
+ "PROGRAMFILES", "PROGRAMDATA", "ALLUSERSPROFILE", "COMSPEC",
+ }
+
+ for _, sysVar := range systemVars {
+ if sysValue := os.Getenv(sysVar); sysValue != "" {
+ envMap[sysVar] = sysValue
+ }
+ }
+}
+
+// prepareCommandEnv prepares environment variables for command execution on Windows
+func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
+ username, domain := s.parseUsername(localUser.Username)
+ userEnv, err := s.getUserEnvironment(username, domain)
+ if err != nil {
+ log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
+ env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
+ env = append(env, prepareSSHEnv(session)...)
+ for _, v := range session.Environ() {
+ if acceptEnv(v) {
+ env = append(env, v)
+ }
+ }
+ return env
+ }
+
+ env := userEnv
+ env = append(env, prepareSSHEnv(session)...)
+ for _, v := range session.Environ() {
+ if acceptEnv(v) {
+ env = append(env, v)
+ }
+ }
+ return env
+}
+
+func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
+ if privilegeResult.User == nil {
+ logger.Errorf("no user in privilege result")
+ return false
+ }
+
+ cmd := session.Command()
+ shell := getUserShell(privilegeResult.User.Uid)
+
+ if len(cmd) == 0 {
+ logger.Infof("starting interactive shell: %s", shell)
+ } else {
+ logger.Infof("executing command: %s", safeLogCommand(cmd))
+ }
+
+ s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
+ return true
+}
+
+// getShellCommandArgs returns the shell command and arguments for executing a command string
+func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
+ if cmdString == "" {
+ return []string{shell, "-NoLogo"}
+ }
+ return []string{shell, "-Command", cmdString}
+}
+
+func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
+ logger.Info("starting interactive shell")
+ s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
+}
+
+type PtyExecutionRequest struct {
+ Shell string
+ Command string
+ Width int
+ Height int
+ Username string
+ Domain string
+}
+
+func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error {
+ log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
+ req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
+
+ privilegeDropper := NewPrivilegeDropper()
+ userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
+ if err != nil {
+ return fmt.Errorf("create user token: %w", err)
+ }
+ defer func() {
+ if err := windows.CloseHandle(userToken); err != nil {
+ log.Debugf("close user token: %v", err)
+ }
+ }()
+
+ server := &Server{}
+ userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
+ if err != nil {
+ log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
+ userEnv = os.Environ()
+ }
+
+ workingDir := getUserHomeFromEnv(userEnv)
+ if workingDir == "" {
+ workingDir = fmt.Sprintf(`C:\Users\%s`, req.Username)
+ }
+
+ ptyConfig := winpty.PtyConfig{
+ Shell: req.Shell,
+ Command: req.Command,
+ Width: req.Width,
+ Height: req.Height,
+ WorkingDir: workingDir,
+ }
+
+ userConfig := winpty.UserConfig{
+ Token: userToken,
+ Environment: userEnv,
+ }
+
+ log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
+ return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig)
+}
+
+func getUserHomeFromEnv(env []string) string {
+ for _, envVar := range env {
+ if len(envVar) > 12 && envVar[:12] == "USERPROFILE=" {
+ return envVar[12:]
+ }
+ }
+ return ""
+}
+
+func (s *Server) setupProcessGroup(_ *exec.Cmd) {
+ // Windows doesn't support process groups in the same way as Unix
+ // Process creation groups are handled differently
+}
+
+func (s *Server) killProcessGroup(cmd *exec.Cmd) {
+ if cmd.Process == nil {
+ return
+ }
+
+ logger := log.WithField("pid", cmd.Process.Pid)
+
+ if err := cmd.Process.Kill(); err != nil {
+ logger.Debugf("kill process failed: %v", err)
+ }
+}
+
+// detectSuPtySupport always returns false on Windows as su is not available
+func (s *Server) detectSuPtySupport(context.Context) bool {
+ return false
+}
+
+// detectUtilLinuxLogin always returns false on Windows
+func (s *Server) detectUtilLinuxLogin(context.Context) bool {
+ return false
+}
+
+// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
+func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
+ command := session.RawCommand()
+ if command == "" {
+ logger.Error("no command specified for PTY execution")
+ if err := session.Exit(1); err != nil {
+ logSessionExitError(logger, err)
+ }
+ return false
+ }
+
+ return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
+}
+
+// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
+func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
+ localUser := privilegeResult.User
+ if localUser == nil {
+ logger.Errorf("no user in privilege result")
+ return false
+ }
+
+ username, domain := s.parseUsername(localUser.Username)
+ shell := getUserShell(localUser.Uid)
+
+ req := PtyExecutionRequest{
+ Shell: shell,
+ Command: command,
+ Width: ptyReq.Window.Width,
+ Height: ptyReq.Window.Height,
+ Username: username,
+ Domain: domain,
+ }
+
+ if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
+ logger.Errorf("ConPty execution failed: %v", err)
+ if err := session.Exit(1); err != nil {
+ logSessionExitError(logger, err)
+ }
+ return false
+ }
+
+ logger.Debug("ConPty execution completed")
+ return true
+}
diff --git a/client/ssh/server/compatibility_test.go b/client/ssh/server/compatibility_test.go
new file mode 100644
index 000000000..34ffccfd2
--- /dev/null
+++ b/client/ssh/server/compatibility_test.go
@@ -0,0 +1,722 @@
+package server
+
+import (
+ "context"
+ "crypto/ed25519"
+ "crypto/rand"
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "os/exec"
+ "runtime"
+ "strings"
+ "testing"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "golang.org/x/crypto/ssh"
+
+ nbssh "github.com/netbirdio/netbird/client/ssh"
+ "github.com/netbirdio/netbird/client/ssh/testutil"
+)
+
+// TestMain handles package-level setup and cleanup
+func TestMain(m *testing.M) {
+ // Guard against infinite recursion when test binary is called as "netbird ssh exec"
+ // This happens when running tests as non-privileged user with fallback
+ if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
+ // Just exit with error to break the recursion
+ fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
+ os.Exit(1)
+ }
+
+ // Run tests
+ code := m.Run()
+
+ // Cleanup any created test users
+ testutil.CleanupTestUsers()
+
+ os.Exit(code)
+}
+
+// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
+func TestSSHServerCompatibility(t *testing.T) {
+ if testing.Short() {
+ t.Skip("Skipping SSH compatibility tests in short mode")
+ }
+
+ // Check if ssh binary is available
+ if !isSSHClientAvailable() {
+ t.Skip("SSH client not available on this system")
+ }
+
+ // Set up SSH server - use our existing key generation for server
+ hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ // Generate OpenSSH-compatible keys for client
+ clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
+ require.NoError(t, err)
+
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := New(serverConfig)
+ server.SetAllowRootLogin(true)
+
+ serverAddr := StartTestServer(t, server)
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+
+ // Create temporary key files for SSH client
+ clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH)
+ defer cleanupKey()
+
+ // Extract host and port from server address
+ host, portStr, err := net.SplitHostPort(serverAddr)
+ require.NoError(t, err)
+
+ // Get appropriate user for SSH connection (handle system accounts)
+ username := testutil.GetTestUsername(t)
+
+ t.Run("basic command execution", func(t *testing.T) {
+ testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username)
+ })
+
+ t.Run("interactive command", func(t *testing.T) {
+ testSSHInteractiveCommand(t, host, portStr, clientKeyFile)
+ })
+
+ t.Run("port forwarding", func(t *testing.T) {
+ testSSHPortForwarding(t, host, portStr, clientKeyFile)
+ })
+}
+
+// testSSHCommandExecutionWithUser tests basic command execution with system SSH client using specified user.
+func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username string) {
+ cmd := exec.Command("ssh",
+ "-i", keyFile,
+ "-p", port,
+ "-o", "StrictHostKeyChecking=no",
+ "-o", "UserKnownHostsFile=/dev/null",
+ "-o", "ConnectTimeout=5",
+ fmt.Sprintf("%s@%s", username, host),
+ "echo", "hello_world")
+
+ output, err := cmd.CombinedOutput()
+
+ if err != nil {
+ t.Logf("SSH command failed: %v", err)
+ t.Logf("Output: %s", string(output))
+ return
+ }
+
+ assert.Contains(t, string(output), "hello_world", "SSH command should execute successfully")
+}
+
+// testSSHInteractiveCommand tests interactive shell session.
+func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
+ // Get appropriate user for SSH connection
+ username := testutil.GetTestUsername(t)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ cmd := exec.CommandContext(ctx, "ssh",
+ "-i", keyFile,
+ "-p", port,
+ "-o", "StrictHostKeyChecking=no",
+ "-o", "UserKnownHostsFile=/dev/null",
+ "-o", "ConnectTimeout=5",
+ fmt.Sprintf("%s@%s", username, host))
+
+ stdin, err := cmd.StdinPipe()
+ if err != nil {
+ t.Skipf("Cannot create stdin pipe: %v", err)
+ return
+ }
+
+ stdout, err := cmd.StdoutPipe()
+ if err != nil {
+ t.Skipf("Cannot create stdout pipe: %v", err)
+ return
+ }
+
+ err = cmd.Start()
+ if err != nil {
+ t.Logf("Cannot start SSH session: %v", err)
+ return
+ }
+
+ go func() {
+ defer func() {
+ if err := stdin.Close(); err != nil {
+ t.Logf("stdin close error: %v", err)
+ }
+ }()
+ time.Sleep(100 * time.Millisecond)
+ if _, err := stdin.Write([]byte("echo interactive_test\n")); err != nil {
+ t.Logf("stdin write error: %v", err)
+ }
+ time.Sleep(100 * time.Millisecond)
+ if _, err := stdin.Write([]byte("exit\n")); err != nil {
+ t.Logf("stdin write error: %v", err)
+ }
+ }()
+
+ output, err := io.ReadAll(stdout)
+ if err != nil {
+ t.Logf("Cannot read SSH output: %v", err)
+ }
+
+ err = cmd.Wait()
+ if err != nil {
+ t.Logf("SSH interactive session error: %v", err)
+ t.Logf("Output: %s", string(output))
+ return
+ }
+
+ assert.Contains(t, string(output), "interactive_test", "Interactive SSH session should work")
+}
+
+// testSSHPortForwarding tests port forwarding compatibility.
+func testSSHPortForwarding(t *testing.T, host, port, keyFile string) {
+ // Get appropriate user for SSH connection
+ username := testutil.GetTestUsername(t)
+
+ testServer, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+ defer testServer.Close()
+
+ testServerAddr := testServer.Addr().String()
+ expectedResponse := "HTTP/1.1 200 OK\r\nContent-Length: 21\r\n\r\nCompatibility Test OK"
+
+ go func() {
+ for {
+ conn, err := testServer.Accept()
+ if err != nil {
+ return
+ }
+ go func(c net.Conn) {
+ defer func() {
+ if err := c.Close(); err != nil {
+ t.Logf("test server connection close error: %v", err)
+ }
+ }()
+ buf := make([]byte, 1024)
+ if _, err := c.Read(buf); err != nil {
+ t.Logf("Test server read error: %v", err)
+ }
+ if _, err := c.Write([]byte(expectedResponse)); err != nil {
+ t.Logf("Test server write error: %v", err)
+ }
+ }(conn)
+ }
+ }()
+
+ localListener, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+ localAddr := localListener.Addr().String()
+ localListener.Close()
+
+ _, localPort, err := net.SplitHostPort(localAddr)
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
+ defer cancel()
+
+ forwardSpec := fmt.Sprintf("%s:%s", localPort, testServerAddr)
+ cmd := exec.CommandContext(ctx, "ssh",
+ "-i", keyFile,
+ "-p", port,
+ "-L", forwardSpec,
+ "-o", "StrictHostKeyChecking=no",
+ "-o", "UserKnownHostsFile=/dev/null",
+ "-o", "ConnectTimeout=5",
+ "-N",
+ fmt.Sprintf("%s@%s", username, host))
+
+ err = cmd.Start()
+ if err != nil {
+ t.Logf("Cannot start SSH port forwarding: %v", err)
+ return
+ }
+
+ defer func() {
+ if cmd.Process != nil {
+ if err := cmd.Process.Kill(); err != nil {
+ t.Logf("process kill error: %v", err)
+ }
+ }
+ if err := cmd.Wait(); err != nil {
+ t.Logf("process wait after kill: %v", err)
+ }
+ }()
+
+ time.Sleep(500 * time.Millisecond)
+
+ conn, err := net.DialTimeout("tcp", localAddr, 3*time.Second)
+ if err != nil {
+ t.Logf("Cannot connect to forwarded port: %v", err)
+ return
+ }
+ defer func() {
+ if err := conn.Close(); err != nil {
+ t.Logf("forwarded connection close error: %v", err)
+ }
+ }()
+
+ request := "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
+ _, err = conn.Write([]byte(request))
+ require.NoError(t, err)
+
+ if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil {
+ log.Debugf("failed to set read deadline: %v", err)
+ }
+ response := make([]byte, len(expectedResponse))
+ n, err := io.ReadFull(conn, response)
+ if err != nil {
+ t.Logf("Cannot read forwarded response: %v", err)
+ return
+ }
+
+ assert.Equal(t, len(expectedResponse), n, "Should read expected number of bytes")
+ assert.Equal(t, expectedResponse, string(response), "Should get correct HTTP response through SSH port forwarding")
+}
+
+// isSSHClientAvailable checks if the ssh binary is available
+func isSSHClientAvailable() bool {
+ _, err := exec.LookPath("ssh")
+ return err == nil
+}
+
+// generateOpenSSHKey generates an ED25519 key in OpenSSH format that the system SSH client can use.
+func generateOpenSSHKey(t *testing.T) ([]byte, []byte, error) {
+ // Check if ssh-keygen is available
+ if _, err := exec.LookPath("ssh-keygen"); err != nil {
+ // Fall back to our existing key generation and try to convert
+ return generateOpenSSHKeyFallback()
+ }
+
+ // Create temporary file for ssh-keygen
+ tempFile, err := os.CreateTemp("", "ssh_keygen_*")
+ if err != nil {
+ return nil, nil, fmt.Errorf("create temp file: %w", err)
+ }
+ keyPath := tempFile.Name()
+ tempFile.Close()
+
+ // Remove the temp file so ssh-keygen can create it
+ if err := os.Remove(keyPath); err != nil {
+ t.Logf("failed to remove key file: %v", err)
+ }
+
+ // Clean up temp files
+ defer func() {
+ if err := os.Remove(keyPath); err != nil {
+ t.Logf("failed to cleanup key file: %v", err)
+ }
+ if err := os.Remove(keyPath + ".pub"); err != nil {
+ t.Logf("failed to cleanup public key file: %v", err)
+ }
+ }()
+
+ // Generate key using ssh-keygen
+ cmd := exec.Command("ssh-keygen", "-t", "ed25519", "-f", keyPath, "-N", "", "-q")
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ return nil, nil, fmt.Errorf("ssh-keygen failed: %w, output: %s", err, string(output))
+ }
+
+ // Read private key
+ privKeyBytes, err := os.ReadFile(keyPath)
+ if err != nil {
+ return nil, nil, fmt.Errorf("read private key: %w", err)
+ }
+
+ // Read public key
+ pubKeyBytes, err := os.ReadFile(keyPath + ".pub")
+ if err != nil {
+ return nil, nil, fmt.Errorf("read public key: %w", err)
+ }
+
+ return privKeyBytes, pubKeyBytes, nil
+}
+
+// generateOpenSSHKeyFallback falls back to generating keys using our existing method
+func generateOpenSSHKeyFallback() ([]byte, []byte, error) {
+ // Generate shared.ED25519 key pair using our existing method
+ _, privKey, err := ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ return nil, nil, fmt.Errorf("generate key: %w", err)
+ }
+
+ // Convert to SSH format
+ sshPrivKey, err := ssh.NewSignerFromKey(privKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("create signer: %w", err)
+ }
+
+ // For the fallback, just use our PKCS#8 format and hope it works
+ // This won't be in OpenSSH format but might still work with some SSH clients
+ hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ if err != nil {
+ return nil, nil, fmt.Errorf("generate fallback key: %w", err)
+ }
+
+ // Get public key in SSH format
+ sshPubKey := ssh.MarshalAuthorizedKey(sshPrivKey.PublicKey())
+
+ return hostKey, sshPubKey, nil
+}
+
+// createTempKeyFileFromBytes creates a temporary SSH private key file from raw bytes
+func createTempKeyFileFromBytes(t *testing.T, keyBytes []byte) (string, func()) {
+ t.Helper()
+
+ tempFile, err := os.CreateTemp("", "ssh_test_key_*")
+ require.NoError(t, err)
+
+ _, err = tempFile.Write(keyBytes)
+ require.NoError(t, err)
+
+ err = tempFile.Close()
+ require.NoError(t, err)
+
+ // Set proper permissions for SSH key (readable by owner only)
+ err = os.Chmod(tempFile.Name(), 0600)
+ require.NoError(t, err)
+
+ cleanup := func() {
+ _ = os.Remove(tempFile.Name())
+ }
+
+ return tempFile.Name(), cleanup
+}
+
+// createTempKeyFile creates a temporary SSH private key file (for backward compatibility)
+func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) {
+ return createTempKeyFileFromBytes(t, privateKey)
+}
+
+// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
+func TestSSHServerFeatureCompatibility(t *testing.T) {
+ if testing.Short() {
+ t.Skip("Skipping SSH feature compatibility tests in short mode")
+ }
+
+ if runtime.GOOS == "windows" && testutil.IsCI() {
+ t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues")
+ }
+
+ if !isSSHClientAvailable() {
+ t.Skip("SSH client not available on this system")
+ }
+
+ // Test various SSH features
+ testCases := []struct {
+ name string
+ testFunc func(t *testing.T, host, port, keyFile string)
+ description string
+ }{
+ {
+ name: "command_with_flags",
+ testFunc: testCommandWithFlags,
+ description: "Commands with flags should work like standard SSH",
+ },
+ {
+ name: "environment_variables",
+ testFunc: testEnvironmentVariables,
+ description: "Environment variables should be available",
+ },
+ {
+ name: "exit_codes",
+ testFunc: testExitCodes,
+ description: "Exit codes should be properly handled",
+ },
+ }
+
+ // Set up SSH server
+ hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := New(serverConfig)
+ server.SetAllowRootLogin(true)
+
+ serverAddr := StartTestServer(t, server)
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+
+ clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
+ defer cleanupKey()
+
+ host, portStr, err := net.SplitHostPort(serverAddr)
+ require.NoError(t, err)
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ tc.testFunc(t, host, portStr, clientKeyFile)
+ })
+ }
+}
+
+// testCommandWithFlags tests that commands with flags work properly
+func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
+ // Get appropriate user for SSH connection
+ username := testutil.GetTestUsername(t)
+
+ // Test ls with flags
+ cmd := exec.Command("ssh",
+ "-i", keyFile,
+ "-p", port,
+ "-o", "StrictHostKeyChecking=no",
+ "-o", "UserKnownHostsFile=/dev/null",
+ "-o", "ConnectTimeout=5",
+ fmt.Sprintf("%s@%s", username, host),
+ "ls", "-la", "/tmp")
+
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ t.Logf("Command with flags failed: %v", err)
+ t.Logf("Output: %s", string(output))
+ return
+ }
+
+ // Should not be empty and should not contain error messages
+ assert.NotEmpty(t, string(output), "ls -la should produce output")
+ assert.NotContains(t, strings.ToLower(string(output)), "command not found", "Command should be executed")
+}
+
+// testEnvironmentVariables tests that environment is properly set up
+func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
+ // Get appropriate user for SSH connection
+ username := testutil.GetTestUsername(t)
+
+ cmd := exec.Command("ssh",
+ "-i", keyFile,
+ "-p", port,
+ "-o", "StrictHostKeyChecking=no",
+ "-o", "UserKnownHostsFile=/dev/null",
+ "-o", "ConnectTimeout=5",
+ fmt.Sprintf("%s@%s", username, host),
+ "echo", "$HOME")
+
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ t.Logf("Environment test failed: %v", err)
+ t.Logf("Output: %s", string(output))
+ return
+ }
+
+ // HOME environment variable should be available
+ homeOutput := strings.TrimSpace(string(output))
+ assert.NotEmpty(t, homeOutput, "HOME environment variable should be set")
+ assert.NotEqual(t, "$HOME", homeOutput, "Environment variable should be expanded")
+}
+
+// testExitCodes tests that exit codes are properly handled
+func testExitCodes(t *testing.T, host, port, keyFile string) {
+ // Get appropriate user for SSH connection
+ username := testutil.GetTestUsername(t)
+
+ // Test successful command (exit code 0)
+ cmd := exec.Command("ssh",
+ "-i", keyFile,
+ "-p", port,
+ "-o", "StrictHostKeyChecking=no",
+ "-o", "UserKnownHostsFile=/dev/null",
+ "-o", "ConnectTimeout=5",
+ fmt.Sprintf("%s@%s", username, host),
+ "true") // always succeeds
+
+ err := cmd.Run()
+ assert.NoError(t, err, "Command with exit code 0 should succeed")
+
+ // Test failing command (exit code 1)
+ cmd = exec.Command("ssh",
+ "-i", keyFile,
+ "-p", port,
+ "-o", "StrictHostKeyChecking=no",
+ "-o", "UserKnownHostsFile=/dev/null",
+ "-o", "ConnectTimeout=5",
+ fmt.Sprintf("%s@%s", username, host),
+ "false") // always fails
+
+ err = cmd.Run()
+ assert.Error(t, err, "Command with exit code 1 should fail")
+
+ // Check if it's the right kind of error
+ if exitError, ok := err.(*exec.ExitError); ok {
+ assert.Equal(t, 1, exitError.ExitCode(), "Exit code should be preserved")
+ }
+}
+
+// TestSSHServerSecurityFeatures tests security-related SSH features
+func TestSSHServerSecurityFeatures(t *testing.T) {
+ if testing.Short() {
+ t.Skip("Skipping SSH security tests in short mode")
+ }
+
+ if !isSSHClientAvailable() {
+ t.Skip("SSH client not available on this system")
+ }
+
+ // Get appropriate user for SSH connection
+ username := testutil.GetTestUsername(t)
+
+ // Set up SSH server with specific security settings
+ hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := New(serverConfig)
+ server.SetAllowRootLogin(true)
+
+ serverAddr := StartTestServer(t, server)
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+
+ clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
+ defer cleanupKey()
+
+ host, portStr, err := net.SplitHostPort(serverAddr)
+ require.NoError(t, err)
+
+ t.Run("key_authentication", func(t *testing.T) {
+ // Test that key authentication works
+ cmd := exec.Command("ssh",
+ "-i", clientKeyFile,
+ "-p", portStr,
+ "-o", "StrictHostKeyChecking=no",
+ "-o", "UserKnownHostsFile=/dev/null",
+ "-o", "ConnectTimeout=5",
+ "-o", "PasswordAuthentication=no",
+ fmt.Sprintf("%s@%s", username, host),
+ "echo", "auth_success")
+
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ t.Logf("Key authentication failed: %v", err)
+ t.Logf("Output: %s", string(output))
+ return
+ }
+
+ assert.Contains(t, string(output), "auth_success", "Key authentication should work")
+ })
+
+ t.Run("any_key_accepted_in_no_auth_mode", func(t *testing.T) {
+ // Create a different key that shouldn't be accepted
+ wrongKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ wrongKeyFile, cleanupWrongKey := createTempKeyFile(t, wrongKey)
+ defer cleanupWrongKey()
+
+ // Test that wrong key is rejected
+ cmd := exec.Command("ssh",
+ "-i", wrongKeyFile,
+ "-p", portStr,
+ "-o", "StrictHostKeyChecking=no",
+ "-o", "UserKnownHostsFile=/dev/null",
+ "-o", "ConnectTimeout=5",
+ "-o", "PasswordAuthentication=no",
+ fmt.Sprintf("%s@%s", username, host),
+ "echo", "should_not_work")
+
+ err = cmd.Run()
+ assert.NoError(t, err, "Any key should work in no-auth mode")
+ })
+}
+
+// TestCrossPlatformCompatibility tests cross-platform behavior
+func TestCrossPlatformCompatibility(t *testing.T) {
+ if testing.Short() {
+ t.Skip("Skipping cross-platform compatibility tests in short mode")
+ }
+
+ if !isSSHClientAvailable() {
+ t.Skip("SSH client not available on this system")
+ }
+
+ // Get appropriate user for SSH connection
+ username := testutil.GetTestUsername(t)
+
+ // Set up SSH server
+ hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := New(serverConfig)
+ server.SetAllowRootLogin(true)
+
+ serverAddr := StartTestServer(t, server)
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+
+ clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
+ defer cleanupKey()
+
+ host, portStr, err := net.SplitHostPort(serverAddr)
+ require.NoError(t, err)
+
+ // Test platform-specific commands
+ var testCommand string
+
+ switch runtime.GOOS {
+ case "windows":
+ testCommand = "echo %OS%"
+ default:
+ testCommand = "uname"
+ }
+
+ cmd := exec.Command("ssh",
+ "-i", clientKeyFile,
+ "-p", portStr,
+ "-o", "StrictHostKeyChecking=no",
+ "-o", "UserKnownHostsFile=/dev/null",
+ "-o", "ConnectTimeout=5",
+ fmt.Sprintf("%s@%s", username, host),
+ testCommand)
+
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ t.Logf("Platform-specific command failed: %v", err)
+ t.Logf("Output: %s", string(output))
+ return
+ }
+
+ outputStr := strings.TrimSpace(string(output))
+ t.Logf("Platform command output: %s", outputStr)
+ assert.NotEmpty(t, outputStr, "Platform-specific command should produce output")
+}
diff --git a/client/ssh/server/executor_unix.go b/client/ssh/server/executor_unix.go
new file mode 100644
index 000000000..8adc824ef
--- /dev/null
+++ b/client/ssh/server/executor_unix.go
@@ -0,0 +1,253 @@
+//go:build unix
+
+package server
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "os/exec"
+ "runtime"
+ "strings"
+ "syscall"
+
+ log "github.com/sirupsen/logrus"
+)
+
+// Exit codes for executor process communication
+const (
+ ExitCodeSuccess = 0
+ ExitCodePrivilegeDropFail = 10
+ ExitCodeShellExecFail = 11
+ ExitCodeValidationFail = 12
+)
+
+// ExecutorConfig holds configuration for the executor process
+type ExecutorConfig struct {
+ UID uint32
+ GID uint32
+ Groups []uint32
+ WorkingDir string
+ Shell string
+ Command string
+ PTY bool
+}
+
+// PrivilegeDropper handles secure privilege dropping in child processes
+type PrivilegeDropper struct{}
+
+// NewPrivilegeDropper creates a new privilege dropper
+func NewPrivilegeDropper() *PrivilegeDropper {
+ return &PrivilegeDropper{}
+}
+
+// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
+func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config ExecutorConfig) (*exec.Cmd, error) {
+ netbirdPath, err := os.Executable()
+ if err != nil {
+ return nil, fmt.Errorf("get netbird executable path: %w", err)
+ }
+
+ if err := pd.validatePrivileges(config.UID, config.GID); err != nil {
+ return nil, fmt.Errorf("invalid privileges: %w", err)
+ }
+
+ args := []string{
+ "ssh", "exec",
+ "--uid", fmt.Sprintf("%d", config.UID),
+ "--gid", fmt.Sprintf("%d", config.GID),
+ "--working-dir", config.WorkingDir,
+ "--shell", config.Shell,
+ }
+
+ for _, group := range config.Groups {
+ args = append(args, "--groups", fmt.Sprintf("%d", group))
+ }
+
+ if config.PTY {
+ args = append(args, "--pty")
+ }
+
+ if config.Command != "" {
+ args = append(args, "--cmd", config.Command)
+ }
+
+ // Log executor args safely - show all args except hide the command value
+ safeArgs := make([]string, len(args))
+ copy(safeArgs, args)
+ for i := 0; i < len(safeArgs)-1; i++ {
+ if safeArgs[i] == "--cmd" {
+ cmdParts := strings.Fields(safeArgs[i+1])
+ safeArgs[i+1] = safeLogCommand(cmdParts)
+ break
+ }
+ }
+ log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
+ return exec.CommandContext(ctx, netbirdPath, args...), nil
+}
+
+// DropPrivileges performs privilege dropping with thread locking for security
+func (pd *PrivilegeDropper) DropPrivileges(targetUID, targetGID uint32, supplementaryGroups []uint32) error {
+ if err := pd.validatePrivileges(targetUID, targetGID); err != nil {
+ return fmt.Errorf("invalid privileges: %w", err)
+ }
+
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ originalUID := os.Geteuid()
+ originalGID := os.Getegid()
+
+ if originalUID != int(targetUID) || originalGID != int(targetGID) {
+ if err := pd.setGroupsAndIDs(targetUID, targetGID, supplementaryGroups); err != nil {
+ return fmt.Errorf("set groups and IDs: %w", err)
+ }
+ }
+
+ if err := pd.validatePrivilegeDropSuccess(targetUID, targetGID, originalUID, originalGID); err != nil {
+ return err
+ }
+
+ log.Tracef("successfully dropped privileges to UID=%d, GID=%d", targetUID, targetGID)
+ return nil
+}
+
+// setGroupsAndIDs sets the supplementary groups, GID, and UID
+func (pd *PrivilegeDropper) setGroupsAndIDs(targetUID, targetGID uint32, supplementaryGroups []uint32) error {
+ groups := make([]int, len(supplementaryGroups))
+ for i, g := range supplementaryGroups {
+ groups[i] = int(g)
+ }
+
+ if runtime.GOOS == "darwin" || runtime.GOOS == "freebsd" {
+ if len(groups) == 0 || groups[0] != int(targetGID) {
+ groups = append([]int{int(targetGID)}, groups...)
+ }
+ }
+
+ if err := syscall.Setgroups(groups); err != nil {
+ return fmt.Errorf("setgroups to %v: %w", groups, err)
+ }
+
+ if err := syscall.Setgid(int(targetGID)); err != nil {
+ return fmt.Errorf("setgid to %d: %w", targetGID, err)
+ }
+
+ if err := syscall.Setuid(int(targetUID)); err != nil {
+ return fmt.Errorf("setuid to %d: %w", targetUID, err)
+ }
+
+ return nil
+}
+
+// validatePrivilegeDropSuccess validates that privilege dropping was successful
+func (pd *PrivilegeDropper) validatePrivilegeDropSuccess(targetUID, targetGID uint32, originalUID, originalGID int) error {
+ if err := pd.validatePrivilegeDropReversibility(targetUID, targetGID, originalUID, originalGID); err != nil {
+ return err
+ }
+
+ if err := pd.validateCurrentPrivileges(targetUID, targetGID); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// validatePrivilegeDropReversibility ensures privileges cannot be restored
+func (pd *PrivilegeDropper) validatePrivilegeDropReversibility(targetUID, targetGID uint32, originalUID, originalGID int) error {
+ if originalGID != int(targetGID) {
+ if err := syscall.Setegid(originalGID); err == nil {
+ return fmt.Errorf("privilege drop validation failed: able to restore original GID %d", originalGID)
+ }
+ }
+ if originalUID != int(targetUID) {
+ if err := syscall.Seteuid(originalUID); err == nil {
+ return fmt.Errorf("privilege drop validation failed: able to restore original UID %d", originalUID)
+ }
+ }
+ return nil
+}
+
+// validateCurrentPrivileges validates the current UID and GID match the target
+func (pd *PrivilegeDropper) validateCurrentPrivileges(targetUID, targetGID uint32) error {
+ currentUID := os.Geteuid()
+ if currentUID != int(targetUID) {
+ return fmt.Errorf("privilege drop validation failed: current UID %d, expected %d", currentUID, targetUID)
+ }
+
+ currentGID := os.Getegid()
+ if currentGID != int(targetGID) {
+ return fmt.Errorf("privilege drop validation failed: current GID %d, expected %d", currentGID, targetGID)
+ }
+
+ return nil
+}
+
+// ExecuteWithPrivilegeDrop executes a command with privilege dropping, using exit codes to signal specific failures
+func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config ExecutorConfig) {
+ log.Tracef("dropping privileges to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups)
+
+ // TODO: Implement Pty support for executor path
+ if config.PTY {
+ config.PTY = false
+ }
+
+ if err := pd.DropPrivileges(config.UID, config.GID, config.Groups); err != nil {
+ _, _ = fmt.Fprintf(os.Stderr, "privilege drop failed: %v\n", err)
+ os.Exit(ExitCodePrivilegeDropFail)
+ }
+
+ if config.WorkingDir != "" {
+ if err := os.Chdir(config.WorkingDir); err != nil {
+ log.Debugf("failed to change to working directory %s, continuing with current directory: %v", config.WorkingDir, err)
+ }
+ }
+
+ var execCmd *exec.Cmd
+ if config.Command == "" {
+ os.Exit(ExitCodeSuccess)
+ }
+
+ execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
+ execCmd.Stdin = os.Stdin
+ execCmd.Stdout = os.Stdout
+ execCmd.Stderr = os.Stderr
+
+ cmdParts := strings.Fields(config.Command)
+ safeCmd := safeLogCommand(cmdParts)
+ log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
+ if err := execCmd.Run(); err != nil {
+ var exitError *exec.ExitError
+ if errors.As(err, &exitError) {
+ // Normal command exit with non-zero code - not an SSH execution error
+ log.Tracef("command exited with code %d", exitError.ExitCode())
+ os.Exit(exitError.ExitCode())
+ }
+
+ // Actual execution failure (command not found, permission denied, etc.)
+ log.Debugf("command execution failed: %v", err)
+ os.Exit(ExitCodeShellExecFail)
+ }
+
+ os.Exit(ExitCodeSuccess)
+}
+
+// validatePrivileges validates that privilege dropping to the target UID/GID is allowed
+func (pd *PrivilegeDropper) validatePrivileges(uid, gid uint32) error {
+ currentUID := uint32(os.Geteuid())
+ currentGID := uint32(os.Getegid())
+
+ // Allow same-user operations (no privilege dropping needed)
+ if uid == currentUID && gid == currentGID {
+ return nil
+ }
+
+ // Only root can drop privileges to other users
+ if currentUID != 0 {
+ return fmt.Errorf("cannot drop privileges from non-root user (UID %d) to UID %d", currentUID, uid)
+ }
+
+ // Root can drop to any user (including root itself)
+ return nil
+}
diff --git a/client/ssh/server/executor_unix_test.go b/client/ssh/server/executor_unix_test.go
new file mode 100644
index 000000000..0c5108f57
--- /dev/null
+++ b/client/ssh/server/executor_unix_test.go
@@ -0,0 +1,262 @@
+//go:build unix
+
+package server
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "os/exec"
+ "os/user"
+ "strconv"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) {
+ pd := NewPrivilegeDropper()
+
+ currentUID := uint32(os.Geteuid())
+ currentGID := uint32(os.Getegid())
+
+ tests := []struct {
+ name string
+ uid uint32
+ gid uint32
+ wantErr bool
+ }{
+ {
+ name: "same user - no privilege drop needed",
+ uid: currentUID,
+ gid: currentGID,
+ wantErr: false,
+ },
+ {
+ name: "non-root to different user should fail",
+ uid: currentUID + 1, // Use a different UID to ensure it's actually different
+ gid: currentGID + 1, // Use a different GID to ensure it's actually different
+ wantErr: currentUID != 0, // Only fail if current user is not root
+ },
+ {
+ name: "root can drop to any user",
+ uid: 1000,
+ gid: 1000,
+ wantErr: false,
+ },
+ {
+ name: "root can stay as root",
+ uid: 0,
+ gid: 0,
+ wantErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Skip non-root tests when running as root, and root tests when not root
+ if tt.name == "non-root to different user should fail" && currentUID == 0 {
+ t.Skip("Skipping non-root test when running as root")
+ }
+ if (tt.name == "root can drop to any user" || tt.name == "root can stay as root") && currentUID != 0 {
+ t.Skip("Skipping root test when not running as root")
+ }
+
+ err := pd.validatePrivileges(tt.uid, tt.gid)
+ if tt.wantErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
+ pd := NewPrivilegeDropper()
+
+ config := ExecutorConfig{
+ UID: 1000,
+ GID: 1000,
+ Groups: []uint32{1000, 1001},
+ WorkingDir: "/home/testuser",
+ Shell: "/bin/bash",
+ Command: "ls -la",
+ }
+
+ cmd, err := pd.CreateExecutorCommand(context.Background(), config)
+ require.NoError(t, err)
+ require.NotNil(t, cmd)
+
+ // Verify the command is calling netbird ssh exec
+ assert.Contains(t, cmd.Args, "ssh")
+ assert.Contains(t, cmd.Args, "exec")
+ assert.Contains(t, cmd.Args, "--uid")
+ assert.Contains(t, cmd.Args, "1000")
+ assert.Contains(t, cmd.Args, "--gid")
+ assert.Contains(t, cmd.Args, "1000")
+ assert.Contains(t, cmd.Args, "--groups")
+ assert.Contains(t, cmd.Args, "1000")
+ assert.Contains(t, cmd.Args, "1001")
+ assert.Contains(t, cmd.Args, "--working-dir")
+ assert.Contains(t, cmd.Args, "/home/testuser")
+ assert.Contains(t, cmd.Args, "--shell")
+ assert.Contains(t, cmd.Args, "/bin/bash")
+ assert.Contains(t, cmd.Args, "--cmd")
+ assert.Contains(t, cmd.Args, "ls -la")
+}
+
+func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
+ pd := NewPrivilegeDropper()
+
+ config := ExecutorConfig{
+ UID: 1000,
+ GID: 1000,
+ Groups: []uint32{1000},
+ WorkingDir: "/home/testuser",
+ Shell: "/bin/bash",
+ Command: "",
+ }
+
+ cmd, err := pd.CreateExecutorCommand(context.Background(), config)
+ require.NoError(t, err)
+ require.NotNil(t, cmd)
+
+ // Verify no command mode (command is empty so no --cmd flag)
+ assert.NotContains(t, cmd.Args, "--cmd")
+ assert.NotContains(t, cmd.Args, "--interactive")
+}
+
+// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping
+// This test requires root privileges and will be skipped if not running as root
+func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) {
+ if os.Geteuid() != 0 {
+ t.Skip("This test requires root privileges")
+ }
+
+ // Find a non-root user to test with
+ testUser, err := findNonRootUser()
+ if err != nil {
+ t.Skip("No suitable non-root user found for testing")
+ }
+
+ // Verify the user actually exists by looking it up again
+ _, err = user.LookupId(testUser.Uid)
+ if err != nil {
+ t.Skipf("Test user %s (UID %s) does not exist on this system: %v", testUser.Username, testUser.Uid, err)
+ }
+
+ uid64, err := strconv.ParseUint(testUser.Uid, 10, 32)
+ require.NoError(t, err)
+ targetUID := uint32(uid64)
+
+ gid64, err := strconv.ParseUint(testUser.Gid, 10, 32)
+ require.NoError(t, err)
+ targetGID := uint32(gid64)
+
+ // Test in a child process to avoid affecting the test runner
+ if os.Getenv("TEST_PRIVILEGE_DROP") == "1" {
+ pd := NewPrivilegeDropper()
+
+ // This should succeed
+ err := pd.DropPrivileges(targetUID, targetGID, []uint32{targetGID})
+ require.NoError(t, err)
+
+ // Verify we are now running as the target user
+ currentUID := uint32(os.Geteuid())
+ currentGID := uint32(os.Getegid())
+
+ assert.Equal(t, targetUID, currentUID, "UID should match target")
+ assert.Equal(t, targetGID, currentGID, "GID should match target")
+ assert.NotEqual(t, uint32(0), currentUID, "Should not be running as root")
+ assert.NotEqual(t, uint32(0), currentGID, "Should not be running as root group")
+
+ return
+ }
+
+ // Fork a child process to test privilege dropping
+ cmd := os.Args[0]
+ args := []string{"-test.run=TestPrivilegeDropper_ActualPrivilegeDrop"}
+
+ env := append(os.Environ(), "TEST_PRIVILEGE_DROP=1")
+
+ execCmd := exec.Command(cmd, args...)
+ execCmd.Env = env
+
+ err = execCmd.Run()
+ require.NoError(t, err, "Child process should succeed")
+}
+
+// findNonRootUser finds any non-root user on the system for testing
+func findNonRootUser() (*user.User, error) {
+ // Try common non-root users, but avoid "nobody" on macOS due to negative UID issues
+ commonUsers := []string{"daemon", "bin", "sys", "sync", "games", "man", "lp", "mail", "news", "uucp", "proxy", "www-data", "backup", "list", "irc"}
+
+ for _, username := range commonUsers {
+ if u, err := user.Lookup(username); err == nil {
+ // Parse as signed integer first to handle negative UIDs
+ uid64, err := strconv.ParseInt(u.Uid, 10, 32)
+ if err != nil {
+ continue
+ }
+ // Skip negative UIDs (like nobody=-2 on macOS) and root
+ if uid64 > 0 && uid64 != 0 {
+ return u, nil
+ }
+ }
+ }
+
+ // If no common users found, try to find any regular user with UID > 100
+ // This helps on macOS where regular users start at UID 501
+ allUsers := []string{"vma", "user", "test", "admin"}
+ for _, username := range allUsers {
+ if u, err := user.Lookup(username); err == nil {
+ uid64, err := strconv.ParseInt(u.Uid, 10, 32)
+ if err != nil {
+ continue
+ }
+ if uid64 > 100 { // Regular user
+ return u, nil
+ }
+ }
+ }
+
+ // If no common users found, return an error
+ return nil, fmt.Errorf("no suitable non-root user found on this system")
+}
+
+func TestPrivilegeDropper_ExecuteWithPrivilegeDrop_Validation(t *testing.T) {
+ pd := NewPrivilegeDropper()
+ currentUID := uint32(os.Geteuid())
+
+ if currentUID == 0 {
+ // When running as root, test that root can create commands for any user
+ config := ExecutorConfig{
+ UID: 1000, // Target non-root user
+ GID: 1000,
+ Groups: []uint32{1000},
+ WorkingDir: "/tmp",
+ Shell: "/bin/sh",
+ Command: "echo test",
+ }
+
+ cmd, err := pd.CreateExecutorCommand(context.Background(), config)
+ assert.NoError(t, err, "Root should be able to create commands for any user")
+ assert.NotNil(t, cmd)
+ } else {
+ // When running as non-root, test that we can't drop to a different user
+ config := ExecutorConfig{
+ UID: 0, // Try to target root
+ GID: 0,
+ Groups: []uint32{0},
+ WorkingDir: "/tmp",
+ Shell: "/bin/sh",
+ Command: "echo test",
+ }
+
+ _, err := pd.CreateExecutorCommand(context.Background(), config)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "cannot drop privileges")
+ }
+}
diff --git a/client/ssh/server/executor_windows.go b/client/ssh/server/executor_windows.go
new file mode 100644
index 000000000..d3504e056
--- /dev/null
+++ b/client/ssh/server/executor_windows.go
@@ -0,0 +1,570 @@
+//go:build windows
+
+package server
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "os/exec"
+ "os/user"
+ "strings"
+ "syscall"
+ "unsafe"
+
+ "github.com/gliderlabs/ssh"
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/sys/windows"
+)
+
+const (
+ ExitCodeSuccess = 0
+ ExitCodeLogonFail = 10
+ ExitCodeCreateProcessFail = 11
+ ExitCodeWorkingDirFail = 12
+ ExitCodeShellExecFail = 13
+ ExitCodeValidationFail = 14
+)
+
+type WindowsExecutorConfig struct {
+ Username string
+ Domain string
+ WorkingDir string
+ Shell string
+ Command string
+ Args []string
+ Interactive bool
+ Pty bool
+ PtyWidth int
+ PtyHeight int
+}
+
+type PrivilegeDropper struct{}
+
+func NewPrivilegeDropper() *PrivilegeDropper {
+ return &PrivilegeDropper{}
+}
+
+var (
+ advapi32 = windows.NewLazyDLL("advapi32.dll")
+ procAllocateLocallyUniqueId = advapi32.NewProc("AllocateLocallyUniqueId")
+)
+
+const (
+ logon32LogonNetwork = 3 // Network logon - no password required for authenticated users
+
+ // Common error messages
+ commandFlag = "-Command"
+ closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials
+ convertUsernameError = "convert username to UTF16: %w"
+ convertDomainError = "convert domain to UTF16: %w"
+)
+
+// CreateWindowsExecutorCommand creates a Windows command with privilege dropping.
+// The caller must close the returned token handle after starting the process.
+func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, config WindowsExecutorConfig) (*exec.Cmd, windows.Token, error) {
+ if config.Username == "" {
+ return nil, 0, errors.New("username cannot be empty")
+ }
+ if config.Shell == "" {
+ return nil, 0, errors.New("shell cannot be empty")
+ }
+
+ shell := config.Shell
+
+ var shellArgs []string
+ if config.Command != "" {
+ shellArgs = []string{shell, commandFlag, config.Command}
+ } else {
+ shellArgs = []string{shell}
+ }
+
+ log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
+
+ cmd, token, err := pd.CreateWindowsProcessAsUser(
+ ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
+ if err != nil {
+ return nil, 0, fmt.Errorf("create Windows process as user: %w", err)
+ }
+
+ return cmd, token, nil
+}
+
+const (
+ // StatusSuccess represents successful LSA operation
+ StatusSuccess = 0
+
+ // KerbS4ULogonType message type for domain users with Kerberos
+ KerbS4ULogonType = 12
+ // Msv10s4ulogontype message type for local users with MSV1_0
+ Msv10s4ulogontype = 12
+
+ // MicrosoftKerberosNameA is the authentication package name for Kerberos
+ MicrosoftKerberosNameA = "Kerberos"
+ // Msv10packagename is the authentication package name for MSV1_0
+ Msv10packagename = "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0"
+
+ NameSamCompatible = 2
+ NameUserPrincipal = 8
+ NameCanonical = 7
+
+ maxUPNLen = 1024
+)
+
+// kerbS4ULogon structure for S4U authentication (domain users)
+type kerbS4ULogon struct {
+ MessageType uint32
+ Flags uint32
+ ClientUpn unicodeString
+ ClientRealm unicodeString
+}
+
+// msv10s4ulogon structure for S4U authentication (local users)
+type msv10s4ulogon struct {
+ MessageType uint32
+ Flags uint32
+ UserPrincipalName unicodeString
+ DomainName unicodeString
+}
+
+// unicodeString structure
+type unicodeString struct {
+ Length uint16
+ MaximumLength uint16
+ Buffer *uint16
+}
+
+// lsaString structure
+type lsaString struct {
+ Length uint16
+ MaximumLength uint16
+ Buffer *byte
+}
+
+// tokenSource structure
+type tokenSource struct {
+ SourceName [8]byte
+ SourceIdentifier windows.LUID
+}
+
+// quotaLimits structure
+type quotaLimits struct {
+ PagedPoolLimit uint32
+ NonPagedPoolLimit uint32
+ MinimumWorkingSetSize uint32
+ MaximumWorkingSetSize uint32
+ PagefileLimit uint32
+ TimeLimit int64
+}
+
+var (
+ secur32 = windows.NewLazyDLL("secur32.dll")
+ procLsaRegisterLogonProcess = secur32.NewProc("LsaRegisterLogonProcess")
+ procLsaLookupAuthenticationPackage = secur32.NewProc("LsaLookupAuthenticationPackage")
+ procLsaLogonUser = secur32.NewProc("LsaLogonUser")
+ procLsaFreeReturnBuffer = secur32.NewProc("LsaFreeReturnBuffer")
+ procLsaDeregisterLogonProcess = secur32.NewProc("LsaDeregisterLogonProcess")
+ procTranslateNameW = secur32.NewProc("TranslateNameW")
+)
+
+// newLsaString creates an LsaString from a Go string
+func newLsaString(s string) lsaString {
+ b := append([]byte(s), 0)
+ return lsaString{
+ Length: uint16(len(s)),
+ MaximumLength: uint16(len(b)),
+ Buffer: &b[0],
+ }
+}
+
+// generateS4UUserToken creates a Windows token using S4U authentication
+// This is the exact approach OpenSSH for Windows uses for public key authentication
+func generateS4UUserToken(username, domain string) (windows.Handle, error) {
+ userCpn := buildUserCpn(username, domain)
+
+ pd := NewPrivilegeDropper()
+ isDomainUser := !pd.isLocalUser(domain)
+
+ lsaHandle, err := initializeLsaConnection()
+ if err != nil {
+ return 0, err
+ }
+ defer cleanupLsaConnection(lsaHandle)
+
+ authPackageId, err := lookupAuthenticationPackage(lsaHandle, isDomainUser)
+ if err != nil {
+ return 0, err
+ }
+
+ logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser)
+ if err != nil {
+ return 0, err
+ }
+
+ return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
+}
+
+// buildUserCpn constructs the user principal name
+func buildUserCpn(username, domain string) string {
+ if domain != "" && domain != "." {
+ return fmt.Sprintf(`%s\%s`, domain, username)
+ }
+ return username
+}
+
+// initializeLsaConnection establishes connection to LSA
+func initializeLsaConnection() (windows.Handle, error) {
+
+ processName := newLsaString("NetBird")
+ var mode uint32
+ var lsaHandle windows.Handle
+ ret, _, _ := procLsaRegisterLogonProcess.Call(
+ uintptr(unsafe.Pointer(&processName)),
+ uintptr(unsafe.Pointer(&lsaHandle)),
+ uintptr(unsafe.Pointer(&mode)),
+ )
+ if ret != StatusSuccess {
+ return 0, fmt.Errorf("LsaRegisterLogonProcess: 0x%x", ret)
+ }
+
+ return lsaHandle, nil
+}
+
+// cleanupLsaConnection closes the LSA connection
+func cleanupLsaConnection(lsaHandle windows.Handle) {
+ if ret, _, _ := procLsaDeregisterLogonProcess.Call(uintptr(lsaHandle)); ret != StatusSuccess {
+ log.Debugf("LsaDeregisterLogonProcess failed: 0x%x", ret)
+ }
+}
+
+// lookupAuthenticationPackage finds the correct authentication package
+func lookupAuthenticationPackage(lsaHandle windows.Handle, isDomainUser bool) (uint32, error) {
+ var authPackageName lsaString
+ if isDomainUser {
+ authPackageName = newLsaString(MicrosoftKerberosNameA)
+ } else {
+ authPackageName = newLsaString(Msv10packagename)
+ }
+
+ var authPackageId uint32
+ ret, _, _ := procLsaLookupAuthenticationPackage.Call(
+ uintptr(lsaHandle),
+ uintptr(unsafe.Pointer(&authPackageName)),
+ uintptr(unsafe.Pointer(&authPackageId)),
+ )
+ if ret != StatusSuccess {
+ return 0, fmt.Errorf("LsaLookupAuthenticationPackage: 0x%x", ret)
+ }
+
+ return authPackageId, nil
+}
+
+// lookupPrincipalName converts DOMAIN\username to username@domain.fqdn (UPN format)
+func lookupPrincipalName(username, domain string) (string, error) {
+ samAccountName := fmt.Sprintf(`%s\%s`, domain, username)
+ samAccountNameUtf16, err := windows.UTF16PtrFromString(samAccountName)
+ if err != nil {
+ return "", fmt.Errorf("convert SAM account name to UTF-16: %w", err)
+ }
+
+ upnBuf := make([]uint16, maxUPNLen+1)
+ upnSize := uint32(len(upnBuf))
+
+ ret, _, _ := procTranslateNameW.Call(
+ uintptr(unsafe.Pointer(samAccountNameUtf16)),
+ uintptr(NameSamCompatible),
+ uintptr(NameUserPrincipal),
+ uintptr(unsafe.Pointer(&upnBuf[0])),
+ uintptr(unsafe.Pointer(&upnSize)),
+ )
+
+ if ret != 0 {
+ upn := windows.UTF16ToString(upnBuf[:upnSize])
+ log.Debugf("Translated %s to explicit UPN: %s", samAccountName, upn)
+ return upn, nil
+ }
+
+ upnSize = uint32(len(upnBuf))
+ ret, _, _ = procTranslateNameW.Call(
+ uintptr(unsafe.Pointer(samAccountNameUtf16)),
+ uintptr(NameSamCompatible),
+ uintptr(NameCanonical),
+ uintptr(unsafe.Pointer(&upnBuf[0])),
+ uintptr(unsafe.Pointer(&upnSize)),
+ )
+
+ if ret != 0 {
+ canonical := windows.UTF16ToString(upnBuf[:upnSize])
+ slashIdx := strings.IndexByte(canonical, '/')
+ if slashIdx > 0 {
+ fqdn := canonical[:slashIdx]
+ upn := fmt.Sprintf("%s@%s", username, fqdn)
+ log.Debugf("Translated %s to implicit UPN: %s (from canonical: %s)", samAccountName, upn, canonical)
+ return upn, nil
+ }
+ }
+
+ log.Debugf("Could not translate %s to UPN, using SAM format", samAccountName)
+ return samAccountName, nil
+}
+
+// prepareS4ULogonStructure creates the appropriate S4U logon structure
+func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
+ if isDomainUser {
+ return prepareDomainS4ULogon(username, domain)
+ }
+ return prepareLocalS4ULogon(username)
+}
+
+// prepareDomainS4ULogon creates S4U logon structure for domain users
+func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) {
+ upn, err := lookupPrincipalName(username, domain)
+ if err != nil {
+ return nil, 0, fmt.Errorf("lookup principal name: %w", err)
+ }
+
+ log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
+
+ upnUtf16, err := windows.UTF16FromString(upn)
+ if err != nil {
+ return nil, 0, fmt.Errorf(convertUsernameError, err)
+ }
+
+ structSize := unsafe.Sizeof(kerbS4ULogon{})
+ upnByteSize := len(upnUtf16) * 2
+ logonInfoSize := structSize + uintptr(upnByteSize)
+
+ buffer := make([]byte, logonInfoSize)
+ logonInfo := unsafe.Pointer(&buffer[0])
+
+ s4uLogon := (*kerbS4ULogon)(logonInfo)
+ s4uLogon.MessageType = KerbS4ULogonType
+ s4uLogon.Flags = 0
+
+ upnOffset := structSize
+ upnBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + upnOffset))
+ copy((*[1025]uint16)(unsafe.Pointer(upnBuffer))[:len(upnUtf16)], upnUtf16)
+
+ s4uLogon.ClientUpn = unicodeString{
+ Length: uint16((len(upnUtf16) - 1) * 2),
+ MaximumLength: uint16(len(upnUtf16) * 2),
+ Buffer: upnBuffer,
+ }
+ s4uLogon.ClientRealm = unicodeString{}
+
+ return logonInfo, logonInfoSize, nil
+}
+
+// prepareLocalS4ULogon creates S4U logon structure for local users
+func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
+ log.Debugf("using Msv1_0S4ULogon for local user: %s", username)
+
+ usernameUtf16, err := windows.UTF16FromString(username)
+ if err != nil {
+ return nil, 0, fmt.Errorf(convertUsernameError, err)
+ }
+
+ domainUtf16, err := windows.UTF16FromString(".")
+ if err != nil {
+ return nil, 0, fmt.Errorf(convertDomainError, err)
+ }
+
+ structSize := unsafe.Sizeof(msv10s4ulogon{})
+ usernameByteSize := len(usernameUtf16) * 2
+ domainByteSize := len(domainUtf16) * 2
+ logonInfoSize := structSize + uintptr(usernameByteSize) + uintptr(domainByteSize)
+
+ buffer := make([]byte, logonInfoSize)
+ logonInfo := unsafe.Pointer(&buffer[0])
+
+ s4uLogon := (*msv10s4ulogon)(logonInfo)
+ s4uLogon.MessageType = Msv10s4ulogontype
+ s4uLogon.Flags = 0x0
+
+ usernameOffset := structSize
+ usernameBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + usernameOffset))
+ copy((*[256]uint16)(unsafe.Pointer(usernameBuffer))[:len(usernameUtf16)], usernameUtf16)
+
+ s4uLogon.UserPrincipalName = unicodeString{
+ Length: uint16((len(usernameUtf16) - 1) * 2),
+ MaximumLength: uint16(len(usernameUtf16) * 2),
+ Buffer: usernameBuffer,
+ }
+
+ domainOffset := usernameOffset + uintptr(usernameByteSize)
+ domainBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + domainOffset))
+ copy((*[16]uint16)(unsafe.Pointer(domainBuffer))[:len(domainUtf16)], domainUtf16)
+
+ s4uLogon.DomainName = unicodeString{
+ Length: uint16((len(domainUtf16) - 1) * 2),
+ MaximumLength: uint16(len(domainUtf16) * 2),
+ Buffer: domainBuffer,
+ }
+
+ return logonInfo, logonInfoSize, nil
+}
+
+// performS4ULogon executes the S4U logon operation
+func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
+ var tokenSource tokenSource
+ copy(tokenSource.SourceName[:], "netbird")
+ if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
+ log.Debugf("AllocateLocallyUniqueId failed")
+ }
+
+ originName := newLsaString("netbird")
+
+ var profile uintptr
+ var profileSize uint32
+ var logonId windows.LUID
+ var token windows.Handle
+ var quotas quotaLimits
+ var subStatus int32
+
+ ret, _, _ := procLsaLogonUser.Call(
+ uintptr(lsaHandle),
+ uintptr(unsafe.Pointer(&originName)),
+ logon32LogonNetwork,
+ uintptr(authPackageId),
+ uintptr(logonInfo),
+ logonInfoSize,
+ 0,
+ uintptr(unsafe.Pointer(&tokenSource)),
+ uintptr(unsafe.Pointer(&profile)),
+ uintptr(unsafe.Pointer(&profileSize)),
+ uintptr(unsafe.Pointer(&logonId)),
+ uintptr(unsafe.Pointer(&token)),
+ uintptr(unsafe.Pointer("as)),
+ uintptr(unsafe.Pointer(&subStatus)),
+ )
+
+ if profile != 0 {
+ if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
+ log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
+ }
+ }
+
+ if ret != StatusSuccess {
+ return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
+ }
+
+ log.Debugf("created S4U %s token for user %s",
+ map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
+ return token, nil
+}
+
+// createToken implements NetBird trust-based authentication using S4U
+func (pd *PrivilegeDropper) createToken(username, domain string) (windows.Handle, error) {
+ fullUsername := buildUserCpn(username, domain)
+
+ if err := userExists(fullUsername, username, domain); err != nil {
+ return 0, err
+ }
+
+ isLocalUser := pd.isLocalUser(domain)
+
+ if isLocalUser {
+ return pd.authenticateLocalUser(username, fullUsername)
+ }
+ return pd.authenticateDomainUser(username, domain, fullUsername)
+}
+
+// userExists checks if the target useVerifier exists on the system
+func userExists(fullUsername, username, domain string) error {
+ if _, err := lookupUser(fullUsername); err != nil {
+ log.Debugf("User %s not found: %v", fullUsername, err)
+ if domain != "" && domain != "." {
+ _, err = lookupUser(username)
+ }
+ if err != nil {
+ return fmt.Errorf("target user %s not found: %w", fullUsername, err)
+ }
+ }
+ return nil
+}
+
+// isLocalUser determines if this is a local user vs domain user
+func (pd *PrivilegeDropper) isLocalUser(domain string) bool {
+ hostname, err := os.Hostname()
+ if err != nil {
+ hostname = "localhost"
+ }
+
+ return domain == "" || domain == "." ||
+ strings.EqualFold(domain, hostname)
+}
+
+// authenticateLocalUser handles authentication for local users
+func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
+ log.Debugf("using S4U authentication for local user %s", fullUsername)
+ token, err := generateS4UUserToken(username, ".")
+ if err != nil {
+ return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
+ }
+ return token, nil
+}
+
+// authenticateDomainUser handles authentication for domain users
+func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
+ log.Debugf("using S4U authentication for domain user %s", fullUsername)
+ token, err := generateS4UUserToken(username, domain)
+ if err != nil {
+ return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
+ }
+ log.Debugf("Successfully created S4U token for domain user %s", fullUsername)
+ return token, nil
+}
+
+// CreateWindowsProcessAsUser creates a process as user with safe argument passing (for SFTP and executables).
+// The caller must close the returned token handle after starting the process.
+func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, executablePath string, args []string, username, domain, workingDir string) (*exec.Cmd, windows.Token, error) {
+ token, err := pd.createToken(username, domain)
+ if err != nil {
+ return nil, 0, fmt.Errorf("user authentication: %w", err)
+ }
+
+ defer func() {
+ if err := windows.CloseHandle(token); err != nil {
+ log.Debugf("close impersonation token: %v", err)
+ }
+ }()
+
+ cmd, primaryToken, err := pd.createProcessWithToken(ctx, windows.Token(token), executablePath, args, workingDir)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ return cmd, primaryToken, nil
+}
+
+// createProcessWithToken creates process with the specified token and executable path.
+// The caller must close the returned token handle after starting the process.
+func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceToken windows.Token, executablePath string, args []string, workingDir string) (*exec.Cmd, windows.Token, error) {
+ cmd := exec.CommandContext(ctx, executablePath, args[1:]...)
+ cmd.Dir = workingDir
+
+ var primaryToken windows.Token
+ err := windows.DuplicateTokenEx(
+ sourceToken,
+ windows.TOKEN_ALL_ACCESS,
+ nil,
+ windows.SecurityIdentification,
+ windows.TokenPrimary,
+ &primaryToken,
+ )
+ if err != nil {
+ return nil, 0, fmt.Errorf("duplicate token to primary token: %w", err)
+ }
+
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Token: syscall.Token(primaryToken),
+ }
+
+ return cmd, primaryToken, nil
+}
+
+// createSuCommand creates a command using su -l -c for privilege switching (Windows stub)
+func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) {
+ return nil, fmt.Errorf("su command not available on Windows")
+}
diff --git a/client/ssh/server/jwt_test.go b/client/ssh/server/jwt_test.go
new file mode 100644
index 000000000..d36d7cbbf
--- /dev/null
+++ b/client/ssh/server/jwt_test.go
@@ -0,0 +1,647 @@
+package server
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "encoding/base64"
+ "encoding/json"
+ "io"
+ "math/big"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "runtime"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+ cryptossh "golang.org/x/crypto/ssh"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ nbssh "github.com/netbirdio/netbird/client/ssh"
+ sshauth "github.com/netbirdio/netbird/client/ssh/auth"
+ "github.com/netbirdio/netbird/client/ssh/client"
+ "github.com/netbirdio/netbird/client/ssh/detection"
+ "github.com/netbirdio/netbird/client/ssh/testutil"
+ nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
+ sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
+)
+
+func TestJWTEnforcement(t *testing.T) {
+ if testing.Short() {
+ t.Skip("Skipping JWT enforcement tests in short mode")
+ }
+
+ // Set up SSH server
+ hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ t.Run("blocks_without_jwt", func(t *testing.T) {
+ jwtConfig := &JWTConfig{
+ Issuer: "test-issuer",
+ Audience: "test-audience",
+ KeysLocation: "test-keys",
+ }
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: jwtConfig,
+ }
+ server := New(serverConfig)
+ server.SetAllowRootLogin(true)
+
+ serverAddr := StartTestServer(t, server)
+ defer require.NoError(t, server.Stop())
+
+ host, portStr, err := net.SplitHostPort(serverAddr)
+ require.NoError(t, err)
+ port, err := strconv.Atoi(portStr)
+ require.NoError(t, err)
+ dialer := &net.Dialer{}
+ serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
+ if err != nil {
+ t.Logf("Detection failed: %v", err)
+ }
+ t.Logf("Detected server type: %s", serverType)
+
+ config := &cryptossh.ClientConfig{
+ User: testutil.GetTestUsername(t),
+ Auth: []cryptossh.AuthMethod{},
+ HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
+ Timeout: 2 * time.Second,
+ }
+
+ _, err = cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
+ assert.Error(t, err, "SSH connection should fail when JWT is required but not provided")
+ })
+
+ t.Run("allows_when_disabled", func(t *testing.T) {
+ serverConfigNoJWT := &Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ serverNoJWT := New(serverConfigNoJWT)
+ require.False(t, serverNoJWT.jwtEnabled, "JWT should be disabled without config")
+ serverNoJWT.SetAllowRootLogin(true)
+
+ serverAddrNoJWT := StartTestServer(t, serverNoJWT)
+ defer require.NoError(t, serverNoJWT.Stop())
+
+ hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
+ require.NoError(t, err)
+ portNoJWT, err := strconv.Atoi(portStrNoJWT)
+ require.NoError(t, err)
+
+ dialer := &net.Dialer{}
+ serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT)
+ require.NoError(t, err)
+ assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType)
+ assert.False(t, serverType.RequiresJWT())
+
+ client, err := connectWithNetBirdClient(t, hostNoJWT, portNoJWT)
+ require.NoError(t, err)
+ defer client.Close()
+ })
+
+}
+
+// setupJWKSServer creates a test HTTP server serving JWKS and returns the server, private key, and URL
+func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
+ privateKey, jwksJSON := generateTestJWKS(t)
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ if _, err := w.Write(jwksJSON); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
+ }))
+
+ return server, privateKey, server.URL
+}
+
+// generateTestJWKS creates a test RSA key pair and returns private key and JWKS JSON
+func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ publicKey := &privateKey.PublicKey
+ n := publicKey.N.Bytes()
+ e := publicKey.E
+
+ jwk := nbjwt.JSONWebKey{
+ Kty: "RSA",
+ Kid: "test-key-id",
+ Use: "sig",
+ N: base64RawURLEncode(n),
+ E: base64RawURLEncode(big.NewInt(int64(e)).Bytes()),
+ }
+
+ jwks := nbjwt.Jwks{
+ Keys: []nbjwt.JSONWebKey{jwk},
+ }
+
+ jwksJSON, err := json.Marshal(jwks)
+ require.NoError(t, err)
+
+ return privateKey, jwksJSON
+}
+
+func base64RawURLEncode(data []byte) string {
+ return base64.RawURLEncoding.EncodeToString(data)
+}
+
+// generateValidJWT creates a valid JWT token for testing
+func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
+ claims := jwt.MapClaims{
+ "iss": issuer,
+ "aud": audience,
+ "sub": "test-user",
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ }
+
+ token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
+ token.Header["kid"] = "test-key-id"
+
+ tokenString, err := token.SignedString(privateKey)
+ require.NoError(t, err)
+
+ return tokenString
+}
+
+// connectWithNetBirdClient connects to SSH server using NetBird's SSH client
+func connectWithNetBirdClient(t *testing.T, host string, port int) (*client.Client, error) {
+ t.Helper()
+ addr := net.JoinHostPort(host, strconv.Itoa(port))
+
+ ctx := context.Background()
+ return client.Dial(ctx, addr, testutil.GetTestUsername(t), client.DialOptions{
+ InsecureSkipVerify: true,
+ })
+}
+
+// TestJWTDetection tests that server detection correctly identifies JWT-enabled servers
+func TestJWTDetection(t *testing.T) {
+ if testing.Short() {
+ t.Skip("Skipping JWT detection test in short mode")
+ }
+
+ jwksServer, _, jwksURL := setupJWKSServer(t)
+ defer jwksServer.Close()
+
+ hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ const (
+ issuer = "https://test-issuer.example.com"
+ audience = "test-audience"
+ )
+
+ jwtConfig := &JWTConfig{
+ Issuer: issuer,
+ Audience: audience,
+ KeysLocation: jwksURL,
+ }
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: jwtConfig,
+ }
+ server := New(serverConfig)
+ server.SetAllowRootLogin(true)
+
+ serverAddr := StartTestServer(t, server)
+ defer require.NoError(t, server.Stop())
+
+ host, portStr, err := net.SplitHostPort(serverAddr)
+ require.NoError(t, err)
+ port, err := strconv.Atoi(portStr)
+ require.NoError(t, err)
+
+ dialer := &net.Dialer{}
+ serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
+ require.NoError(t, err)
+ assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType)
+ assert.True(t, serverType.RequiresJWT())
+}
+
+func TestJWTFailClose(t *testing.T) {
+ if testing.Short() {
+ t.Skip("Skipping JWT fail-close tests in short mode")
+ }
+
+ jwksServer, privateKey, jwksURL := setupJWKSServer(t)
+ defer jwksServer.Close()
+
+ const (
+ issuer = "https://test-issuer.example.com"
+ audience = "test-audience"
+ )
+
+ hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ testCases := []struct {
+ name string
+ tokenClaims jwt.MapClaims
+ }{
+ {
+ name: "blocks_token_missing_iat",
+ tokenClaims: jwt.MapClaims{
+ "iss": issuer,
+ "aud": audience,
+ "sub": "test-user",
+ "exp": time.Now().Add(time.Hour).Unix(),
+ },
+ },
+ {
+ name: "blocks_token_missing_sub",
+ tokenClaims: jwt.MapClaims{
+ "iss": issuer,
+ "aud": audience,
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ },
+ },
+ {
+ name: "blocks_token_missing_iss",
+ tokenClaims: jwt.MapClaims{
+ "aud": audience,
+ "sub": "test-user",
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ },
+ },
+ {
+ name: "blocks_token_missing_aud",
+ tokenClaims: jwt.MapClaims{
+ "iss": issuer,
+ "sub": "test-user",
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ },
+ },
+ {
+ name: "blocks_token_wrong_issuer",
+ tokenClaims: jwt.MapClaims{
+ "iss": "wrong-issuer",
+ "aud": audience,
+ "sub": "test-user",
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ },
+ },
+ {
+ name: "blocks_token_wrong_audience",
+ tokenClaims: jwt.MapClaims{
+ "iss": issuer,
+ "aud": "wrong-audience",
+ "sub": "test-user",
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ },
+ },
+ {
+ name: "blocks_expired_token",
+ tokenClaims: jwt.MapClaims{
+ "iss": issuer,
+ "aud": audience,
+ "sub": "test-user",
+ "exp": time.Now().Add(-time.Hour).Unix(),
+ "iat": time.Now().Add(-2 * time.Hour).Unix(),
+ },
+ },
+ {
+ name: "blocks_token_exceeding_max_age",
+ tokenClaims: jwt.MapClaims{
+ "iss": issuer,
+ "aud": audience,
+ "sub": "test-user",
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Add(-2 * time.Hour).Unix(),
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ jwtConfig := &JWTConfig{
+ Issuer: issuer,
+ Audience: audience,
+ KeysLocation: jwksURL,
+ MaxTokenAge: 3600,
+ }
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: jwtConfig,
+ }
+ server := New(serverConfig)
+ server.SetAllowRootLogin(true)
+
+ serverAddr := StartTestServer(t, server)
+ defer require.NoError(t, server.Stop())
+
+ host, portStr, err := net.SplitHostPort(serverAddr)
+ require.NoError(t, err)
+
+ token := jwt.NewWithClaims(jwt.SigningMethodRS256, tc.tokenClaims)
+ token.Header["kid"] = "test-key-id"
+ tokenString, err := token.SignedString(privateKey)
+ require.NoError(t, err)
+
+ config := &cryptossh.ClientConfig{
+ User: testutil.GetTestUsername(t),
+ Auth: []cryptossh.AuthMethod{
+ cryptossh.Password(tokenString),
+ },
+ HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
+ Timeout: 2 * time.Second,
+ }
+
+ conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
+ if conn != nil {
+ defer func() {
+ if err := conn.Close(); err != nil {
+ t.Logf("close connection: %v", err)
+ }
+ }()
+ }
+
+ assert.Error(t, err, "Authentication should fail (fail-close)")
+ })
+ }
+}
+
+// TestJWTAuthentication tests JWT authentication with valid/invalid tokens and enforcement for various connection types
+func TestJWTAuthentication(t *testing.T) {
+ if testing.Short() {
+ t.Skip("Skipping JWT authentication tests in short mode")
+ }
+
+ jwksServer, privateKey, jwksURL := setupJWKSServer(t)
+ defer jwksServer.Close()
+
+ const (
+ issuer = "https://test-issuer.example.com"
+ audience = "test-audience"
+ )
+
+ hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ testCases := []struct {
+ name string
+ token string
+ wantAuthOK bool
+ setupServer func(*Server)
+ testOperation func(*testing.T, *cryptossh.Client, string) error
+ wantOpSuccess bool
+ }{
+ {
+ name: "allows_shell_with_jwt",
+ token: "valid",
+ wantAuthOK: true,
+ setupServer: func(s *Server) {
+ s.SetAllowRootLogin(true)
+ },
+ testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
+ session, err := conn.NewSession()
+ require.NoError(t, err)
+ defer session.Close()
+ return session.Shell()
+ },
+ wantOpSuccess: true,
+ },
+ {
+ name: "rejects_invalid_token",
+ token: "invalid",
+ wantAuthOK: false,
+ setupServer: func(s *Server) {
+ s.SetAllowRootLogin(true)
+ },
+ testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
+ session, err := conn.NewSession()
+ require.NoError(t, err)
+ defer session.Close()
+
+ output, err := session.CombinedOutput("echo test")
+ if err != nil {
+ t.Logf("Command output: %s", string(output))
+ return err
+ }
+ return nil
+ },
+ wantOpSuccess: false,
+ },
+ {
+ name: "blocks_shell_without_jwt",
+ token: "",
+ wantAuthOK: false,
+ setupServer: func(s *Server) {
+ s.SetAllowRootLogin(true)
+ },
+ testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
+ session, err := conn.NewSession()
+ require.NoError(t, err)
+ defer session.Close()
+
+ output, err := session.CombinedOutput("echo test")
+ if err != nil {
+ t.Logf("Command output: %s", string(output))
+ return err
+ }
+ return nil
+ },
+ wantOpSuccess: false,
+ },
+ {
+ name: "blocks_command_without_jwt",
+ token: "",
+ wantAuthOK: false,
+ setupServer: func(s *Server) {
+ s.SetAllowRootLogin(true)
+ },
+ testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
+ session, err := conn.NewSession()
+ require.NoError(t, err)
+ defer session.Close()
+
+ output, err := session.CombinedOutput("ls")
+ if err != nil {
+ t.Logf("Command output: %s", string(output))
+ return err
+ }
+ return nil
+ },
+ wantOpSuccess: false,
+ },
+ {
+ name: "allows_sftp_with_jwt",
+ token: "valid",
+ wantAuthOK: true,
+ setupServer: func(s *Server) {
+ s.SetAllowRootLogin(true)
+ s.SetAllowSFTP(true)
+ },
+ testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
+ session, err := conn.NewSession()
+ require.NoError(t, err)
+ defer session.Close()
+
+ session.Stdout = io.Discard
+ session.Stderr = io.Discard
+ return session.RequestSubsystem("sftp")
+ },
+ wantOpSuccess: true,
+ },
+ {
+ name: "blocks_sftp_without_jwt",
+ token: "",
+ wantAuthOK: false,
+ setupServer: func(s *Server) {
+ s.SetAllowRootLogin(true)
+ s.SetAllowSFTP(true)
+ },
+ testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
+ session, err := conn.NewSession()
+ require.NoError(t, err)
+ defer session.Close()
+
+ session.Stdout = io.Discard
+ session.Stderr = io.Discard
+ err = session.RequestSubsystem("sftp")
+ if err == nil {
+ err = session.Wait()
+ }
+ return err
+ },
+ wantOpSuccess: false,
+ },
+ {
+ name: "allows_port_forward_with_jwt",
+ token: "valid",
+ wantAuthOK: true,
+ setupServer: func(s *Server) {
+ s.SetAllowRootLogin(true)
+ s.SetAllowRemotePortForwarding(true)
+ },
+ testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
+ ln, err := conn.Listen("tcp", "127.0.0.1:0")
+ if ln != nil {
+ defer ln.Close()
+ }
+ return err
+ },
+ wantOpSuccess: true,
+ },
+ {
+ name: "blocks_port_forward_without_jwt",
+ token: "",
+ wantAuthOK: false,
+ setupServer: func(s *Server) {
+ s.SetAllowRootLogin(true)
+ s.SetAllowLocalPortForwarding(true)
+ },
+ testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
+ ln, err := conn.Listen("tcp", "127.0.0.1:0")
+ if ln != nil {
+ defer ln.Close()
+ }
+ return err
+ },
+ wantOpSuccess: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ // TODO: Skip port forwarding tests on Windows - user switching not supported
+ // These features are tested on Linux/Unix platforms
+ if runtime.GOOS == "windows" &&
+ (tc.name == "allows_port_forward_with_jwt" ||
+ tc.name == "blocks_port_forward_without_jwt") {
+ t.Skip("Skipping port forwarding test on Windows - covered by Linux tests")
+ }
+
+ jwtConfig := &JWTConfig{
+ Issuer: issuer,
+ Audience: audience,
+ KeysLocation: jwksURL,
+ }
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: jwtConfig,
+ }
+ server := New(serverConfig)
+ if tc.setupServer != nil {
+ tc.setupServer(server)
+ }
+
+ // Always set up authorization for test-user to ensure tests fail at JWT validation stage
+ testUserHash, err := sshuserhash.HashUserID("test-user")
+ require.NoError(t, err)
+
+ // Get current OS username for machine user mapping
+ currentUser := testutil.GetTestUsername(t)
+
+ authConfig := &sshauth.Config{
+ UserIDClaim: sshauth.DefaultUserIDClaim,
+ AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
+ MachineUsers: map[string][]uint32{
+ currentUser: {0}, // Allow test-user (index 0) to access current OS user
+ },
+ }
+ server.UpdateSSHAuth(authConfig)
+
+ serverAddr := StartTestServer(t, server)
+ defer require.NoError(t, server.Stop())
+
+ host, portStr, err := net.SplitHostPort(serverAddr)
+ require.NoError(t, err)
+
+ var authMethods []cryptossh.AuthMethod
+ if tc.token == "valid" {
+ token := generateValidJWT(t, privateKey, issuer, audience)
+ authMethods = []cryptossh.AuthMethod{
+ cryptossh.Password(token),
+ }
+ } else if tc.token == "invalid" {
+ invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid"
+ authMethods = []cryptossh.AuthMethod{
+ cryptossh.Password(invalidToken),
+ }
+ }
+
+ config := &cryptossh.ClientConfig{
+ User: testutil.GetTestUsername(t),
+ Auth: authMethods,
+ HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
+ Timeout: 2 * time.Second,
+ }
+
+ conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
+ if tc.wantAuthOK {
+ require.NoError(t, err, "JWT authentication should succeed")
+ } else if err != nil {
+ t.Logf("Connection failed as expected: %v", err)
+ return
+ }
+ if conn != nil {
+ defer func() {
+ if err := conn.Close(); err != nil {
+ t.Logf("close connection: %v", err)
+ }
+ }()
+ }
+
+ err = tc.testOperation(t, conn, serverAddr)
+ if tc.wantOpSuccess {
+ require.NoError(t, err, "Operation should succeed")
+ } else {
+ assert.Error(t, err, "Operation should fail")
+ }
+ })
+ }
+}
diff --git a/client/ssh/server/port_forwarding.go b/client/ssh/server/port_forwarding.go
new file mode 100644
index 000000000..6138f9296
--- /dev/null
+++ b/client/ssh/server/port_forwarding.go
@@ -0,0 +1,386 @@
+package server
+
+import (
+ "encoding/binary"
+ "fmt"
+ "io"
+ "net"
+ "strconv"
+
+ "github.com/gliderlabs/ssh"
+ log "github.com/sirupsen/logrus"
+ cryptossh "golang.org/x/crypto/ssh"
+)
+
+// SessionKey uniquely identifies an SSH session
+type SessionKey string
+
+// ConnectionKey uniquely identifies a port forwarding connection within a session
+type ConnectionKey string
+
+// ForwardKey uniquely identifies a port forwarding listener
+type ForwardKey string
+
+// tcpipForwardMsg represents the structure for tcpip-forward SSH requests
+type tcpipForwardMsg struct {
+ Host string
+ Port uint32
+}
+
+// SetAllowLocalPortForwarding configures local port forwarding
+func (s *Server) SetAllowLocalPortForwarding(allow bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.allowLocalPortForwarding = allow
+}
+
+// SetAllowRemotePortForwarding configures remote port forwarding
+func (s *Server) SetAllowRemotePortForwarding(allow bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.allowRemotePortForwarding = allow
+}
+
+// configurePortForwarding sets up port forwarding callbacks
+func (s *Server) configurePortForwarding(server *ssh.Server) {
+ allowLocal := s.allowLocalPortForwarding
+ allowRemote := s.allowRemotePortForwarding
+
+ server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
+ if !allowLocal {
+ log.Warnf("local port forwarding denied for %s from %s: disabled by configuration",
+ net.JoinHostPort(dstHost, fmt.Sprintf("%d", dstPort)), ctx.RemoteAddr())
+ return false
+ }
+
+ if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
+ log.Warnf("local port forwarding denied for %s:%d from %s: %v", dstHost, dstPort, ctx.RemoteAddr(), err)
+ return false
+ }
+
+ log.Debugf("local port forwarding allowed: %s:%d", dstHost, dstPort)
+ return true
+ }
+
+ server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
+ if !allowRemote {
+ log.Warnf("remote port forwarding denied for %s from %s: disabled by configuration",
+ net.JoinHostPort(bindHost, fmt.Sprintf("%d", bindPort)), ctx.RemoteAddr())
+ return false
+ }
+
+ if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
+ log.Warnf("remote port forwarding denied for %s:%d from %s: %v", bindHost, bindPort, ctx.RemoteAddr(), err)
+ return false
+ }
+
+ log.Debugf("remote port forwarding allowed: %s:%d", bindHost, bindPort)
+ return true
+ }
+
+ log.Debugf("SSH server configured with local_forwarding=%v, remote_forwarding=%v", allowLocal, allowRemote)
+}
+
+// checkPortForwardingPrivileges validates privilege requirements for port forwarding operations.
+// Returns nil if allowed, error if denied.
+func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error {
+ if ctx == nil {
+ return fmt.Errorf("%s port forwarding denied: no context", forwardType)
+ }
+
+ username := ctx.User()
+ remoteAddr := "unknown"
+ if ctx.RemoteAddr() != nil {
+ remoteAddr = ctx.RemoteAddr().String()
+ }
+
+ logger := log.WithFields(log.Fields{"user": username, "remote": remoteAddr, "port": port})
+
+ result := s.CheckPrivileges(PrivilegeCheckRequest{
+ RequestedUsername: username,
+ FeatureSupportsUserSwitch: false,
+ FeatureName: forwardType + " port forwarding",
+ })
+
+ if !result.Allowed {
+ return result.Error
+ }
+
+ logger.Debugf("%s port forwarding allowed: user %s validated (port %d)",
+ forwardType, result.User.Username, port)
+
+ return nil
+}
+
+// tcpipForwardHandler handles tcpip-forward requests for remote port forwarding.
+func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
+ logger := s.getRequestLogger(ctx)
+
+ if !s.isRemotePortForwardingAllowed() {
+ logger.Warnf("tcpip-forward request denied: remote port forwarding disabled")
+ return false, nil
+ }
+
+ payload, err := s.parseTcpipForwardRequest(req)
+ if err != nil {
+ logger.Errorf("tcpip-forward unmarshal error: %v", err)
+ return false, nil
+ }
+
+ if err := s.checkPortForwardingPrivileges(ctx, "tcpip-forward", payload.Port); err != nil {
+ logger.Warnf("tcpip-forward denied: %v", err)
+ return false, nil
+ }
+
+ logger.Debugf("tcpip-forward request: %s:%d", payload.Host, payload.Port)
+
+ sshConn, err := s.getSSHConnection(ctx)
+ if err != nil {
+ logger.Warnf("tcpip-forward request denied: %v", err)
+ return false, nil
+ }
+
+ return s.setupDirectForward(ctx, logger, sshConn, payload)
+}
+
+// cancelTcpipForwardHandler handles cancel-tcpip-forward requests.
+func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
+ logger := s.getRequestLogger(ctx)
+
+ var payload tcpipForwardMsg
+ if err := cryptossh.Unmarshal(req.Payload, &payload); err != nil {
+ logger.Errorf("cancel-tcpip-forward unmarshal error: %v", err)
+ return false, nil
+ }
+
+ key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
+ if s.removeRemoteForwardListener(key) {
+ logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
+ return true, nil
+ }
+
+ logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port)
+ return false, nil
+}
+
+// handleRemoteForwardListener handles incoming connections for remote port forwarding.
+func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) {
+ log.Debugf("starting remote forward listener handler for %s:%d", host, port)
+
+ defer func() {
+ log.Debugf("cleaning up remote forward listener for %s:%d", host, port)
+ if err := ln.Close(); err != nil {
+ log.Debugf("remote forward listener close error: %v", err)
+ } else {
+ log.Debugf("remote forward listener closed successfully for %s:%d", host, port)
+ }
+ }()
+
+ acceptChan := make(chan acceptResult, 1)
+
+ go func() {
+ for {
+ conn, err := ln.Accept()
+ select {
+ case acceptChan <- acceptResult{conn: conn, err: err}:
+ if err != nil {
+ return
+ }
+ case <-ctx.Done():
+ return
+ }
+ }
+ }()
+
+ for {
+ select {
+ case result := <-acceptChan:
+ if result.err != nil {
+ log.Debugf("remote forward accept error: %v", result.err)
+ return
+ }
+ go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
+ case <-ctx.Done():
+ log.Debugf("remote forward listener shutting down due to context cancellation for %s:%d", host, port)
+ return
+ }
+ }
+}
+
+// getRequestLogger creates a logger with user and remote address context
+func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry {
+ remoteAddr := "unknown"
+ username := "unknown"
+ if ctx != nil {
+ if ctx.RemoteAddr() != nil {
+ remoteAddr = ctx.RemoteAddr().String()
+ }
+ username = ctx.User()
+ }
+ return log.WithFields(log.Fields{"user": username, "remote": remoteAddr})
+}
+
+// isRemotePortForwardingAllowed checks if remote port forwarding is enabled
+func (s *Server) isRemotePortForwardingAllowed() bool {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return s.allowRemotePortForwarding
+}
+
+// parseTcpipForwardRequest parses the SSH request payload
+func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
+ var payload tcpipForwardMsg
+ err := cryptossh.Unmarshal(req.Payload, &payload)
+ return &payload, err
+}
+
+// getSSHConnection extracts SSH connection from context
+func (s *Server) getSSHConnection(ctx ssh.Context) (*cryptossh.ServerConn, error) {
+ if ctx == nil {
+ return nil, fmt.Errorf("no context")
+ }
+ sshConnValue := ctx.Value(ssh.ContextKeyConn)
+ if sshConnValue == nil {
+ return nil, fmt.Errorf("no SSH connection in context")
+ }
+ sshConn, ok := sshConnValue.(*cryptossh.ServerConn)
+ if !ok || sshConn == nil {
+ return nil, fmt.Errorf("invalid SSH connection in context")
+ }
+ return sshConn, nil
+}
+
+// setupDirectForward sets up a direct port forward
+func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn *cryptossh.ServerConn, payload *tcpipForwardMsg) (bool, []byte) {
+ bindAddr := net.JoinHostPort(payload.Host, strconv.FormatUint(uint64(payload.Port), 10))
+
+ ln, err := net.Listen("tcp", bindAddr)
+ if err != nil {
+ logger.Errorf("tcpip-forward listen failed on %s: %v", bindAddr, err)
+ return false, nil
+ }
+
+ actualPort := payload.Port
+ if payload.Port == 0 {
+ tcpAddr := ln.Addr().(*net.TCPAddr)
+ actualPort = uint32(tcpAddr.Port)
+ logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
+ }
+
+ key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
+ s.storeRemoteForwardListener(key, ln)
+
+ s.markConnectionActivePortForward(sshConn, ctx.User(), ctx.RemoteAddr().String())
+ go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
+
+ response := make([]byte, 4)
+ binary.BigEndian.PutUint32(response, actualPort)
+
+ logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort)
+ return true, response
+}
+
+// acceptResult holds the result of a listener Accept() call
+type acceptResult struct {
+ conn net.Conn
+ err error
+}
+
+// handleRemoteForwardConnection handles a single remote port forwarding connection
+func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) {
+ sessionKey := s.findSessionKeyByContext(ctx)
+ connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port)
+ logger := log.WithFields(log.Fields{
+ "session": sessionKey,
+ "conn": connID,
+ })
+
+ defer func() {
+ if err := conn.Close(); err != nil {
+ logger.Debugf("connection close error: %v", err)
+ }
+ }()
+
+ sshConn := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
+ if sshConn == nil {
+ logger.Debugf("remote forward: no SSH connection in context")
+ return
+ }
+
+ remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
+ if !ok {
+ logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr())
+ return
+ }
+
+ channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger)
+ if err != nil {
+ logger.Debugf("open forward channel: %v", err)
+ return
+ }
+
+ s.proxyForwardConnection(ctx, logger, conn, channel)
+}
+
+// openForwardChannel creates an SSH forwarded-tcpip channel
+func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr, logger *log.Entry) (cryptossh.Channel, error) {
+ logger.Tracef("opening forwarded-tcpip channel for %s:%d", host, port)
+
+ payload := struct {
+ ConnectedAddress string
+ ConnectedPort uint32
+ OriginatorAddress string
+ OriginatorPort uint32
+ }{
+ ConnectedAddress: host,
+ ConnectedPort: port,
+ OriginatorAddress: remoteAddr.IP.String(),
+ OriginatorPort: uint32(remoteAddr.Port),
+ }
+
+ channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", cryptossh.Marshal(&payload))
+ if err != nil {
+ return nil, fmt.Errorf("open SSH channel: %w", err)
+ }
+
+ go cryptossh.DiscardRequests(reqs)
+ return channel, nil
+}
+
+// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel
+func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) {
+ done := make(chan struct{}, 2)
+
+ go func() {
+ if _, err := io.Copy(channel, conn); err != nil {
+ logger.Debugf("copy error (conn->channel): %v", err)
+ }
+ done <- struct{}{}
+ }()
+
+ go func() {
+ if _, err := io.Copy(conn, channel); err != nil {
+ logger.Debugf("copy error (channel->conn): %v", err)
+ }
+ done <- struct{}{}
+ }()
+
+ select {
+ case <-ctx.Done():
+ logger.Debugf("session ended, closing connections")
+ case <-done:
+ // First copy finished, wait for second copy or context cancellation
+ select {
+ case <-ctx.Done():
+ logger.Debugf("session ended, closing connections")
+ case <-done:
+ }
+ }
+
+ if err := channel.Close(); err != nil {
+ logger.Debugf("channel close error: %v", err)
+ }
+ if err := conn.Close(); err != nil {
+ logger.Debugf("connection close error: %v", err)
+ }
+}
diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go
new file mode 100644
index 000000000..82718d002
--- /dev/null
+++ b/client/ssh/server/server.go
@@ -0,0 +1,751 @@
+package server
+
+import (
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/netip"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/gliderlabs/ssh"
+ gojwt "github.com/golang-jwt/jwt/v5"
+ log "github.com/sirupsen/logrus"
+ cryptossh "golang.org/x/crypto/ssh"
+ "golang.org/x/exp/maps"
+ "golang.zx2c4.com/wireguard/tun/netstack"
+
+ "github.com/netbirdio/netbird/client/iface/wgaddr"
+ sshauth "github.com/netbirdio/netbird/client/ssh/auth"
+ "github.com/netbirdio/netbird/client/ssh/detection"
+ "github.com/netbirdio/netbird/shared/auth"
+ "github.com/netbirdio/netbird/shared/auth/jwt"
+ "github.com/netbirdio/netbird/version"
+)
+
+// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
+const DefaultSSHPort = 22
+
+// InternalSSHPort is the port SSH server listens on and is redirected to
+const InternalSSHPort = 22022
+
+const (
+ errWriteSession = "write session error: %v"
+ errExitSession = "exit session error: %v"
+
+ msgPrivilegedUserDisabled = "privileged user login is disabled"
+
+ // DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
+ DefaultJWTMaxTokenAge = 5 * 60
+)
+
+var (
+ ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled)
+ ErrUserNotFound = errors.New("user not found")
+)
+
+// PrivilegedUserError represents an error when privileged user login is disabled
+type PrivilegedUserError struct {
+ Username string
+}
+
+func (e *PrivilegedUserError) Error() string {
+ return fmt.Sprintf("%s for user: %s", msgPrivilegedUserDisabled, e.Username)
+}
+
+func (e *PrivilegedUserError) Is(target error) bool {
+ return target == ErrPrivilegedUserDisabled
+}
+
+// UserNotFoundError represents an error when a user cannot be found
+type UserNotFoundError struct {
+ Username string
+ Cause error
+}
+
+func (e *UserNotFoundError) Error() string {
+ if e.Cause != nil {
+ return fmt.Sprintf("user %s not found: %v", e.Username, e.Cause)
+ }
+ return fmt.Sprintf("user %s not found", e.Username)
+}
+
+func (e *UserNotFoundError) Is(target error) bool {
+ return target == ErrUserNotFound
+}
+
+func (e *UserNotFoundError) Unwrap() error {
+ return e.Cause
+}
+
+// logSessionExitError logs session exit errors, ignoring EOF (normal close) errors
+func logSessionExitError(logger *log.Entry, err error) {
+ if err != nil && !errors.Is(err, io.EOF) {
+ logger.Warnf(errExitSession, err)
+ }
+}
+
+// safeLogCommand returns a safe representation of the command for logging
+func safeLogCommand(cmd []string) string {
+ if len(cmd) == 0 {
+ return "
"
+ }
+ if len(cmd) == 1 {
+ return cmd[0]
+ }
+ return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
+}
+
+type sshConnectionState struct {
+ hasActivePortForward bool
+ username string
+ remoteAddr string
+}
+
+type authKey string
+
+func newAuthKey(username string, remoteAddr net.Addr) authKey {
+ return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String()))
+}
+
+type Server struct {
+ sshServer *ssh.Server
+ mu sync.RWMutex
+ hostKeyPEM []byte
+ sessions map[SessionKey]ssh.Session
+ sessionCancels map[ConnectionKey]context.CancelFunc
+ sessionJWTUsers map[SessionKey]string
+ pendingAuthJWT map[authKey]string
+
+ allowLocalPortForwarding bool
+ allowRemotePortForwarding bool
+ allowRootLogin bool
+ allowSFTP bool
+ jwtEnabled bool
+
+ netstackNet *netstack.Net
+
+ wgAddress wgaddr.Address
+
+ remoteForwardListeners map[ForwardKey]net.Listener
+ sshConnections map[*cryptossh.ServerConn]*sshConnectionState
+
+ jwtValidator *jwt.Validator
+ jwtExtractor *jwt.ClaimsExtractor
+ jwtConfig *JWTConfig
+
+ authorizer *sshauth.Authorizer
+
+ suSupportsPty bool
+ loginIsUtilLinux bool
+}
+
+type JWTConfig struct {
+ Issuer string
+ Audience string
+ KeysLocation string
+ MaxTokenAge int64
+}
+
+// Config contains all SSH server configuration options
+type Config struct {
+ // JWT authentication configuration. If nil, JWT authentication is disabled
+ JWT *JWTConfig
+
+ // HostKey is the SSH server host key in PEM format
+ HostKeyPEM []byte
+}
+
+// SessionInfo contains information about an active SSH session
+type SessionInfo struct {
+ Username string
+ RemoteAddress string
+ Command string
+ JWTUsername string
+}
+
+// New creates an SSH server instance with the provided host key and optional JWT configuration
+// If jwtConfig is nil, JWT authentication is disabled
+func New(config *Config) *Server {
+ s := &Server{
+ mu: sync.RWMutex{},
+ hostKeyPEM: config.HostKeyPEM,
+ sessions: make(map[SessionKey]ssh.Session),
+ sessionJWTUsers: make(map[SessionKey]string),
+ pendingAuthJWT: make(map[authKey]string),
+ remoteForwardListeners: make(map[ForwardKey]net.Listener),
+ sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
+ jwtEnabled: config.JWT != nil,
+ jwtConfig: config.JWT,
+ authorizer: sshauth.NewAuthorizer(), // Initialize with empty config
+ }
+
+ return s
+}
+
+// Start runs the SSH server
+func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if s.sshServer != nil {
+ return errors.New("SSH server is already running")
+ }
+
+ s.suSupportsPty = s.detectSuPtySupport(ctx)
+ s.loginIsUtilLinux = s.detectUtilLinuxLogin(ctx)
+
+ ln, addrDesc, err := s.createListener(ctx, addr)
+ if err != nil {
+ return fmt.Errorf("create listener: %w", err)
+ }
+
+ sshServer, err := s.createSSHServer(ln.Addr())
+ if err != nil {
+ s.closeListener(ln)
+ return fmt.Errorf("create SSH server: %w", err)
+ }
+
+ s.sshServer = sshServer
+ log.Infof("SSH server started on %s", addrDesc)
+
+ go func() {
+ if err := sshServer.Serve(ln); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
+ log.Errorf("SSH server error: %v", err)
+ }
+ }()
+ return nil
+}
+
+func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) {
+ if s.netstackNet != nil {
+ ln, err := s.netstackNet.ListenTCPAddrPort(addr)
+ if err != nil {
+ return nil, "", fmt.Errorf("listen on netstack: %w", err)
+ }
+ return ln, fmt.Sprintf("netstack %s", addr), nil
+ }
+
+ tcpAddr := net.TCPAddrFromAddrPort(addr)
+ lc := net.ListenConfig{}
+ ln, err := lc.Listen(ctx, "tcp", tcpAddr.String())
+ if err != nil {
+ return nil, "", fmt.Errorf("listen: %w", err)
+ }
+ return ln, addr.String(), nil
+}
+
+func (s *Server) closeListener(ln net.Listener) {
+ if ln == nil {
+ return
+ }
+ if err := ln.Close(); err != nil {
+ log.Debugf("listener close error: %v", err)
+ }
+}
+
+// Stop closes the SSH server
+func (s *Server) Stop() error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if s.sshServer == nil {
+ return nil
+ }
+
+ if err := s.sshServer.Close(); err != nil {
+ log.Debugf("close SSH server: %v", err)
+ }
+
+ s.sshServer = nil
+
+ maps.Clear(s.sessions)
+ maps.Clear(s.sessionJWTUsers)
+ maps.Clear(s.pendingAuthJWT)
+ maps.Clear(s.sshConnections)
+
+ for _, cancelFunc := range s.sessionCancels {
+ cancelFunc()
+ }
+ maps.Clear(s.sessionCancels)
+
+ for _, listener := range s.remoteForwardListeners {
+ if err := listener.Close(); err != nil {
+ log.Debugf("close remote forward listener: %v", err)
+ }
+ }
+ maps.Clear(s.remoteForwardListeners)
+
+ return nil
+}
+
+// GetStatus returns the current status of the SSH server and active sessions
+func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ enabled = s.sshServer != nil
+
+ for sessionKey, session := range s.sessions {
+ cmd := ""
+ if len(session.Command()) > 0 {
+ cmd = safeLogCommand(session.Command())
+ }
+
+ jwtUsername := s.sessionJWTUsers[sessionKey]
+
+ sessions = append(sessions, SessionInfo{
+ Username: session.User(),
+ RemoteAddress: session.RemoteAddr().String(),
+ Command: cmd,
+ JWTUsername: jwtUsername,
+ })
+ }
+
+ return enabled, sessions
+}
+
+// SetNetstackNet sets the netstack network for userspace networking
+func (s *Server) SetNetstackNet(net *netstack.Net) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.netstackNet = net
+}
+
+// SetNetworkValidation configures network-based connection filtering
+func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.wgAddress = addr
+}
+
+// UpdateSSHAuth updates the SSH fine-grained access control configuration
+// This should be called when network map updates include new SSH auth configuration
+func (s *Server) UpdateSSHAuth(config *sshauth.Config) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Reset JWT validator/extractor to pick up new userIDClaim
+ s.jwtValidator = nil
+ s.jwtExtractor = nil
+
+ s.authorizer.Update(config)
+}
+
+// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
+func (s *Server) ensureJWTValidator() error {
+ s.mu.RLock()
+ if s.jwtValidator != nil && s.jwtExtractor != nil {
+ s.mu.RUnlock()
+ return nil
+ }
+ config := s.jwtConfig
+ authorizer := s.authorizer
+ s.mu.RUnlock()
+
+ if config == nil {
+ return fmt.Errorf("JWT config not set")
+ }
+
+ log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
+
+ validator := jwt.NewValidator(
+ config.Issuer,
+ []string{config.Audience},
+ config.KeysLocation,
+ true,
+ )
+
+ // Use custom userIDClaim from authorizer if available
+ extractorOptions := []jwt.ClaimsExtractorOption{
+ jwt.WithAudience(config.Audience),
+ }
+ if authorizer.GetUserIDClaim() != "" {
+ extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
+ log.Debugf("Using custom user ID claim: %s", authorizer.GetUserIDClaim())
+ }
+
+ extractor := jwt.NewClaimsExtractor(extractorOptions...)
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if s.jwtValidator != nil && s.jwtExtractor != nil {
+ return nil
+ }
+
+ s.jwtValidator = validator
+ s.jwtExtractor = extractor
+
+ log.Infof("JWT validator initialized successfully")
+ return nil
+}
+
+func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
+ s.mu.RLock()
+ jwtValidator := s.jwtValidator
+ jwtConfig := s.jwtConfig
+ s.mu.RUnlock()
+
+ if jwtValidator == nil {
+ return nil, fmt.Errorf("JWT validator not initialized")
+ }
+
+ token, err := jwtValidator.ValidateAndParse(context.Background(), tokenString)
+ if err != nil {
+ if jwtConfig != nil {
+ if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
+ return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
+ jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
+ }
+ }
+ return nil, fmt.Errorf("validate token: %w", err)
+ }
+
+ if err := s.checkTokenAge(token, jwtConfig); err != nil {
+ return nil, err
+ }
+
+ return token, nil
+}
+
+func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
+ if jwtConfig == nil {
+ return nil
+ }
+
+ maxTokenAge := jwtConfig.MaxTokenAge
+ if maxTokenAge <= 0 {
+ maxTokenAge = DefaultJWTMaxTokenAge
+ }
+
+ claims, ok := token.Claims.(gojwt.MapClaims)
+ if !ok {
+ userID := extractUserID(token)
+ return fmt.Errorf("token has invalid claims format (user=%s)", userID)
+ }
+
+ iat, ok := claims["iat"].(float64)
+ if !ok {
+ userID := extractUserID(token)
+ return fmt.Errorf("token missing iat claim (user=%s)", userID)
+ }
+
+ issuedAt := time.Unix(int64(iat), 0)
+ tokenAge := time.Since(issuedAt)
+ maxAge := time.Duration(maxTokenAge) * time.Second
+ if tokenAge > maxAge {
+ userID := getUserIDFromClaims(claims)
+ return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
+ }
+
+ return nil
+}
+
+func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
+ s.mu.RLock()
+ jwtExtractor := s.jwtExtractor
+ s.mu.RUnlock()
+
+ if jwtExtractor == nil {
+ userID := extractUserID(token)
+ return nil, fmt.Errorf("JWT extractor not initialized (user=%s)", userID)
+ }
+
+ userAuth, err := jwtExtractor.ToUserAuth(token)
+ if err != nil {
+ userID := extractUserID(token)
+ return nil, fmt.Errorf("extract user from token (user=%s): %w", userID, err)
+ }
+
+ if !s.hasSSHAccess(&userAuth) {
+ return nil, fmt.Errorf("user %s does not have SSH access permissions", userAuth.UserId)
+ }
+
+ return &userAuth, nil
+}
+
+func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
+ return userAuth.UserId != ""
+}
+
+func extractUserID(token *gojwt.Token) string {
+ if token == nil {
+ return "unknown"
+ }
+ claims, ok := token.Claims.(gojwt.MapClaims)
+ if !ok {
+ return "unknown"
+ }
+ return getUserIDFromClaims(claims)
+}
+
+func getUserIDFromClaims(claims gojwt.MapClaims) string {
+ if sub, ok := claims["sub"].(string); ok && sub != "" {
+ return sub
+ }
+ if userID, ok := claims["user_id"].(string); ok && userID != "" {
+ return userID
+ }
+ if email, ok := claims["email"].(string); ok && email != "" {
+ return email
+ }
+ return "unknown"
+}
+
+func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
+ parts := strings.Split(tokenString, ".")
+ if len(parts) != 3 {
+ return nil, fmt.Errorf("invalid token format")
+ }
+
+ payload, err := base64.RawURLEncoding.DecodeString(parts[1])
+ if err != nil {
+ return nil, fmt.Errorf("decode payload: %w", err)
+ }
+
+ var claims map[string]interface{}
+ if err := json.Unmarshal(payload, &claims); err != nil {
+ return nil, fmt.Errorf("parse claims: %w", err)
+ }
+
+ return claims, nil
+}
+
+func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
+ osUsername := ctx.User()
+ remoteAddr := ctx.RemoteAddr()
+
+ if err := s.ensureJWTValidator(); err != nil {
+ log.Errorf("JWT validator initialization failed for user %s from %s: %v", osUsername, remoteAddr, err)
+ return false
+ }
+
+ token, err := s.validateJWTToken(password)
+ if err != nil {
+ log.Warnf("JWT authentication failed for user %s from %s: %v", osUsername, remoteAddr, err)
+ return false
+ }
+
+ userAuth, err := s.extractAndValidateUser(token)
+ if err != nil {
+ log.Warnf("User validation failed for user %s from %s: %v", osUsername, remoteAddr, err)
+ return false
+ }
+
+ s.mu.RLock()
+ authorizer := s.authorizer
+ s.mu.RUnlock()
+
+ if err := authorizer.Authorize(userAuth.UserId, osUsername); err != nil {
+ log.Warnf("SSH authorization denied for user %s (JWT user ID: %s) from %s: %v", osUsername, userAuth.UserId, remoteAddr, err)
+ return false
+ }
+
+ key := newAuthKey(osUsername, remoteAddr)
+ s.mu.Lock()
+ s.pendingAuthJWT[key] = userAuth.UserId
+ s.mu.Unlock()
+
+ log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", osUsername, userAuth.UserId, remoteAddr)
+ return true
+}
+
+func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if state, exists := s.sshConnections[sshConn]; exists {
+ state.hasActivePortForward = true
+ } else {
+ s.sshConnections[sshConn] = &sshConnectionState{
+ hasActivePortForward: true,
+ username: username,
+ remoteAddr: remoteAddr,
+ }
+ }
+}
+
+func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
+ // We can't extract the SSH connection from net.Conn directly
+ // Connection cleanup will happen during session cleanup or via timeout
+ log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
+}
+
+func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
+ if ctx == nil {
+ return "unknown"
+ }
+
+ // Try to match by SSH connection
+ sshConn := ctx.Value(ssh.ContextKeyConn)
+ if sshConn == nil {
+ return "unknown"
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ // Look through sessions to find one with matching connection
+ for sessionKey, session := range s.sessions {
+ if session.Context().Value(ssh.ContextKeyConn) == sshConn {
+ return sessionKey
+ }
+ }
+
+ // If no session found, this might be during early connection setup
+ // Return a temporary key that we'll fix up later
+ if ctx.User() != "" && ctx.RemoteAddr() != nil {
+ tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
+ log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
+ return tempKey
+ }
+
+ return "unknown"
+}
+
+func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
+ s.mu.RLock()
+ netbirdNetwork := s.wgAddress.Network
+ localIP := s.wgAddress.IP
+ s.mu.RUnlock()
+
+ if !netbirdNetwork.IsValid() || !localIP.IsValid() {
+ return conn
+ }
+
+ remoteAddr := conn.RemoteAddr()
+ tcpAddr, ok := remoteAddr.(*net.TCPAddr)
+ if !ok {
+ log.Warnf("SSH connection rejected: non-TCP address %s", remoteAddr)
+ return nil
+ }
+
+ remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
+ if !ok {
+ log.Warnf("SSH connection rejected: invalid remote IP %s", tcpAddr.IP)
+ return nil
+ }
+
+ // Block connections from our own IP (prevent local apps from connecting to ourselves)
+ if remoteIP == localIP {
+ log.Warnf("SSH connection rejected from own IP %s", remoteIP)
+ return nil
+ }
+
+ if !netbirdNetwork.Contains(remoteIP) {
+ log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP)
+ return nil
+ }
+
+ log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr)
+ return conn
+}
+
+func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
+ if err := enableUserSwitching(); err != nil {
+ log.Warnf("failed to enable user switching: %v", err)
+ }
+
+ serverVersion := fmt.Sprintf("%s-%s", detection.ServerIdentifier, version.NetbirdVersion())
+ if s.jwtEnabled {
+ serverVersion += " " + detection.JWTRequiredMarker
+ }
+
+ server := &ssh.Server{
+ Addr: addr.String(),
+ Handler: s.sessionHandler,
+ SubsystemHandlers: map[string]ssh.SubsystemHandler{
+ "sftp": s.sftpSubsystemHandler,
+ },
+ HostSigners: []ssh.Signer{},
+ ChannelHandlers: map[string]ssh.ChannelHandler{
+ "session": ssh.DefaultSessionHandler,
+ "direct-tcpip": s.directTCPIPHandler,
+ },
+ RequestHandlers: map[string]ssh.RequestHandler{
+ "tcpip-forward": s.tcpipForwardHandler,
+ "cancel-tcpip-forward": s.cancelTcpipForwardHandler,
+ },
+ ConnCallback: s.connectionValidator,
+ ConnectionFailedCallback: s.connectionCloseHandler,
+ Version: serverVersion,
+ }
+
+ if s.jwtEnabled {
+ server.PasswordHandler = s.passwordHandler
+ }
+
+ hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM)
+ if err := server.SetOption(hostKeyPEM); err != nil {
+ return nil, fmt.Errorf("set host key: %w", err)
+ }
+
+ s.configurePortForwarding(server)
+ return server, nil
+}
+
+func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.remoteForwardListeners[key] = ln
+}
+
+func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ ln, exists := s.remoteForwardListeners[key]
+ if !exists {
+ return false
+ }
+
+ delete(s.remoteForwardListeners, key)
+ if err := ln.Close(); err != nil {
+ log.Debugf("remote forward listener close error: %v", err)
+ }
+
+ return true
+}
+
+func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
+ var payload struct {
+ Host string
+ Port uint32
+ OriginatorAddr string
+ OriginatorPort uint32
+ }
+
+ if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
+ if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil {
+ log.Debugf("channel reject error: %v", err)
+ }
+ return
+ }
+
+ s.mu.RLock()
+ allowLocal := s.allowLocalPortForwarding
+ s.mu.RUnlock()
+
+ if !allowLocal {
+ log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port)
+ _ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
+ return
+ }
+
+ // Check privilege requirements for the destination port
+ if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
+ log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
+ _ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
+ return
+ }
+
+ log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
+
+ ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
+}
diff --git a/client/ssh/server/server_config_test.go b/client/ssh/server/server_config_test.go
new file mode 100644
index 000000000..24e455025
--- /dev/null
+++ b/client/ssh/server/server_config_test.go
@@ -0,0 +1,394 @@
+package server
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "os/user"
+ "runtime"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/netbirdio/netbird/client/ssh"
+ sshclient "github.com/netbirdio/netbird/client/ssh/client"
+)
+
+func TestServer_RootLoginRestriction(t *testing.T) {
+ // Generate host key for server
+ hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+ allowRoot bool
+ username string
+ expectError bool
+ description string
+ }{
+ {
+ name: "root login allowed",
+ allowRoot: true,
+ username: "root",
+ expectError: false,
+ description: "Root login should succeed when allowed",
+ },
+ {
+ name: "root login denied",
+ allowRoot: false,
+ username: "root",
+ expectError: true,
+ description: "Root login should fail when disabled",
+ },
+ {
+ name: "regular user login always allowed",
+ allowRoot: false,
+ username: "testuser",
+ expectError: false,
+ description: "Regular user login should work regardless of root setting",
+ },
+ }
+
+ // Add Windows Administrator tests if on Windows
+ if runtime.GOOS == "windows" {
+ tests = append(tests, []struct {
+ name string
+ allowRoot bool
+ username string
+ expectError bool
+ description string
+ }{
+ {
+ name: "Administrator login allowed",
+ allowRoot: true,
+ username: "Administrator",
+ expectError: false,
+ description: "Administrator login should succeed when allowed",
+ },
+ {
+ name: "Administrator login denied",
+ allowRoot: false,
+ username: "Administrator",
+ expectError: true,
+ description: "Administrator login should fail when disabled",
+ },
+ {
+ name: "administrator login denied (lowercase)",
+ allowRoot: false,
+ username: "administrator",
+ expectError: true,
+ description: "administrator login should fail when disabled (case insensitive)",
+ },
+ }...)
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Mock privileged environment to test root access controls
+ // Set up mock users based on platform
+ mockUsers := map[string]*user.User{
+ "root": createTestUser("root", "0", "0", "/root"),
+ "testuser": createTestUser("testuser", "1000", "1000", "/home/testuser"),
+ }
+
+ // Add Windows-specific users for Administrator tests
+ if runtime.GOOS == "windows" {
+ mockUsers["Administrator"] = createTestUser("Administrator", "500", "544", "C:\\Users\\Administrator")
+ mockUsers["administrator"] = createTestUser("administrator", "500", "544", "C:\\Users\\administrator")
+ }
+
+ cleanup := setupTestDependencies(
+ createTestUser("root", "0", "0", "/root"), // Running as root
+ nil,
+ runtime.GOOS,
+ 0, // euid 0 (root)
+ mockUsers,
+ nil,
+ )
+ defer cleanup()
+
+ // Create server with specific configuration
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := New(serverConfig)
+ server.SetAllowRootLogin(tt.allowRoot)
+
+ // Test the userNameLookup method directly
+ user, err := server.userNameLookup(tt.username)
+
+ if tt.expectError {
+ assert.Error(t, err, tt.description)
+ if tt.username == "root" || strings.ToLower(tt.username) == "administrator" {
+ // Check for appropriate error message based on platform capabilities
+ errorMsg := err.Error()
+ // Either privileged user restriction OR user switching limitation
+ hasPrivilegedError := strings.Contains(errorMsg, "privileged user")
+ hasSwitchingError := strings.Contains(errorMsg, "cannot switch") || strings.Contains(errorMsg, "user switching not supported")
+ assert.True(t, hasPrivilegedError || hasSwitchingError,
+ "Expected privileged user or user switching error, got: %s", errorMsg)
+ }
+ } else {
+ if tt.username == "root" || strings.ToLower(tt.username) == "administrator" {
+ // For privileged users, we expect either success or a different error
+ // (like user not found), but not the "login disabled" error
+ if err != nil {
+ assert.NotContains(t, err.Error(), "privileged user login is disabled")
+ }
+ } else {
+ // For regular users, lookup should generally succeed or fall back gracefully
+ // Note: may return current user as fallback
+ assert.NotNil(t, user)
+ }
+ }
+ })
+ }
+}
+
+func TestServer_PortForwardingRestriction(t *testing.T) {
+ // Test that the port forwarding callbacks properly respect configuration flags
+ // This is a unit test of the callback logic, not a full integration test
+
+ // Generate host key for server
+ hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+ allowLocalForwarding bool
+ allowRemoteForwarding bool
+ description string
+ }{
+ {
+ name: "all forwarding allowed",
+ allowLocalForwarding: true,
+ allowRemoteForwarding: true,
+ description: "Both local and remote forwarding should be allowed",
+ },
+ {
+ name: "local forwarding disabled",
+ allowLocalForwarding: false,
+ allowRemoteForwarding: true,
+ description: "Local forwarding should be denied when disabled",
+ },
+ {
+ name: "remote forwarding disabled",
+ allowLocalForwarding: true,
+ allowRemoteForwarding: false,
+ description: "Remote forwarding should be denied when disabled",
+ },
+ {
+ name: "all forwarding disabled",
+ allowLocalForwarding: false,
+ allowRemoteForwarding: false,
+ description: "Both forwarding types should be denied when disabled",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create server with specific configuration
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := New(serverConfig)
+ server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
+ server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
+
+ // We need to access the internal configuration to simulate the callback tests
+ // Since the callbacks are created inside the Start method, we'll test the logic directly
+
+ // Test the configuration values are set correctly
+ server.mu.RLock()
+ allowLocal := server.allowLocalPortForwarding
+ allowRemote := server.allowRemotePortForwarding
+ server.mu.RUnlock()
+
+ assert.Equal(t, tt.allowLocalForwarding, allowLocal, "Local forwarding configuration should be set correctly")
+ assert.Equal(t, tt.allowRemoteForwarding, allowRemote, "Remote forwarding configuration should be set correctly")
+
+ // Simulate the callback logic
+ localResult := allowLocal // This would be the callback return value
+ remoteResult := allowRemote // This would be the callback return value
+
+ assert.Equal(t, tt.allowLocalForwarding, localResult,
+ "Local port forwarding callback should return correct value")
+ assert.Equal(t, tt.allowRemoteForwarding, remoteResult,
+ "Remote port forwarding callback should return correct value")
+ })
+ }
+}
+
+func TestServer_PortConflictHandling(t *testing.T) {
+ // Test that multiple sessions requesting the same local port are handled naturally by the OS
+ // Get current user for SSH connection
+ currentUser, err := user.Current()
+ require.NoError(t, err, "Should be able to get current user")
+
+ // Generate host key for server
+ hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
+ require.NoError(t, err)
+
+ // Create server
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := New(serverConfig)
+ server.SetAllowRootLogin(true)
+
+ serverAddr := StartTestServer(t, server)
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+
+ // Get a free port for testing
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+ testPort := ln.Addr().(*net.TCPAddr).Port
+ err = ln.Close()
+ require.NoError(t, err)
+
+ // Connect first client
+ ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel1()
+
+ client1, err := sshclient.Dial(ctx1, serverAddr, currentUser.Username, sshclient.DialOptions{
+ InsecureSkipVerify: true,
+ })
+ require.NoError(t, err)
+ defer func() {
+ err := client1.Close()
+ assert.NoError(t, err)
+ }()
+
+ // Connect second client
+ ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel2()
+
+ client2, err := sshclient.Dial(ctx2, serverAddr, currentUser.Username, sshclient.DialOptions{
+ InsecureSkipVerify: true,
+ })
+ require.NoError(t, err)
+ defer func() {
+ err := client2.Close()
+ assert.NoError(t, err)
+ }()
+
+ // First client binds to the test port
+ localAddr1 := fmt.Sprintf("127.0.0.1:%d", testPort)
+ remoteAddr := "127.0.0.1:80"
+
+ // Start first client's port forwarding
+ done1 := make(chan error, 1)
+ go func() {
+ // This should succeed and hold the port
+ err := client1.LocalPortForward(ctx1, localAddr1, remoteAddr)
+ done1 <- err
+ }()
+
+ // Give first client time to bind
+ time.Sleep(200 * time.Millisecond)
+
+ // Second client tries to bind to same port
+ localAddr2 := fmt.Sprintf("127.0.0.1:%d", testPort)
+
+ shortCtx, shortCancel := context.WithTimeout(context.Background(), 1*time.Second)
+ defer shortCancel()
+
+ err = client2.LocalPortForward(shortCtx, localAddr2, remoteAddr)
+ // Second client should fail due to "address already in use"
+ assert.Error(t, err, "Second client should fail to bind to same port")
+ if err != nil {
+ // The error should indicate the address is already in use
+ errMsg := strings.ToLower(err.Error())
+ if runtime.GOOS == "windows" {
+ assert.Contains(t, errMsg, "only one usage of each socket address",
+ "Error should indicate port conflict")
+ } else {
+ assert.Contains(t, errMsg, "address already in use",
+ "Error should indicate port conflict")
+ }
+ }
+
+ // Cancel first client's context and wait for it to finish
+ cancel1()
+ select {
+ case err1 := <-done1:
+ // Should get context cancelled or deadline exceeded
+ assert.Error(t, err1, "First client should exit when context cancelled")
+ case <-time.After(2 * time.Second):
+ t.Error("First client did not exit within timeout")
+ }
+}
+
+func TestServer_IsPrivilegedUser(t *testing.T) {
+
+ tests := []struct {
+ username string
+ expected bool
+ description string
+ }{
+ {
+ username: "root",
+ expected: true,
+ description: "root should be considered privileged",
+ },
+ {
+ username: "regular",
+ expected: false,
+ description: "regular user should not be privileged",
+ },
+ {
+ username: "",
+ expected: false,
+ description: "empty username should not be privileged",
+ },
+ }
+
+ // Add Windows-specific tests
+ if runtime.GOOS == "windows" {
+ tests = append(tests, []struct {
+ username string
+ expected bool
+ description string
+ }{
+ {
+ username: "Administrator",
+ expected: true,
+ description: "Administrator should be considered privileged on Windows",
+ },
+ {
+ username: "administrator",
+ expected: true,
+ description: "administrator should be considered privileged on Windows (case insensitive)",
+ },
+ }...)
+ } else {
+ // On non-Windows systems, Administrator should not be privileged
+ tests = append(tests, []struct {
+ username string
+ expected bool
+ description string
+ }{
+ {
+ username: "Administrator",
+ expected: false,
+ description: "Administrator should not be privileged on non-Windows systems",
+ },
+ }...)
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.description, func(t *testing.T) {
+ result := isPrivilegedUsername(tt.username)
+ assert.Equal(t, tt.expected, result, tt.description)
+ })
+ }
+}
diff --git a/client/ssh/server/server_test.go b/client/ssh/server/server_test.go
new file mode 100644
index 000000000..661068539
--- /dev/null
+++ b/client/ssh/server/server_test.go
@@ -0,0 +1,441 @@
+package server
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/netip"
+ "os/user"
+ "runtime"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ cryptossh "golang.org/x/crypto/ssh"
+
+ nbssh "github.com/netbirdio/netbird/client/ssh"
+)
+
+func TestServer_StartStop(t *testing.T) {
+ key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ serverConfig := &Config{
+ HostKeyPEM: key,
+ JWT: nil,
+ }
+ server := New(serverConfig)
+
+ err = server.Stop()
+ assert.NoError(t, err)
+}
+
+func TestSSHServerIntegration(t *testing.T) {
+ // Generate host key for server
+ hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ // Generate client key pair
+ clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ // Create server with random port
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := New(serverConfig)
+
+ // Start server in background
+ serverAddr := "127.0.0.1:0"
+ started := make(chan string, 1)
+ errChan := make(chan error, 1)
+
+ go func() {
+ // Get a free port
+ ln, err := net.Listen("tcp", serverAddr)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ actualAddr := ln.Addr().String()
+ if err := ln.Close(); err != nil {
+ errChan <- fmt.Errorf("close temp listener: %w", err)
+ return
+ }
+
+ addrPort, _ := netip.ParseAddrPort(actualAddr)
+ if err := server.Start(context.Background(), addrPort); err != nil {
+ errChan <- err
+ return
+ }
+ started <- actualAddr
+ }()
+
+ select {
+ case actualAddr := <-started:
+ serverAddr = actualAddr
+ case err := <-errChan:
+ t.Fatalf("Server failed to start: %v", err)
+ case <-time.After(5 * time.Second):
+ t.Fatal("Server start timeout")
+ }
+
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+
+ // Parse client private key
+ signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
+ require.NoError(t, err)
+
+ // Parse server host key for verification
+ hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
+ require.NoError(t, err)
+ hostPubKey := hostPrivParsed.PublicKey()
+
+ // Get current user for SSH connection
+ currentUser, err := user.Current()
+ require.NoError(t, err, "Should be able to get current user for test")
+
+ // Create SSH client config
+ config := &cryptossh.ClientConfig{
+ User: currentUser.Username,
+ Auth: []cryptossh.AuthMethod{
+ cryptossh.PublicKeys(signer),
+ },
+ HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
+ Timeout: 3 * time.Second,
+ }
+
+ // Connect to SSH server
+ client, err := cryptossh.Dial("tcp", serverAddr, config)
+ require.NoError(t, err)
+ defer func() {
+ if err := client.Close(); err != nil {
+ t.Logf("close client: %v", err)
+ }
+ }()
+
+ // Test creating a session
+ session, err := client.NewSession()
+ require.NoError(t, err)
+ defer func() {
+ if err := session.Close(); err != nil {
+ t.Logf("close session: %v", err)
+ }
+ }()
+
+ // Note: Since we don't have a real shell environment in tests,
+ // we can't test actual command execution, but we can verify
+ // the connection and authentication work
+ t.Log("SSH connection and authentication successful")
+}
+
+func TestSSHServerMultipleConnections(t *testing.T) {
+ // Generate host key for server
+ hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ // Generate client key pair
+ clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ // Create server
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := New(serverConfig)
+
+ // Start server
+ serverAddr := "127.0.0.1:0"
+ started := make(chan string, 1)
+ errChan := make(chan error, 1)
+
+ go func() {
+ ln, err := net.Listen("tcp", serverAddr)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ actualAddr := ln.Addr().String()
+ if err := ln.Close(); err != nil {
+ errChan <- fmt.Errorf("close temp listener: %w", err)
+ return
+ }
+
+ addrPort, _ := netip.ParseAddrPort(actualAddr)
+ if err := server.Start(context.Background(), addrPort); err != nil {
+ errChan <- err
+ return
+ }
+ started <- actualAddr
+ }()
+
+ select {
+ case actualAddr := <-started:
+ serverAddr = actualAddr
+ case err := <-errChan:
+ t.Fatalf("Server failed to start: %v", err)
+ case <-time.After(5 * time.Second):
+ t.Fatal("Server start timeout")
+ }
+
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+
+ // Parse client private key
+ signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
+ require.NoError(t, err)
+
+ // Parse server host key
+ hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
+ require.NoError(t, err)
+ hostPubKey := hostPrivParsed.PublicKey()
+
+ // Get current user for SSH connection
+ currentUser, err := user.Current()
+ require.NoError(t, err, "Should be able to get current user for test")
+
+ config := &cryptossh.ClientConfig{
+ User: currentUser.Username,
+ Auth: []cryptossh.AuthMethod{
+ cryptossh.PublicKeys(signer),
+ },
+ HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
+ Timeout: 3 * time.Second,
+ }
+
+ // Test multiple concurrent connections
+ const numConnections = 5
+ results := make(chan error, numConnections)
+
+ for i := 0; i < numConnections; i++ {
+ go func(id int) {
+ client, err := cryptossh.Dial("tcp", serverAddr, config)
+ if err != nil {
+ results <- fmt.Errorf("connection %d failed: %w", id, err)
+ return
+ }
+ defer func() {
+ _ = client.Close() // Ignore error in test goroutine
+ }()
+
+ session, err := client.NewSession()
+ if err != nil {
+ results <- fmt.Errorf("session %d failed: %w", id, err)
+ return
+ }
+ defer func() {
+ _ = session.Close() // Ignore error in test goroutine
+ }()
+
+ results <- nil
+ }(i)
+ }
+
+ // Wait for all connections to complete
+ for i := 0; i < numConnections; i++ {
+ select {
+ case err := <-results:
+ assert.NoError(t, err)
+ case <-time.After(10 * time.Second):
+ t.Fatalf("Connection %d timed out", i)
+ }
+ }
+}
+
+func TestSSHServerNoAuthMode(t *testing.T) {
+ // Generate host key for server
+ hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ // Create server
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := New(serverConfig)
+
+ // Start server
+ serverAddr := "127.0.0.1:0"
+ started := make(chan string, 1)
+ errChan := make(chan error, 1)
+
+ go func() {
+ ln, err := net.Listen("tcp", serverAddr)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ actualAddr := ln.Addr().String()
+ if err := ln.Close(); err != nil {
+ errChan <- fmt.Errorf("close temp listener: %w", err)
+ return
+ }
+
+ addrPort, _ := netip.ParseAddrPort(actualAddr)
+ if err := server.Start(context.Background(), addrPort); err != nil {
+ errChan <- err
+ return
+ }
+ started <- actualAddr
+ }()
+
+ select {
+ case actualAddr := <-started:
+ serverAddr = actualAddr
+ case err := <-errChan:
+ t.Fatalf("Server failed to start: %v", err)
+ case <-time.After(5 * time.Second):
+ t.Fatal("Server start timeout")
+ }
+
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+
+ // Generate a client private key for SSH protocol (server doesn't check it)
+ clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+ clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey)
+ require.NoError(t, err)
+
+ // Parse server host key
+ hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
+ require.NoError(t, err)
+ hostPubKey := hostPrivParsed.PublicKey()
+
+ // Get current user for SSH connection
+ currentUser, err := user.Current()
+ require.NoError(t, err, "Should be able to get current user for test")
+
+ // Try to connect with client key
+ config := &cryptossh.ClientConfig{
+ User: currentUser.Username,
+ Auth: []cryptossh.AuthMethod{
+ cryptossh.PublicKeys(clientSigner),
+ },
+ HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
+ Timeout: 3 * time.Second,
+ }
+
+ // This should succeed in no-auth mode (server doesn't verify keys)
+ conn, err := cryptossh.Dial("tcp", serverAddr, config)
+ assert.NoError(t, err, "Connection should succeed in no-auth mode")
+ if conn != nil {
+ assert.NoError(t, conn.Close())
+ }
+}
+
+func TestSSHServerStartStopCycle(t *testing.T) {
+ hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := New(serverConfig)
+ serverAddr := "127.0.0.1:0"
+
+ // Test multiple start/stop cycles
+ for i := 0; i < 3; i++ {
+ t.Logf("Start/stop cycle %d", i+1)
+
+ started := make(chan string, 1)
+ errChan := make(chan error, 1)
+
+ go func() {
+ ln, err := net.Listen("tcp", serverAddr)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ actualAddr := ln.Addr().String()
+ if err := ln.Close(); err != nil {
+ errChan <- fmt.Errorf("close temp listener: %w", err)
+ return
+ }
+
+ addrPort, _ := netip.ParseAddrPort(actualAddr)
+ if err := server.Start(context.Background(), addrPort); err != nil {
+ errChan <- err
+ return
+ }
+ started <- actualAddr
+ }()
+
+ select {
+ case <-started:
+ case err := <-errChan:
+ t.Fatalf("Cycle %d: Server failed to start: %v", i+1, err)
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Cycle %d: Server start timeout", i+1)
+ }
+
+ err = server.Stop()
+ require.NoError(t, err, "Cycle %d: Stop should succeed", i+1)
+ }
+}
+
+func TestSSHServer_WindowsShellHandling(t *testing.T) {
+ if testing.Short() {
+ t.Skip("Skipping Windows shell test in short mode")
+ }
+
+ server := &Server{}
+
+ if runtime.GOOS == "windows" {
+ // Test Windows cmd.exe shell behavior
+ args := server.getShellCommandArgs("cmd.exe", "echo test")
+ assert.Equal(t, "cmd.exe", args[0])
+ assert.Equal(t, "-Command", args[1])
+ assert.Equal(t, "echo test", args[2])
+
+ // Test PowerShell behavior
+ args = server.getShellCommandArgs("powershell.exe", "echo test")
+ assert.Equal(t, "powershell.exe", args[0])
+ assert.Equal(t, "-Command", args[1])
+ assert.Equal(t, "echo test", args[2])
+ } else {
+ // Test Unix shell behavior
+ args := server.getShellCommandArgs("/bin/sh", "echo test")
+ assert.Equal(t, "/bin/sh", args[0])
+ assert.Equal(t, "-l", args[1])
+ assert.Equal(t, "-c", args[2])
+ assert.Equal(t, "echo test", args[3])
+ }
+}
+
+func TestSSHServer_PortForwardingConfiguration(t *testing.T) {
+ hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+
+ serverConfig1 := &Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server1 := New(serverConfig1)
+
+ serverConfig2 := &Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server2 := New(serverConfig2)
+
+ assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security")
+ assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security")
+
+ server2.SetAllowLocalPortForwarding(true)
+ server2.SetAllowRemotePortForwarding(true)
+
+ assert.True(t, server2.allowLocalPortForwarding, "Local port forwarding should be enabled when explicitly set")
+ assert.True(t, server2.allowRemotePortForwarding, "Remote port forwarding should be enabled when explicitly set")
+}
diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go
new file mode 100644
index 000000000..4e6d72098
--- /dev/null
+++ b/client/ssh/server/session_handlers.go
@@ -0,0 +1,168 @@
+package server
+
+import (
+ "crypto/sha256"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+ "strings"
+ "time"
+
+ "github.com/gliderlabs/ssh"
+ log "github.com/sirupsen/logrus"
+ cryptossh "golang.org/x/crypto/ssh"
+)
+
+// sessionHandler handles SSH sessions
+func (s *Server) sessionHandler(session ssh.Session) {
+ sessionKey := s.registerSession(session)
+
+ key := newAuthKey(session.User(), session.RemoteAddr())
+ s.mu.Lock()
+ jwtUsername := s.pendingAuthJWT[key]
+ if jwtUsername != "" {
+ s.sessionJWTUsers[sessionKey] = jwtUsername
+ delete(s.pendingAuthJWT, key)
+ }
+ s.mu.Unlock()
+
+ logger := log.WithField("session", sessionKey)
+ if jwtUsername != "" {
+ logger = logger.WithField("jwt_user", jwtUsername)
+ logger.Infof("SSH session started (JWT user: %s)", jwtUsername)
+ } else {
+ logger.Infof("SSH session started")
+ }
+ sessionStart := time.Now()
+
+ defer s.unregisterSession(sessionKey, session)
+ defer func() {
+ duration := time.Since(sessionStart).Round(time.Millisecond)
+ if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
+ logger.Warnf("close session after %v: %v", duration, err)
+ }
+ logger.Infof("SSH session closed after %v", duration)
+ }()
+
+ privilegeResult, err := s.userPrivilegeCheck(session.User())
+ if err != nil {
+ s.handlePrivError(logger, session, err)
+ return
+ }
+
+ ptyReq, winCh, isPty := session.Pty()
+ hasCommand := len(session.Command()) > 0
+
+ switch {
+ case isPty && hasCommand:
+ // ssh -t - Pty command execution
+ s.handleCommand(logger, session, privilegeResult, winCh)
+ case isPty:
+ // ssh - Pty interactive session (login)
+ s.handlePty(logger, session, privilegeResult, ptyReq, winCh)
+ case hasCommand:
+ // ssh - non-Pty command execution
+ s.handleCommand(logger, session, privilegeResult, nil)
+ default:
+ s.rejectInvalidSession(logger, session)
+ }
+}
+
+func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) {
+ if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil {
+ logger.Debugf(errWriteSession, err)
+ }
+ if err := session.Exit(1); err != nil {
+ logSessionExitError(logger, err)
+ }
+ logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr())
+}
+
+func (s *Server) registerSession(session ssh.Session) SessionKey {
+ sessionID := session.Context().Value(ssh.ContextKeySessionID)
+ if sessionID == nil {
+ sessionID = fmt.Sprintf("%p", session)
+ }
+
+ // Create a short 4-byte identifier from the full session ID
+ hasher := sha256.New()
+ hasher.Write([]byte(fmt.Sprintf("%v", sessionID)))
+ hash := hasher.Sum(nil)
+ shortID := hex.EncodeToString(hash[:4])
+
+ remoteAddr := session.RemoteAddr().String()
+ username := session.User()
+ sessionKey := SessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID))
+
+ s.mu.Lock()
+ s.sessions[sessionKey] = session
+ s.mu.Unlock()
+
+ return sessionKey
+}
+
+func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) {
+ s.mu.Lock()
+ delete(s.sessions, sessionKey)
+ delete(s.sessionJWTUsers, sessionKey)
+
+ // Cancel all port forwarding connections for this session
+ var connectionsToCancel []ConnectionKey
+ for key := range s.sessionCancels {
+ if strings.HasPrefix(string(key), string(sessionKey)+"-") {
+ connectionsToCancel = append(connectionsToCancel, key)
+ }
+ }
+
+ for _, key := range connectionsToCancel {
+ if cancelFunc, exists := s.sessionCancels[key]; exists {
+ log.WithField("session", sessionKey).Debugf("cancelling port forwarding context: %s", key)
+ cancelFunc()
+ delete(s.sessionCancels, key)
+ }
+ }
+
+ if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil {
+ if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok {
+ delete(s.sshConnections, sshConn)
+ }
+ }
+
+ s.mu.Unlock()
+}
+
+func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) {
+ logger.Warnf("user privilege check failed: %v", err)
+
+ errorMsg := s.buildUserLookupErrorMessage(err)
+
+ if _, writeErr := fmt.Fprint(session, errorMsg); writeErr != nil {
+ logger.Debugf(errWriteSession, writeErr)
+ }
+ if exitErr := session.Exit(1); exitErr != nil {
+ logSessionExitError(logger, exitErr)
+ }
+}
+
+// buildUserLookupErrorMessage creates appropriate user-facing error messages based on error type
+func (s *Server) buildUserLookupErrorMessage(err error) string {
+ var privilegedErr *PrivilegedUserError
+
+ switch {
+ case errors.As(err, &privilegedErr):
+ if privilegedErr.Username == "root" {
+ return "root login is disabled on this SSH server\n"
+ }
+ return "privileged user access is disabled on this SSH server\n"
+
+ case errors.Is(err, ErrPrivilegeRequired):
+ return "Windows user switching failed - NetBird must run with elevated privileges for user switching\n"
+
+ case errors.Is(err, ErrPrivilegedUserSwitch):
+ return "Cannot switch to privileged user - current user lacks required privileges\n"
+
+ default:
+ return "User authentication failed\n"
+ }
+}
diff --git a/client/ssh/server/session_handlers_js.go b/client/ssh/server/session_handlers_js.go
new file mode 100644
index 000000000..c35e4da0b
--- /dev/null
+++ b/client/ssh/server/session_handlers_js.go
@@ -0,0 +1,22 @@
+//go:build js
+
+package server
+
+import (
+ "fmt"
+
+ "github.com/gliderlabs/ssh"
+ log "github.com/sirupsen/logrus"
+)
+
+// handlePty is not supported on JS/WASM
+func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
+ errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
+ if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
+ logger.Debugf(errWriteSession, err)
+ }
+ if err := session.Exit(1); err != nil {
+ logSessionExitError(logger, err)
+ }
+ return false
+}
diff --git a/client/ssh/server/sftp.go b/client/ssh/server/sftp.go
new file mode 100644
index 000000000..c2b9f552b
--- /dev/null
+++ b/client/ssh/server/sftp.go
@@ -0,0 +1,81 @@
+package server
+
+import (
+ "fmt"
+ "io"
+
+ "github.com/gliderlabs/ssh"
+ "github.com/pkg/sftp"
+ log "github.com/sirupsen/logrus"
+)
+
+// SetAllowSFTP enables or disables SFTP support
+func (s *Server) SetAllowSFTP(allow bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.allowSFTP = allow
+}
+
+// sftpSubsystemHandler handles SFTP subsystem requests
+func (s *Server) sftpSubsystemHandler(sess ssh.Session) {
+ s.mu.RLock()
+ allowSFTP := s.allowSFTP
+ s.mu.RUnlock()
+
+ if !allowSFTP {
+ log.Debugf("SFTP subsystem request denied: SFTP disabled")
+ if err := sess.Exit(1); err != nil {
+ log.Debugf("SFTP session exit failed: %v", err)
+ }
+ return
+ }
+
+ result := s.CheckPrivileges(PrivilegeCheckRequest{
+ RequestedUsername: sess.User(),
+ FeatureSupportsUserSwitch: true,
+ FeatureName: FeatureSFTP,
+ })
+
+ if !result.Allowed {
+ log.Warnf("SFTP access denied for user %s from %s: %v", sess.User(), sess.RemoteAddr(), result.Error)
+ if err := sess.Exit(1); err != nil {
+ log.Debugf("exit SFTP session: %v", err)
+ }
+ return
+ }
+
+ log.Debugf("SFTP subsystem request from user %s (effective user %s)", sess.User(), result.User.Username)
+
+ if !result.RequiresUserSwitching {
+ if err := s.executeSftpDirect(sess); err != nil {
+ log.Errorf("SFTP direct execution: %v", err)
+ }
+ return
+ }
+
+ if err := s.executeSftpWithPrivilegeDrop(sess, result.User); err != nil {
+ log.Errorf("SFTP privilege drop execution: %v", err)
+ }
+}
+
+// executeSftpDirect executes SFTP directly without privilege dropping
+func (s *Server) executeSftpDirect(sess ssh.Session) error {
+ log.Debugf("starting SFTP session for user %s (no privilege dropping)", sess.User())
+
+ sftpServer, err := sftp.NewServer(sess)
+ if err != nil {
+ return fmt.Errorf("SFTP server creation: %w", err)
+ }
+
+ defer func() {
+ if err := sftpServer.Close(); err != nil {
+ log.Debugf("failed to close sftp server: %v", err)
+ }
+ }()
+
+ if err := sftpServer.Serve(); err != nil && err != io.EOF {
+ return fmt.Errorf("serve: %w", err)
+ }
+
+ return nil
+}
diff --git a/client/ssh/server/sftp_js.go b/client/ssh/server/sftp_js.go
new file mode 100644
index 000000000..3b27aeff4
--- /dev/null
+++ b/client/ssh/server/sftp_js.go
@@ -0,0 +1,12 @@
+//go:build js
+
+package server
+
+import (
+ "os/user"
+)
+
+// parseUserCredentials is not supported on JS/WASM
+func (s *Server) parseUserCredentials(_ *user.User) (uint32, uint32, []uint32, error) {
+ return 0, 0, nil, errNotSupported
+}
diff --git a/client/ssh/server/sftp_test.go b/client/ssh/server/sftp_test.go
new file mode 100644
index 000000000..32a3643e4
--- /dev/null
+++ b/client/ssh/server/sftp_test.go
@@ -0,0 +1,228 @@
+package server
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/netip"
+ "os"
+ "os/user"
+ "testing"
+ "time"
+
+ "github.com/pkg/sftp"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ cryptossh "golang.org/x/crypto/ssh"
+
+ "github.com/netbirdio/netbird/client/ssh"
+)
+
+func TestSSHServer_SFTPSubsystem(t *testing.T) {
+ // Skip SFTP test when running as root due to protocol issues in some environments
+ if os.Geteuid() == 0 {
+ t.Skip("Skipping SFTP test when running as root - may have protocol compatibility issues")
+ }
+
+ // Get current user for SSH connection
+ currentUser, err := user.Current()
+ require.NoError(t, err, "Should be able to get current user")
+
+ // Generate host key for server
+ hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
+ require.NoError(t, err)
+
+ // Generate client key pair
+ clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
+ require.NoError(t, err)
+
+ // Create server with SFTP enabled
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := New(serverConfig)
+ server.SetAllowSFTP(true)
+ server.SetAllowRootLogin(true)
+
+ // Start server
+ serverAddr := "127.0.0.1:0"
+ started := make(chan string, 1)
+ errChan := make(chan error, 1)
+
+ go func() {
+ ln, err := net.Listen("tcp", serverAddr)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ actualAddr := ln.Addr().String()
+ if err := ln.Close(); err != nil {
+ errChan <- fmt.Errorf("close temp listener: %w", err)
+ return
+ }
+
+ addrPort, _ := netip.ParseAddrPort(actualAddr)
+ if err := server.Start(context.Background(), addrPort); err != nil {
+ errChan <- err
+ return
+ }
+ started <- actualAddr
+ }()
+
+ select {
+ case actualAddr := <-started:
+ serverAddr = actualAddr
+ case err := <-errChan:
+ t.Fatalf("Server failed to start: %v", err)
+ case <-time.After(5 * time.Second):
+ t.Fatal("Server start timeout")
+ }
+
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+
+ // Parse client private key
+ signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
+ require.NoError(t, err)
+
+ // Parse server host key
+ hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
+ require.NoError(t, err)
+ hostPubKey := hostPrivParsed.PublicKey()
+
+ // (currentUser already obtained at function start)
+
+ // Create SSH client connection
+ clientConfig := &cryptossh.ClientConfig{
+ User: currentUser.Username,
+ Auth: []cryptossh.AuthMethod{
+ cryptossh.PublicKeys(signer),
+ },
+ HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
+ Timeout: 5 * time.Second,
+ }
+
+ conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig)
+ require.NoError(t, err, "SSH connection should succeed")
+ defer func() {
+ if err := conn.Close(); err != nil {
+ t.Logf("connection close error: %v", err)
+ }
+ }()
+
+ // Create SFTP client
+ sftpClient, err := sftp.NewClient(conn)
+ require.NoError(t, err, "SFTP client creation should succeed")
+ defer func() {
+ if err := sftpClient.Close(); err != nil {
+ t.Logf("SFTP client close error: %v", err)
+ }
+ }()
+
+ // Test basic SFTP operations
+ workingDir, err := sftpClient.Getwd()
+ assert.NoError(t, err, "Should be able to get working directory")
+ assert.NotEmpty(t, workingDir, "Working directory should not be empty")
+
+ // Test directory listing
+ files, err := sftpClient.ReadDir(".")
+ assert.NoError(t, err, "Should be able to list current directory")
+ assert.NotNil(t, files, "File list should not be nil")
+}
+
+func TestSSHServer_SFTPDisabled(t *testing.T) {
+ // Get current user for SSH connection
+ currentUser, err := user.Current()
+ require.NoError(t, err, "Should be able to get current user")
+
+ // Generate host key for server
+ hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
+ require.NoError(t, err)
+
+ // Generate client key pair
+ clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
+ require.NoError(t, err)
+
+ // Create server with SFTP disabled
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := New(serverConfig)
+ server.SetAllowSFTP(false)
+
+ // Start server
+ serverAddr := "127.0.0.1:0"
+ started := make(chan string, 1)
+ errChan := make(chan error, 1)
+
+ go func() {
+ ln, err := net.Listen("tcp", serverAddr)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ actualAddr := ln.Addr().String()
+ if err := ln.Close(); err != nil {
+ errChan <- fmt.Errorf("close temp listener: %w", err)
+ return
+ }
+
+ addrPort, _ := netip.ParseAddrPort(actualAddr)
+ if err := server.Start(context.Background(), addrPort); err != nil {
+ errChan <- err
+ return
+ }
+ started <- actualAddr
+ }()
+
+ select {
+ case actualAddr := <-started:
+ serverAddr = actualAddr
+ case err := <-errChan:
+ t.Fatalf("Server failed to start: %v", err)
+ case <-time.After(5 * time.Second):
+ t.Fatal("Server start timeout")
+ }
+
+ defer func() {
+ err := server.Stop()
+ require.NoError(t, err)
+ }()
+
+ // Parse client private key
+ signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
+ require.NoError(t, err)
+
+ // Parse server host key
+ hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
+ require.NoError(t, err)
+ hostPubKey := hostPrivParsed.PublicKey()
+
+ // (currentUser already obtained at function start)
+
+ // Create SSH client connection
+ clientConfig := &cryptossh.ClientConfig{
+ User: currentUser.Username,
+ Auth: []cryptossh.AuthMethod{
+ cryptossh.PublicKeys(signer),
+ },
+ HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
+ Timeout: 5 * time.Second,
+ }
+
+ conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig)
+ require.NoError(t, err, "SSH connection should succeed")
+ defer func() {
+ if err := conn.Close(); err != nil {
+ t.Logf("connection close error: %v", err)
+ }
+ }()
+
+ // Try to create SFTP client - should fail when SFTP is disabled
+ _, err = sftp.NewClient(conn)
+ assert.Error(t, err, "SFTP client creation should fail when SFTP is disabled")
+}
diff --git a/client/ssh/server/sftp_unix.go b/client/ssh/server/sftp_unix.go
new file mode 100644
index 000000000..44202bead
--- /dev/null
+++ b/client/ssh/server/sftp_unix.go
@@ -0,0 +1,71 @@
+//go:build !windows
+
+package server
+
+import (
+ "errors"
+ "fmt"
+ "os"
+ "os/exec"
+ "os/user"
+ "strconv"
+
+ "github.com/gliderlabs/ssh"
+ log "github.com/sirupsen/logrus"
+)
+
+// executeSftpWithPrivilegeDrop executes SFTP using Unix privilege dropping
+func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error {
+ uid, gid, groups, err := s.parseUserCredentials(targetUser)
+ if err != nil {
+ return fmt.Errorf("parse user credentials: %w", err)
+ }
+
+ sftpCmd, err := s.createSftpExecutorCommand(sess, uid, gid, groups, targetUser.HomeDir)
+ if err != nil {
+ return fmt.Errorf("create executor: %w", err)
+ }
+
+ sftpCmd.Stdin = sess
+ sftpCmd.Stdout = sess
+ sftpCmd.Stderr = sess.Stderr()
+
+ log.Tracef("starting SFTP with privilege dropping to user %s (UID=%d, GID=%d)", targetUser.Username, uid, gid)
+
+ if err := sftpCmd.Start(); err != nil {
+ return fmt.Errorf("starting SFTP executor: %w", err)
+ }
+
+ if err := sftpCmd.Wait(); err != nil {
+ var exitError *exec.ExitError
+ if errors.As(err, &exitError) {
+ log.Tracef("SFTP process exited with code %d", exitError.ExitCode())
+ return nil
+ }
+ return fmt.Errorf("exec: %w", err)
+ }
+
+ return nil
+}
+
+// createSftpExecutorCommand creates a command that spawns netbird ssh sftp for privilege dropping
+func (s *Server) createSftpExecutorCommand(sess ssh.Session, uid, gid uint32, groups []uint32, workingDir string) (*exec.Cmd, error) {
+ netbirdPath, err := os.Executable()
+ if err != nil {
+ return nil, err
+ }
+
+ args := []string{
+ "ssh", "sftp",
+ "--uid", strconv.FormatUint(uint64(uid), 10),
+ "--gid", strconv.FormatUint(uint64(gid), 10),
+ "--working-dir", workingDir,
+ }
+
+ for _, group := range groups {
+ args = append(args, "--groups", strconv.FormatUint(uint64(group), 10))
+ }
+
+ log.Tracef("creating SFTP executor command: %s %v", netbirdPath, args)
+ return exec.CommandContext(sess.Context(), netbirdPath, args...), nil
+}
diff --git a/client/ssh/server/sftp_windows.go b/client/ssh/server/sftp_windows.go
new file mode 100644
index 000000000..dc532b9e7
--- /dev/null
+++ b/client/ssh/server/sftp_windows.go
@@ -0,0 +1,91 @@
+//go:build windows
+
+package server
+
+import (
+ "errors"
+ "fmt"
+ "os"
+ "os/exec"
+ "os/user"
+
+ "github.com/gliderlabs/ssh"
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/sys/windows"
+)
+
+// createSftpCommand creates a Windows SFTP command with user switching.
+// The caller must close the returned token handle after starting the process.
+func (s *Server) createSftpCommand(targetUser *user.User, sess ssh.Session) (*exec.Cmd, windows.Token, error) {
+ username, domain := s.parseUsername(targetUser.Username)
+
+ netbirdPath, err := os.Executable()
+ if err != nil {
+ return nil, 0, fmt.Errorf("get netbird executable path: %w", err)
+ }
+
+ args := []string{
+ "ssh", "sftp",
+ "--working-dir", targetUser.HomeDir,
+ "--windows-username", username,
+ "--windows-domain", domain,
+ }
+
+ pd := NewPrivilegeDropper()
+ token, err := pd.createToken(username, domain)
+ if err != nil {
+ return nil, 0, fmt.Errorf("create token: %w", err)
+ }
+
+ defer func() {
+ if err := windows.CloseHandle(token); err != nil {
+ log.Warnf("failed to close impersonation token: %v", err)
+ }
+ }()
+
+ cmd, primaryToken, err := pd.createProcessWithToken(sess.Context(), windows.Token(token), netbirdPath, append([]string{netbirdPath}, args...), targetUser.HomeDir)
+ if err != nil {
+ return nil, 0, fmt.Errorf("create SFTP command: %w", err)
+ }
+
+ log.Debugf("Created Windows SFTP command with user switching for %s", targetUser.Username)
+ return cmd, primaryToken, nil
+}
+
+// executeSftpCommand executes a Windows SFTP command with proper I/O handling
+func (s *Server) executeSftpCommand(sess ssh.Session, sftpCmd *exec.Cmd, token windows.Token) error {
+ defer func() {
+ if err := windows.CloseHandle(windows.Handle(token)); err != nil {
+ log.Debugf("close primary token: %v", err)
+ }
+ }()
+
+ sftpCmd.Stdin = sess
+ sftpCmd.Stdout = sess
+ sftpCmd.Stderr = sess.Stderr()
+
+ if err := sftpCmd.Start(); err != nil {
+ return fmt.Errorf("starting sftp executor: %w", err)
+ }
+
+ if err := sftpCmd.Wait(); err != nil {
+ var exitError *exec.ExitError
+ if errors.As(err, &exitError) {
+ log.Tracef("sftp process exited with code %d", exitError.ExitCode())
+ return nil
+ }
+
+ return fmt.Errorf("exec sftp: %w", err)
+ }
+
+ return nil
+}
+
+// executeSftpWithPrivilegeDrop executes SFTP using Windows privilege dropping
+func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error {
+ sftpCmd, token, err := s.createSftpCommand(targetUser, sess)
+ if err != nil {
+ return fmt.Errorf("create sftp: %w", err)
+ }
+ return s.executeSftpCommand(sess, sftpCmd, token)
+}
diff --git a/client/ssh/server/shell.go b/client/ssh/server/shell.go
new file mode 100644
index 000000000..fea9d2910
--- /dev/null
+++ b/client/ssh/server/shell.go
@@ -0,0 +1,180 @@
+package server
+
+import (
+ "bufio"
+ "fmt"
+ "net"
+ "os"
+ "os/exec"
+ "os/user"
+ "runtime"
+ "strconv"
+ "strings"
+
+ "github.com/gliderlabs/ssh"
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ defaultUnixShell = "/bin/sh"
+
+ pwshExe = "pwsh.exe" // #nosec G101 - This is not a credential, just executable name
+ powershellExe = "powershell.exe"
+)
+
+// getUserShell returns the appropriate shell for the given user ID
+// Handles all platform-specific logic and fallbacks consistently
+func getUserShell(userID string) string {
+ switch runtime.GOOS {
+ case "windows":
+ return getWindowsUserShell()
+ default:
+ return getUnixUserShell(userID)
+ }
+}
+
+// getWindowsUserShell returns the best shell for Windows users.
+// We intentionally do not support cmd.exe or COMSPEC fallbacks to avoid command injection
+// vulnerabilities that arise from cmd.exe's complex command line parsing and special characters.
+// PowerShell provides safer argument handling and is available on all modern Windows systems.
+// Order: pwsh.exe -> powershell.exe
+func getWindowsUserShell() string {
+ if path, err := exec.LookPath(pwshExe); err == nil {
+ return path
+ }
+ if path, err := exec.LookPath(powershellExe); err == nil {
+ return path
+ }
+
+ return `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`
+}
+
+// getUnixUserShell returns the shell for Unix-like systems
+func getUnixUserShell(userID string) string {
+ shell := getShellFromPasswd(userID)
+ if shell != "" {
+ return shell
+ }
+
+ if shell := os.Getenv("SHELL"); shell != "" {
+ return shell
+ }
+
+ return defaultUnixShell
+}
+
+// getShellFromPasswd reads the shell from /etc/passwd for the given user ID
+func getShellFromPasswd(userID string) string {
+ file, err := os.Open("/etc/passwd")
+ if err != nil {
+ return ""
+ }
+ defer func() {
+ if err := file.Close(); err != nil {
+ log.Warnf("close /etc/passwd file: %v", err)
+ }
+ }()
+
+ scanner := bufio.NewScanner(file)
+ for scanner.Scan() {
+ line := scanner.Text()
+ fields := strings.Split(line, ":")
+ if len(fields) < 7 {
+ continue
+ }
+
+ // field 2 is UID
+ if fields[2] == userID {
+ shell := strings.TrimSpace(fields[6])
+ return shell
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ log.Warnf("error reading /etc/passwd: %v", err)
+ }
+
+ return ""
+}
+
+// prepareUserEnv prepares environment variables for user execution
+func prepareUserEnv(user *user.User, shell string) []string {
+ pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games"
+ if runtime.GOOS == "windows" {
+ pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`
+ }
+
+ return []string{
+ fmt.Sprint("SHELL=" + shell),
+ fmt.Sprint("USER=" + user.Username),
+ fmt.Sprint("LOGNAME=" + user.Username),
+ fmt.Sprint("HOME=" + user.HomeDir),
+ "PATH=" + pathValue,
+ }
+}
+
+// acceptEnv checks if environment variable from SSH client should be accepted
+// This is a whitelist of variables that SSH clients can send to the server
+func acceptEnv(envVar string) bool {
+ varName := envVar
+ if idx := strings.Index(envVar, "="); idx != -1 {
+ varName = envVar[:idx]
+ }
+
+ exactMatches := []string{
+ "LANG",
+ "LANGUAGE",
+ "TERM",
+ "COLORTERM",
+ "EDITOR",
+ "VISUAL",
+ "PAGER",
+ "LESS",
+ "LESSCHARSET",
+ "TZ",
+ }
+
+ prefixMatches := []string{
+ "LC_",
+ }
+
+ for _, exact := range exactMatches {
+ if varName == exact {
+ return true
+ }
+ }
+
+ for _, prefix := range prefixMatches {
+ if strings.HasPrefix(varName, prefix) {
+ return true
+ }
+ }
+
+ return false
+}
+
+// prepareSSHEnv prepares SSH protocol-specific environment variables
+// These variables provide information about the SSH connection itself
+func prepareSSHEnv(session ssh.Session) []string {
+ remoteAddr := session.RemoteAddr()
+ localAddr := session.LocalAddr()
+
+ remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
+ if err != nil {
+ remoteHost = remoteAddr.String()
+ remotePort = "0"
+ }
+
+ localHost, localPort, err := net.SplitHostPort(localAddr.String())
+ if err != nil {
+ localHost = localAddr.String()
+ localPort = strconv.Itoa(InternalSSHPort)
+ }
+
+ return []string{
+ // SSH_CLIENT format: "client_ip client_port server_port"
+ fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
+ // SSH_CONNECTION format: "client_ip client_port server_ip server_port"
+ fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
+ }
+}
diff --git a/client/ssh/server/test.go b/client/ssh/server/test.go
new file mode 100644
index 000000000..20930c721
--- /dev/null
+++ b/client/ssh/server/test.go
@@ -0,0 +1,45 @@
+package server
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/netip"
+ "testing"
+ "time"
+)
+
+func StartTestServer(t *testing.T, server *Server) string {
+ started := make(chan string, 1)
+ errChan := make(chan error, 1)
+
+ go func() {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ errChan <- err
+ return
+ }
+ actualAddr := ln.Addr().String()
+ if err := ln.Close(); err != nil {
+ errChan <- fmt.Errorf("close temp listener: %w", err)
+ return
+ }
+
+ addrPort := netip.MustParseAddrPort(actualAddr)
+ if err := server.Start(context.Background(), addrPort); err != nil {
+ errChan <- err
+ return
+ }
+ started <- actualAddr
+ }()
+
+ select {
+ case actualAddr := <-started:
+ return actualAddr
+ case err := <-errChan:
+ t.Fatalf("Server failed to start: %v", err)
+ case <-time.After(5 * time.Second):
+ t.Fatal("Server start timeout")
+ }
+ return ""
+}
diff --git a/client/ssh/server/user_utils.go b/client/ssh/server/user_utils.go
new file mode 100644
index 000000000..799882cbb
--- /dev/null
+++ b/client/ssh/server/user_utils.go
@@ -0,0 +1,411 @@
+package server
+
+import (
+ "errors"
+ "fmt"
+ "os"
+ "os/user"
+ "runtime"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
+)
+
+var (
+ ErrPrivilegeRequired = errors.New("SeAssignPrimaryTokenPrivilege required for user switching - NetBird must run with elevated privileges")
+ ErrPrivilegedUserSwitch = errors.New("cannot switch to privileged user - current user lacks required privileges")
+)
+
+// isPlatformUnix returns true for Unix-like platforms (Linux, macOS, etc.)
+func isPlatformUnix() bool {
+ return getCurrentOS() != "windows"
+}
+
+// Dependency injection variables for testing - allows mocking dynamic runtime checks
+var (
+ getCurrentUser = user.Current
+ lookupUser = user.Lookup
+ getCurrentOS = func() string { return runtime.GOOS }
+ getIsProcessPrivileged = isCurrentProcessPrivileged
+
+ getEuid = os.Geteuid
+)
+
+const (
+ // FeatureSSHLogin represents SSH login operations for privilege checking
+ FeatureSSHLogin = "SSH login"
+ // FeatureSFTP represents SFTP operations for privilege checking
+ FeatureSFTP = "SFTP"
+)
+
+// PrivilegeCheckRequest represents a privilege check request
+type PrivilegeCheckRequest struct {
+ // Username being requested (empty = current user)
+ RequestedUsername string
+ FeatureSupportsUserSwitch bool // Does this feature/operation support user switching?
+ FeatureName string
+}
+
+// PrivilegeCheckResult represents the result of a privilege check
+type PrivilegeCheckResult struct {
+ // Allowed indicates whether the privilege check passed
+ Allowed bool
+ // User is the effective user to use for the operation (nil if not allowed)
+ User *user.User
+ // Error contains the reason for denial (nil if allowed)
+ Error error
+ // UsedFallback indicates we fell back to current user instead of requested user.
+ // This happens on Unix when running as an unprivileged user (e.g., in containers)
+ // where there's no point in user switching since we lack privileges anyway.
+ // When true, all privilege checks have already been performed and no additional
+ // privilege dropping or root checks are needed - the current user is the target.
+ UsedFallback bool
+ // RequiresUserSwitching indicates whether user switching will actually occur
+ // (false for fallback cases where no actual switching happens)
+ RequiresUserSwitching bool
+}
+
+// CheckPrivileges performs comprehensive privilege checking for all SSH features.
+// This is the single source of truth for privilege decisions across the SSH server.
+func (s *Server) CheckPrivileges(req PrivilegeCheckRequest) PrivilegeCheckResult {
+ context, err := s.buildPrivilegeCheckContext(req.FeatureName)
+ if err != nil {
+ return PrivilegeCheckResult{Allowed: false, Error: err}
+ }
+
+ // Handle empty username case - but still check root access controls
+ if req.RequestedUsername == "" {
+ if isPrivilegedUsername(context.currentUser.Username) && !context.allowRoot {
+ return PrivilegeCheckResult{
+ Allowed: false,
+ Error: &PrivilegedUserError{Username: context.currentUser.Username},
+ }
+ }
+ return PrivilegeCheckResult{
+ Allowed: true,
+ User: context.currentUser,
+ RequiresUserSwitching: false,
+ }
+ }
+
+ return s.checkUserRequest(context, req)
+}
+
+// buildPrivilegeCheckContext gathers all the context needed for privilege checking
+func (s *Server) buildPrivilegeCheckContext(featureName string) (*privilegeCheckContext, error) {
+ currentUser, err := getCurrentUser()
+ if err != nil {
+ return nil, fmt.Errorf("get current user for %s: %w", featureName, err)
+ }
+
+ s.mu.RLock()
+ allowRoot := s.allowRootLogin
+ s.mu.RUnlock()
+
+ return &privilegeCheckContext{
+ currentUser: currentUser,
+ currentUserPrivileged: getIsProcessPrivileged(),
+ allowRoot: allowRoot,
+ }, nil
+}
+
+// checkUserRequest handles normal privilege checking flow for specific usernames
+func (s *Server) checkUserRequest(ctx *privilegeCheckContext, req PrivilegeCheckRequest) PrivilegeCheckResult {
+ if !ctx.currentUserPrivileged && isPlatformUnix() {
+ log.Debugf("Unix non-privileged shortcut: falling back to current user %s for %s (requested: %s)",
+ ctx.currentUser.Username, req.FeatureName, req.RequestedUsername)
+ return PrivilegeCheckResult{
+ Allowed: true,
+ User: ctx.currentUser,
+ UsedFallback: true,
+ RequiresUserSwitching: false,
+ }
+ }
+
+ resolvedUser, err := s.resolveRequestedUser(req.RequestedUsername)
+ if err != nil {
+ // Calculate if user switching would be required even if lookup failed
+ needsUserSwitching := !isSameUser(req.RequestedUsername, ctx.currentUser.Username)
+ return PrivilegeCheckResult{
+ Allowed: false,
+ Error: err,
+ RequiresUserSwitching: needsUserSwitching,
+ }
+ }
+
+ needsUserSwitching := !isSameResolvedUser(resolvedUser, ctx.currentUser)
+
+ if isPrivilegedUsername(resolvedUser.Username) && !ctx.allowRoot {
+ return PrivilegeCheckResult{
+ Allowed: false,
+ Error: &PrivilegedUserError{Username: resolvedUser.Username},
+ RequiresUserSwitching: needsUserSwitching,
+ }
+ }
+
+ if needsUserSwitching && !req.FeatureSupportsUserSwitch {
+ return PrivilegeCheckResult{
+ Allowed: false,
+ Error: fmt.Errorf("%s: user switching not supported by this feature", req.FeatureName),
+ RequiresUserSwitching: needsUserSwitching,
+ }
+ }
+
+ return PrivilegeCheckResult{
+ Allowed: true,
+ User: resolvedUser,
+ RequiresUserSwitching: needsUserSwitching,
+ }
+}
+
+// resolveRequestedUser resolves a username to its canonical user identity
+func (s *Server) resolveRequestedUser(requestedUsername string) (*user.User, error) {
+ if requestedUsername == "" {
+ return getCurrentUser()
+ }
+
+ if err := validateUsername(requestedUsername); err != nil {
+ return nil, fmt.Errorf("invalid username %q: %w", requestedUsername, err)
+ }
+
+ u, err := lookupUser(requestedUsername)
+ if err != nil {
+ return nil, &UserNotFoundError{Username: requestedUsername, Cause: err}
+ }
+ return u, nil
+}
+
+// isSameResolvedUser compares two resolved user identities
+func isSameResolvedUser(user1, user2 *user.User) bool {
+ if user1 == nil || user2 == nil {
+ return user1 == user2
+ }
+ return user1.Uid == user2.Uid
+}
+
+// privilegeCheckContext holds all context needed for privilege checking
+type privilegeCheckContext struct {
+ currentUser *user.User
+ currentUserPrivileged bool
+ allowRoot bool
+}
+
+// isSameUser checks if two usernames refer to the same user
+// SECURITY: This function must be conservative - it should only return true
+// when we're certain both usernames refer to the exact same user identity
+func isSameUser(requestedUsername, currentUsername string) bool {
+ // Empty requested username means current user
+ if requestedUsername == "" {
+ return true
+ }
+
+ // Exact match (most common case)
+ if getCurrentOS() == "windows" {
+ if strings.EqualFold(requestedUsername, currentUsername) {
+ return true
+ }
+ } else {
+ if requestedUsername == currentUsername {
+ return true
+ }
+ }
+
+ // Windows domain resolution: only allow domain stripping when comparing
+ // a bare username against the current user's domain-qualified name
+ if getCurrentOS() == "windows" {
+ return isWindowsSameUser(requestedUsername, currentUsername)
+ }
+
+ return false
+}
+
+// isWindowsSameUser handles Windows-specific user comparison with domain logic
+func isWindowsSameUser(requestedUsername, currentUsername string) bool {
+ // Extract domain and username parts
+ extractParts := func(name string) (domain, user string) {
+ // Handle DOMAIN\username format
+ if idx := strings.LastIndex(name, `\`); idx != -1 {
+ return name[:idx], name[idx+1:]
+ }
+ // Handle user@domain.com format
+ if idx := strings.Index(name, "@"); idx != -1 {
+ return name[idx+1:], name[:idx]
+ }
+ // No domain specified - local machine
+ return "", name
+ }
+
+ reqDomain, reqUser := extractParts(requestedUsername)
+ curDomain, curUser := extractParts(currentUsername)
+
+ // Case-insensitive username comparison
+ if !strings.EqualFold(reqUser, curUser) {
+ return false
+ }
+
+ // If requested username has no domain, it refers to local machine user
+ // Allow this to match the current user regardless of current user's domain
+ if reqDomain == "" {
+ return true
+ }
+
+ // If both have domains, they must match exactly (case-insensitive)
+ return strings.EqualFold(reqDomain, curDomain)
+}
+
+// SetAllowRootLogin configures root login access
+func (s *Server) SetAllowRootLogin(allow bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.allowRootLogin = allow
+}
+
+// userNameLookup performs user lookup with root login permission check
+func (s *Server) userNameLookup(username string) (*user.User, error) {
+ result := s.CheckPrivileges(PrivilegeCheckRequest{
+ RequestedUsername: username,
+ FeatureSupportsUserSwitch: true,
+ FeatureName: FeatureSSHLogin,
+ })
+
+ if !result.Allowed {
+ return nil, result.Error
+ }
+
+ return result.User, nil
+}
+
+// userPrivilegeCheck performs user lookup with full privilege check result
+func (s *Server) userPrivilegeCheck(username string) (PrivilegeCheckResult, error) {
+ result := s.CheckPrivileges(PrivilegeCheckRequest{
+ RequestedUsername: username,
+ FeatureSupportsUserSwitch: true,
+ FeatureName: FeatureSSHLogin,
+ })
+
+ if !result.Allowed {
+ return result, result.Error
+ }
+
+ return result, nil
+}
+
+// isPrivilegedUsername checks if the given username represents a privileged user across platforms.
+// On Unix: root
+// On Windows: Administrator, SYSTEM (case-insensitive)
+// Handles domain-qualified usernames like "DOMAIN\Administrator" or "user@domain.com"
+func isPrivilegedUsername(username string) bool {
+ if getCurrentOS() != "windows" {
+ return username == "root"
+ }
+
+ bareUsername := username
+ // Handle Windows domain format: DOMAIN\username
+ if idx := strings.LastIndex(username, `\`); idx != -1 {
+ bareUsername = username[idx+1:]
+ }
+ // Handle email-style format: username@domain.com
+ if idx := strings.Index(bareUsername, "@"); idx != -1 {
+ bareUsername = bareUsername[:idx]
+ }
+
+ return isWindowsPrivilegedUser(bareUsername)
+}
+
+// isWindowsPrivilegedUser checks if a bare username (domain already stripped) represents a Windows privileged account
+func isWindowsPrivilegedUser(bareUsername string) bool {
+ // common privileged usernames (case insensitive)
+ privilegedNames := []string{
+ "administrator",
+ "admin",
+ "root",
+ "system",
+ "localsystem",
+ "networkservice",
+ "localservice",
+ }
+
+ usernameLower := strings.ToLower(bareUsername)
+ for _, privilegedName := range privilegedNames {
+ if usernameLower == privilegedName {
+ return true
+ }
+ }
+
+ // computer accounts (ending with $) are not privileged by themselves
+ // They only gain privileges through group membership or specific SIDs
+
+ if targetUser, err := lookupUser(bareUsername); err == nil {
+ return isWindowsPrivilegedSID(targetUser.Uid)
+ }
+
+ return false
+}
+
+// isWindowsPrivilegedSID checks if a Windows SID represents a privileged account
+func isWindowsPrivilegedSID(sid string) bool {
+ privilegedSIDs := []string{
+ "S-1-5-18", // Local System (SYSTEM)
+ "S-1-5-19", // Local Service (NT AUTHORITY\LOCAL SERVICE)
+ "S-1-5-20", // Network Service (NT AUTHORITY\NETWORK SERVICE)
+ "S-1-5-32-544", // Administrators group (BUILTIN\Administrators)
+ "S-1-5-500", // Built-in Administrator account (local machine RID 500)
+ }
+
+ for _, privilegedSID := range privilegedSIDs {
+ if sid == privilegedSID {
+ return true
+ }
+ }
+
+ // Check for domain administrator accounts (RID 500 in any domain)
+ // Format: S-1-5-21-domain-domain-domain-500
+ // This is reliable as RID 500 is reserved for the domain Administrator account
+ if strings.HasPrefix(sid, "S-1-5-21-") && strings.HasSuffix(sid, "-500") {
+ return true
+ }
+
+ // Check for other well-known privileged RIDs in domain contexts
+ // RID 512 = Domain Admins group, RID 516 = Domain Controllers group
+ if strings.HasPrefix(sid, "S-1-5-21-") {
+ if strings.HasSuffix(sid, "-512") || // Domain Admins group
+ strings.HasSuffix(sid, "-516") || // Domain Controllers group
+ strings.HasSuffix(sid, "-519") { // Enterprise Admins group
+ return true
+ }
+ }
+
+ return false
+}
+
+// isCurrentProcessPrivileged checks if the current process is running with elevated privileges.
+// On Unix systems, this means running as root (UID 0).
+// On Windows, this means running as Administrator or SYSTEM.
+func isCurrentProcessPrivileged() bool {
+ if getCurrentOS() == "windows" {
+ return isWindowsElevated()
+ }
+ return getEuid() == 0
+}
+
+// isWindowsElevated checks if the current process is running with elevated privileges on Windows
+func isWindowsElevated() bool {
+ currentUser, err := getCurrentUser()
+ if err != nil {
+ log.Errorf("failed to get current user for privilege check, assuming non-privileged: %v", err)
+ return false
+ }
+
+ if isWindowsPrivilegedSID(currentUser.Uid) {
+ log.Debugf("Windows user switching supported: running as privileged SID %s", currentUser.Uid)
+ return true
+ }
+
+ if isPrivilegedUsername(currentUser.Username) {
+ log.Debugf("Windows user switching supported: running as privileged username %s", currentUser.Username)
+ return true
+ }
+
+ log.Debugf("Windows user switching not supported: not running as privileged user (current: %s)", currentUser.Uid)
+ return false
+}
diff --git a/client/ssh/server/user_utils_js.go b/client/ssh/server/user_utils_js.go
new file mode 100644
index 000000000..163b24c6c
--- /dev/null
+++ b/client/ssh/server/user_utils_js.go
@@ -0,0 +1,8 @@
+//go:build js
+
+package server
+
+// validateUsername is not supported on JS/WASM
+func validateUsername(_ string) error {
+ return errNotSupported
+}
diff --git a/client/ssh/server/user_utils_test.go b/client/ssh/server/user_utils_test.go
new file mode 100644
index 000000000..637dc10d0
--- /dev/null
+++ b/client/ssh/server/user_utils_test.go
@@ -0,0 +1,908 @@
+package server
+
+import (
+ "errors"
+ "os/user"
+ "runtime"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+// Test helper functions
+func createTestUser(username, uid, gid, homeDir string) *user.User {
+ return &user.User{
+ Uid: uid,
+ Gid: gid,
+ Username: username,
+ Name: username,
+ HomeDir: homeDir,
+ }
+}
+
+// Test dependency injection setup - injects platform dependencies to test real logic
+func setupTestDependencies(currentUser *user.User, currentUserErr error, os string, euid int, lookupUsers map[string]*user.User, lookupErrors map[string]error) func() {
+ // Store originals
+ originalGetCurrentUser := getCurrentUser
+ originalLookupUser := lookupUser
+ originalGetCurrentOS := getCurrentOS
+ originalGetEuid := getEuid
+
+ // Reset caches to ensure clean test state
+
+ // Set test values - inject platform dependencies
+ getCurrentUser = func() (*user.User, error) {
+ return currentUser, currentUserErr
+ }
+
+ lookupUser = func(username string) (*user.User, error) {
+ if err, exists := lookupErrors[username]; exists {
+ return nil, err
+ }
+ if userObj, exists := lookupUsers[username]; exists {
+ return userObj, nil
+ }
+ return nil, errors.New("user: unknown user " + username)
+ }
+
+ getCurrentOS = func() string {
+ return os
+ }
+
+ getEuid = func() int {
+ return euid
+ }
+
+ // Mock privilege detection based on the test user
+ getIsProcessPrivileged = func() bool {
+ if currentUser == nil {
+ return false
+ }
+ // Check both username and SID for Windows systems
+ if os == "windows" && isWindowsPrivilegedSID(currentUser.Uid) {
+ return true
+ }
+ return isPrivilegedUsername(currentUser.Username)
+ }
+
+ // Return cleanup function
+ return func() {
+ getCurrentUser = originalGetCurrentUser
+ lookupUser = originalLookupUser
+ getCurrentOS = originalGetCurrentOS
+ getEuid = originalGetEuid
+
+ getIsProcessPrivileged = isCurrentProcessPrivileged
+
+ // Reset caches after test
+ }
+}
+
+func TestCheckPrivileges_ComprehensiveMatrix(t *testing.T) {
+ tests := []struct {
+ name string
+ os string
+ euid int
+ currentUser *user.User
+ requestedUsername string
+ featureSupportsUserSwitch bool
+ allowRoot bool
+ lookupUsers map[string]*user.User
+ expectedAllowed bool
+ expectedRequiresSwitch bool
+ }{
+ {
+ name: "linux_root_can_switch_to_alice",
+ os: "linux",
+ euid: 0, // Root process
+ currentUser: createTestUser("root", "0", "0", "/root"),
+ requestedUsername: "alice",
+ featureSupportsUserSwitch: true,
+ allowRoot: true,
+ lookupUsers: map[string]*user.User{
+ "alice": createTestUser("alice", "1000", "1000", "/home/alice"),
+ },
+ expectedAllowed: true,
+ expectedRequiresSwitch: true,
+ },
+ {
+ name: "linux_non_root_fallback_to_current_user",
+ os: "linux",
+ euid: 1000, // Non-root process
+ currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
+ requestedUsername: "bob",
+ featureSupportsUserSwitch: true,
+ allowRoot: true,
+ expectedAllowed: true, // Should fallback to current user (alice)
+ expectedRequiresSwitch: false, // Fallback means no actual switching
+ },
+ {
+ name: "windows_admin_can_switch_to_alice",
+ os: "windows",
+ euid: 1000, // Irrelevant on Windows
+ currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
+ requestedUsername: "alice",
+ featureSupportsUserSwitch: true,
+ allowRoot: true,
+ lookupUsers: map[string]*user.User{
+ "alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
+ },
+ expectedAllowed: true,
+ expectedRequiresSwitch: true,
+ },
+ {
+ name: "windows_non_admin_no_fallback_hard_failure",
+ os: "windows",
+ euid: 1000, // Irrelevant on Windows
+ currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"),
+ requestedUsername: "bob",
+ featureSupportsUserSwitch: true,
+ allowRoot: true,
+ lookupUsers: map[string]*user.User{
+ "bob": createTestUser("bob", "S-1-5-21-123456789-123456789-123456789-1002", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\bob"),
+ },
+ expectedAllowed: true, // Let OS decide - deferred security check
+ expectedRequiresSwitch: true, // Different user was requested
+ },
+ // Comprehensive test matrix: non-root linux with different allowRoot settings
+ {
+ name: "linux_non_root_request_root_allowRoot_false",
+ os: "linux",
+ euid: 1000,
+ currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
+ requestedUsername: "root",
+ featureSupportsUserSwitch: true,
+ allowRoot: false,
+ expectedAllowed: true, // Fallback allows access regardless of root setting
+ expectedRequiresSwitch: false, // Fallback case, no switching
+ },
+ {
+ name: "linux_non_root_request_root_allowRoot_true",
+ os: "linux",
+ euid: 1000,
+ currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
+ requestedUsername: "root",
+ featureSupportsUserSwitch: true,
+ allowRoot: true,
+ expectedAllowed: true, // Should fallback to alice (non-privileged process)
+ expectedRequiresSwitch: false, // Fallback means no actual switching
+ },
+ // Windows admin test matrix
+ {
+ name: "windows_admin_request_root_allowRoot_false",
+ os: "windows",
+ euid: 1000,
+ currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
+ requestedUsername: "root",
+ featureSupportsUserSwitch: true,
+ allowRoot: false,
+ expectedAllowed: false, // Root not allowed
+ expectedRequiresSwitch: true,
+ },
+ {
+ name: "windows_admin_request_root_allowRoot_true",
+ os: "windows",
+ euid: 1000,
+ currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
+ requestedUsername: "root",
+ featureSupportsUserSwitch: true,
+ allowRoot: true,
+ lookupUsers: map[string]*user.User{
+ "root": createTestUser("root", "0", "0", "/root"),
+ },
+ expectedAllowed: true, // Windows user switching should work like Unix
+ expectedRequiresSwitch: true,
+ },
+ // Windows non-admin test matrix
+ {
+ name: "windows_non_admin_request_root_allowRoot_false",
+ os: "windows",
+ euid: 1000,
+ currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
+ requestedUsername: "root",
+ featureSupportsUserSwitch: true,
+ allowRoot: false,
+ expectedAllowed: false, // Root not allowed (allowRoot=false takes precedence)
+ expectedRequiresSwitch: true,
+ },
+ {
+ name: "windows_system_account_allowRoot_false",
+ os: "windows",
+ euid: 1000,
+ currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"),
+ requestedUsername: "root",
+ featureSupportsUserSwitch: true,
+ allowRoot: false,
+ expectedAllowed: false, // Root not allowed
+ expectedRequiresSwitch: true,
+ },
+ {
+ name: "windows_system_account_allowRoot_true",
+ os: "windows",
+ euid: 1000,
+ currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"),
+ requestedUsername: "root",
+ featureSupportsUserSwitch: true,
+ allowRoot: true,
+ lookupUsers: map[string]*user.User{
+ "root": createTestUser("root", "0", "0", "/root"),
+ },
+ expectedAllowed: true, // SYSTEM can switch to root
+ expectedRequiresSwitch: true,
+ },
+ {
+ name: "windows_non_admin_request_root_allowRoot_true",
+ os: "windows",
+ euid: 1000,
+ currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
+ requestedUsername: "root",
+ featureSupportsUserSwitch: true,
+ allowRoot: true,
+ lookupUsers: map[string]*user.User{
+ "root": createTestUser("root", "0", "0", "/root"),
+ },
+ expectedAllowed: true, // Let OS decide - deferred security check
+ expectedRequiresSwitch: true,
+ },
+
+ // Feature doesn't support user switching scenarios
+ {
+ name: "linux_root_feature_no_user_switching_same_user",
+ os: "linux",
+ euid: 0,
+ currentUser: createTestUser("root", "0", "0", "/root"),
+ requestedUsername: "root", // Same user
+ featureSupportsUserSwitch: false,
+ allowRoot: true,
+ lookupUsers: map[string]*user.User{
+ "root": createTestUser("root", "0", "0", "/root"),
+ },
+ expectedAllowed: true, // Same user should work regardless of feature support
+ expectedRequiresSwitch: false,
+ },
+ {
+ name: "linux_root_feature_no_user_switching_different_user",
+ os: "linux",
+ euid: 0,
+ currentUser: createTestUser("root", "0", "0", "/root"),
+ requestedUsername: "alice",
+ featureSupportsUserSwitch: false, // Feature doesn't support switching
+ allowRoot: true,
+ lookupUsers: map[string]*user.User{
+ "alice": createTestUser("alice", "1000", "1000", "/home/alice"),
+ },
+ expectedAllowed: false, // Should deny because feature doesn't support switching
+ expectedRequiresSwitch: true,
+ },
+
+ // Empty username (current user) scenarios
+ {
+ name: "linux_non_root_current_user_empty_username",
+ os: "linux",
+ euid: 1000,
+ currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
+ requestedUsername: "", // Empty = current user
+ featureSupportsUserSwitch: true,
+ allowRoot: false,
+ expectedAllowed: true, // Current user should always work
+ expectedRequiresSwitch: false,
+ },
+ {
+ name: "linux_root_current_user_empty_username_root_not_allowed",
+ os: "linux",
+ euid: 0,
+ currentUser: createTestUser("root", "0", "0", "/root"),
+ requestedUsername: "", // Empty = current user (root)
+ featureSupportsUserSwitch: true,
+ allowRoot: false, // Root not allowed
+ expectedAllowed: false, // Should deny root even when it's current user
+ expectedRequiresSwitch: false,
+ },
+
+ // User not found scenarios
+ {
+ name: "linux_root_user_not_found",
+ os: "linux",
+ euid: 0,
+ currentUser: createTestUser("root", "0", "0", "/root"),
+ requestedUsername: "nonexistent",
+ featureSupportsUserSwitch: true,
+ allowRoot: true,
+ lookupUsers: map[string]*user.User{}, // No users defined = user not found
+ expectedAllowed: false, // Should fail due to user not found
+ expectedRequiresSwitch: true,
+ },
+
+ // Windows feature doesn't support user switching
+ {
+ name: "windows_admin_feature_no_user_switching_different_user",
+ os: "windows",
+ euid: 1000,
+ currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
+ requestedUsername: "alice",
+ featureSupportsUserSwitch: false, // Feature doesn't support switching
+ allowRoot: true,
+ lookupUsers: map[string]*user.User{
+ "alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
+ },
+ expectedAllowed: false, // Should deny because feature doesn't support switching
+ expectedRequiresSwitch: true,
+ },
+
+ // Windows regular user scenarios (non-admin)
+ {
+ name: "windows_regular_user_same_user",
+ os: "windows",
+ euid: 1000,
+ currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
+ requestedUsername: "alice", // Same user
+ featureSupportsUserSwitch: true,
+ allowRoot: false,
+ lookupUsers: map[string]*user.User{
+ "alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
+ },
+ expectedAllowed: true, // Regular user accessing themselves should work
+ expectedRequiresSwitch: false, // No switching for same user
+ },
+ {
+ name: "windows_regular_user_empty_username",
+ os: "windows",
+ euid: 1000,
+ currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
+ requestedUsername: "", // Empty = current user
+ featureSupportsUserSwitch: true,
+ allowRoot: false,
+ expectedAllowed: true, // Current user should always work
+ expectedRequiresSwitch: false, // No switching for current user
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Inject platform dependencies to test real logic
+ cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, tt.lookupUsers, nil)
+ defer cleanup()
+
+ server := &Server{allowRootLogin: tt.allowRoot}
+
+ result := server.CheckPrivileges(PrivilegeCheckRequest{
+ RequestedUsername: tt.requestedUsername,
+ FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch,
+ FeatureName: "SSH login",
+ })
+
+ assert.Equal(t, tt.expectedAllowed, result.Allowed)
+ assert.Equal(t, tt.expectedRequiresSwitch, result.RequiresUserSwitching)
+ })
+ }
+}
+
+func TestUsedFallback_MeansNoPrivilegeDropping(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("Fallback mechanism is Unix-specific")
+ }
+
+ // Create test scenario where fallback should occur
+ server := &Server{allowRootLogin: true}
+
+ // Mock dependencies to simulate non-privileged user
+ originalGetCurrentUser := getCurrentUser
+ originalGetIsProcessPrivileged := getIsProcessPrivileged
+
+ defer func() {
+ getCurrentUser = originalGetCurrentUser
+ getIsProcessPrivileged = originalGetIsProcessPrivileged
+
+ }()
+
+ // Set up mocks for fallback scenario
+ getCurrentUser = func() (*user.User, error) {
+ return createTestUser("netbird", "1000", "1000", "/var/lib/netbird"), nil
+ }
+ getIsProcessPrivileged = func() bool { return false } // Non-privileged
+
+ // Request different user - should fallback
+ result := server.CheckPrivileges(PrivilegeCheckRequest{
+ RequestedUsername: "alice",
+ FeatureSupportsUserSwitch: true,
+ FeatureName: "SSH login",
+ })
+
+ // Verify fallback occurred
+ assert.True(t, result.Allowed, "Should allow with fallback")
+ assert.True(t, result.UsedFallback, "Should indicate fallback was used")
+ assert.Equal(t, "netbird", result.User.Username, "Should return current user")
+ assert.False(t, result.RequiresUserSwitching, "Should not require switching when fallback is used")
+
+ // Key assertion: When UsedFallback is true, no privilege dropping should be needed
+ // because all privilege checks have already been performed and we're using current user
+ t.Logf("UsedFallback=true means: current user (%s) is the target, no privilege dropping needed",
+ result.User.Username)
+}
+
+func TestPrivilegedUsernameDetection(t *testing.T) {
+ tests := []struct {
+ name string
+ username string
+ platform string
+ privileged bool
+ }{
+ // Unix/Linux tests
+ {"unix_root", "root", "linux", true},
+ {"unix_regular_user", "alice", "linux", false},
+ {"unix_root_capital", "Root", "linux", false}, // Case-sensitive
+
+ // Windows tests
+ {"windows_administrator", "Administrator", "windows", true},
+ {"windows_system", "SYSTEM", "windows", true},
+ {"windows_admin", "admin", "windows", true},
+ {"windows_admin_lowercase", "administrator", "windows", true}, // Case-insensitive
+ {"windows_domain_admin", "DOMAIN\\Administrator", "windows", true},
+ {"windows_email_admin", "admin@domain.com", "windows", true},
+ {"windows_regular_user", "alice", "windows", false},
+ {"windows_domain_user", "DOMAIN\\alice", "windows", false},
+ {"windows_localsystem", "localsystem", "windows", true},
+ {"windows_networkservice", "networkservice", "windows", true},
+ {"windows_localservice", "localservice", "windows", true},
+
+ // Computer accounts (these depend on current user context in real implementation)
+ {"windows_computer_account", "WIN2K19-C2$", "windows", false}, // Computer account by itself not privileged
+ {"windows_domain_computer", "DOMAIN\\COMPUTER$", "windows", false}, // Domain computer account
+
+ // Cross-platform
+ {"root_on_windows", "root", "windows", true}, // Root should be privileged everywhere
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Mock the platform for this test
+ cleanup := setupTestDependencies(nil, nil, tt.platform, 1000, nil, nil)
+ defer cleanup()
+
+ result := isPrivilegedUsername(tt.username)
+ assert.Equal(t, tt.privileged, result)
+ })
+ }
+}
+
+func TestWindowsPrivilegedSIDDetection(t *testing.T) {
+ tests := []struct {
+ name string
+ sid string
+ privileged bool
+ description string
+ }{
+ // Well-known system accounts
+ {"system_account", "S-1-5-18", true, "Local System (SYSTEM)"},
+ {"local_service", "S-1-5-19", true, "Local Service"},
+ {"network_service", "S-1-5-20", true, "Network Service"},
+ {"administrators_group", "S-1-5-32-544", true, "Administrators group"},
+ {"builtin_administrator", "S-1-5-500", true, "Built-in Administrator"},
+
+ // Domain accounts
+ {"domain_administrator", "S-1-5-21-1234567890-1234567890-1234567890-500", true, "Domain Administrator (RID 500)"},
+ {"domain_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-512", true, "Domain Admins group"},
+ {"domain_controllers_group", "S-1-5-21-1234567890-1234567890-1234567890-516", true, "Domain Controllers group"},
+ {"enterprise_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-519", true, "Enterprise Admins group"},
+
+ // Regular users
+ {"regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1001", false, "Regular domain user"},
+ {"another_regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1234", false, "Another regular user"},
+ {"local_user", "S-1-5-21-1234567890-1234567890-1234567890-1000", false, "Local regular user"},
+
+ // Groups that are not privileged
+ {"domain_users", "S-1-5-21-1234567890-1234567890-1234567890-513", false, "Domain Users group"},
+ {"power_users", "S-1-5-32-547", false, "Power Users group"},
+
+ // Invalid SIDs
+ {"malformed_sid", "S-1-5-invalid", false, "Malformed SID"},
+ {"empty_sid", "", false, "Empty SID"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := isWindowsPrivilegedSID(tt.sid)
+ assert.Equal(t, tt.privileged, result, "Failed for %s: %s", tt.description, tt.sid)
+ })
+ }
+}
+
+func TestIsSameUser(t *testing.T) {
+ tests := []struct {
+ name string
+ user1 string
+ user2 string
+ os string
+ expected bool
+ }{
+ // Basic cases
+ {"same_username", "alice", "alice", "linux", true},
+ {"different_username", "alice", "bob", "linux", false},
+
+ // Linux (no domain processing)
+ {"linux_domain_vs_bare", "DOMAIN\\alice", "alice", "linux", false},
+ {"linux_email_vs_bare", "alice@domain.com", "alice", "linux", false},
+ {"linux_same_literal", "DOMAIN\\alice", "DOMAIN\\alice", "linux", true},
+
+ // Windows (with domain processing) - Note: parameter order is (requested, current, os, expected)
+ {"windows_domain_vs_bare", "alice", "DOMAIN\\alice", "windows", true}, // bare username matches domain current user
+ {"windows_email_vs_bare", "alice", "alice@domain.com", "windows", true}, // bare username matches email current user
+ {"windows_different_domains_same_user", "DOMAIN1\\alice", "DOMAIN2\\alice", "windows", false}, // SECURITY: different domains = different users
+ {"windows_case_insensitive", "Alice", "alice", "windows", true},
+ {"windows_different_users", "DOMAIN\\alice", "DOMAIN\\bob", "windows", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Set up OS mock
+ cleanup := setupTestDependencies(nil, nil, tt.os, 1000, nil, nil)
+ defer cleanup()
+
+ result := isSameUser(tt.user1, tt.user2)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestUsernameValidation_Unix(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("Unix-specific username validation tests")
+ }
+
+ tests := []struct {
+ name string
+ username string
+ wantErr bool
+ errMsg string
+ }{
+ // Valid usernames (Unix/POSIX)
+ {"valid_alphanumeric", "user123", false, ""},
+ {"valid_with_dots", "user.name", false, ""},
+ {"valid_with_hyphens", "user-name", false, ""},
+ {"valid_with_underscores", "user_name", false, ""},
+ {"valid_uppercase", "UserName", false, ""},
+ {"valid_starting_with_digit", "123user", false, ""},
+ {"valid_starting_with_dot", ".hidden", false, ""},
+
+ // Invalid usernames (Unix/POSIX)
+ {"empty_username", "", true, "username cannot be empty"},
+ {"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"},
+ {"username_starting_with_hyphen", "-user", true, "invalid characters"}, // POSIX restriction
+ {"username_with_spaces", "user name", true, "invalid characters"},
+ {"username_with_shell_metacharacters", "user;rm", true, "invalid characters"},
+ {"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"},
+ {"username_with_pipe", "user|rm", true, "invalid characters"},
+ {"username_with_ampersand", "user&rm", true, "invalid characters"},
+ {"username_with_quotes", "user\"name", true, "invalid characters"},
+ {"username_with_newline", "user\nname", true, "invalid characters"},
+ {"reserved_dot", ".", true, "cannot be '.' or '..'"},
+ {"reserved_dotdot", "..", true, "cannot be '.' or '..'"},
+ {"username_with_at_symbol", "user@domain", true, "invalid characters"}, // Not allowed in bare Unix usernames
+ {"username_with_backslash", "user\\name", true, "invalid characters"}, // Not allowed in Unix usernames
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validateUsername(tt.username)
+ if tt.wantErr {
+ assert.Error(t, err, "Should reject invalid username")
+ if tt.errMsg != "" {
+ assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text")
+ }
+ } else {
+ assert.NoError(t, err, "Should accept valid username")
+ }
+ })
+ }
+}
+
+func TestUsernameValidation_Windows(t *testing.T) {
+ if runtime.GOOS != "windows" {
+ t.Skip("Windows-specific username validation tests")
+ }
+
+ tests := []struct {
+ name string
+ username string
+ wantErr bool
+ errMsg string
+ }{
+ // Valid usernames (Windows)
+ {"valid_alphanumeric", "user123", false, ""},
+ {"valid_with_dots", "user.name", false, ""},
+ {"valid_with_hyphens", "user-name", false, ""},
+ {"valid_with_underscores", "user_name", false, ""},
+ {"valid_uppercase", "UserName", false, ""},
+ {"valid_starting_with_digit", "123user", false, ""},
+ {"valid_starting_with_dot", ".hidden", false, ""},
+ {"valid_starting_with_hyphen", "-user", false, ""}, // Windows allows this
+ {"valid_domain_username", "DOMAIN\\user", false, ""}, // Windows domain format
+ {"valid_email_username", "user@domain.com", false, ""}, // Windows email format
+ {"valid_machine_username", "MACHINE\\user", false, ""}, // Windows machine format
+
+ // Invalid usernames (Windows)
+ {"empty_username", "", true, "username cannot be empty"},
+ {"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"},
+ {"username_with_spaces", "user name", true, "invalid characters"},
+ {"username_with_shell_metacharacters", "user;rm", true, "invalid characters"},
+ {"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"},
+ {"username_with_pipe", "user|rm", true, "invalid characters"},
+ {"username_with_ampersand", "user&rm", true, "invalid characters"},
+ {"username_with_quotes", "user\"name", true, "invalid characters"},
+ {"username_with_newline", "user\nname", true, "invalid characters"},
+ {"username_with_brackets", "user[name]", true, "invalid characters"},
+ {"username_with_colon", "user:name", true, "invalid characters"},
+ {"username_with_semicolon", "user;name", true, "invalid characters"},
+ {"username_with_equals", "user=name", true, "invalid characters"},
+ {"username_with_comma", "user,name", true, "invalid characters"},
+ {"username_with_plus", "user+name", true, "invalid characters"},
+ {"username_with_asterisk", "user*name", true, "invalid characters"},
+ {"username_with_question", "user?name", true, "invalid characters"},
+ {"username_with_angles", "user", true, "invalid characters"},
+ {"reserved_dot", ".", true, "cannot be '.' or '..'"},
+ {"reserved_dotdot", "..", true, "cannot be '.' or '..'"},
+ {"username_ending_with_period", "user.", true, "cannot end with a period"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validateUsername(tt.username)
+ if tt.wantErr {
+ assert.Error(t, err, "Should reject invalid username")
+ if tt.errMsg != "" {
+ assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text")
+ }
+ } else {
+ assert.NoError(t, err, "Should accept valid username")
+ }
+ })
+ }
+}
+
+// Test real-world integration scenarios with actual platform capabilities
+func TestCheckPrivileges_RealWorldScenarios(t *testing.T) {
+ tests := []struct {
+ name string
+ feature string
+ featureSupportsUserSwitch bool
+ requestedUsername string
+ allowRoot bool
+ expectedBehaviorPattern string
+ }{
+ {"SSH_login_current_user", "SSH login", true, "", true, "should_allow_current_user"},
+ {"SFTP_current_user", "SFTP", true, "", true, "should_allow_current_user"},
+ {"port_forwarding_current_user", "port forwarding", false, "", true, "should_allow_current_user"},
+ {"SSH_login_root_not_allowed", "SSH login", true, "root", false, "should_deny_root"},
+ {"port_forwarding_different_user", "port forwarding", false, "differentuser", true, "should_deny_switching"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Mock privileged environment to ensure consistent test behavior across environments
+ cleanup := setupTestDependencies(
+ createTestUser("root", "0", "0", "/root"), // Running as root
+ nil,
+ runtime.GOOS,
+ 0, // euid 0 (root)
+ map[string]*user.User{
+ "root": createTestUser("root", "0", "0", "/root"),
+ "differentuser": createTestUser("differentuser", "1000", "1000", "/home/differentuser"),
+ },
+ nil,
+ )
+ defer cleanup()
+
+ server := &Server{allowRootLogin: tt.allowRoot}
+
+ result := server.CheckPrivileges(PrivilegeCheckRequest{
+ RequestedUsername: tt.requestedUsername,
+ FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch,
+ FeatureName: tt.feature,
+ })
+
+ switch tt.expectedBehaviorPattern {
+ case "should_allow_current_user":
+ assert.True(t, result.Allowed, "Should allow current user access")
+ assert.False(t, result.RequiresUserSwitching, "Current user should not require switching")
+ case "should_deny_root":
+ assert.False(t, result.Allowed, "Should deny root when not allowed")
+ assert.Contains(t, result.Error.Error(), "root", "Should mention root in error")
+ case "should_deny_switching":
+ assert.False(t, result.Allowed, "Should deny when feature doesn't support switching")
+ assert.Contains(t, result.Error.Error(), "user switching not supported", "Should mention switching in error")
+ }
+ })
+ }
+}
+
+// Test with actual platform capabilities - no mocking
+func TestCheckPrivileges_ActualPlatform(t *testing.T) {
+ // This test uses the REAL platform capabilities
+ server := &Server{allowRootLogin: true}
+
+ // Test current user access - should always work
+ result := server.CheckPrivileges(PrivilegeCheckRequest{
+ RequestedUsername: "", // Current user
+ FeatureSupportsUserSwitch: true,
+ FeatureName: "SSH login",
+ })
+
+ assert.True(t, result.Allowed, "Current user should always be allowed")
+ assert.False(t, result.RequiresUserSwitching, "Current user should not require switching")
+ assert.NotNil(t, result.User, "Should return current user")
+
+ // Test user switching capability based on actual platform
+ actualIsPrivileged := isCurrentProcessPrivileged() // REAL check
+ actualOS := runtime.GOOS // REAL check
+
+ t.Logf("Platform capabilities: OS=%s, isPrivileged=%v, supportsUserSwitching=%v",
+ actualOS, actualIsPrivileged, actualIsPrivileged)
+
+ // Test requesting different user
+ result = server.CheckPrivileges(PrivilegeCheckRequest{
+ RequestedUsername: "nonexistentuser",
+ FeatureSupportsUserSwitch: true,
+ FeatureName: "SSH login",
+ })
+
+ switch {
+ case actualOS == "windows":
+ // Windows supports user switching but should fail on nonexistent user
+ assert.False(t, result.Allowed, "Windows should deny nonexistent user")
+ assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed")
+ assert.Contains(t, result.Error.Error(), "not found",
+ "Should indicate user not found")
+ case !actualIsPrivileged:
+ // Non-privileged Unix processes should fallback to current user
+ assert.True(t, result.Allowed, "Non-privileged Unix process should fallback to current user")
+ assert.False(t, result.RequiresUserSwitching, "Fallback means no switching actually happens")
+ assert.True(t, result.UsedFallback, "Should indicate fallback was used")
+ assert.NotNil(t, result.User, "Should return current user")
+ default:
+ // Privileged Unix processes should attempt user lookup
+ assert.False(t, result.Allowed, "Should fail due to nonexistent user")
+ assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed")
+ assert.Contains(t, result.Error.Error(), "nonexistentuser",
+ "Should indicate user not found")
+ }
+}
+
+// Test platform detection logic with dependency injection
+func TestPlatformLogic_DependencyInjection(t *testing.T) {
+ tests := []struct {
+ name string
+ os string
+ euid int
+ currentUser *user.User
+ expectedIsProcessPrivileged bool
+ expectedSupportsUserSwitching bool
+ }{
+ {
+ name: "linux_root_process",
+ os: "linux",
+ euid: 0,
+ currentUser: createTestUser("root", "0", "0", "/root"),
+ expectedIsProcessPrivileged: true,
+ expectedSupportsUserSwitching: true,
+ },
+ {
+ name: "linux_non_root_process",
+ os: "linux",
+ euid: 1000,
+ currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
+ expectedIsProcessPrivileged: false,
+ expectedSupportsUserSwitching: false,
+ },
+ {
+ name: "windows_admin_process",
+ os: "windows",
+ euid: 1000, // euid ignored on Windows
+ currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
+ expectedIsProcessPrivileged: true,
+ expectedSupportsUserSwitching: true, // Windows supports user switching when privileged
+ },
+ {
+ name: "windows_regular_process",
+ os: "windows",
+ euid: 1000, // euid ignored on Windows
+ currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"),
+ expectedIsProcessPrivileged: false,
+ expectedSupportsUserSwitching: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Inject platform dependencies and test REAL logic
+ cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, nil, nil)
+ defer cleanup()
+
+ // Test the actual functions with injected dependencies
+ actualIsPrivileged := isCurrentProcessPrivileged()
+ actualSupportsUserSwitching := actualIsPrivileged
+
+ assert.Equal(t, tt.expectedIsProcessPrivileged, actualIsPrivileged,
+ "isCurrentProcessPrivileged() result mismatch")
+ assert.Equal(t, tt.expectedSupportsUserSwitching, actualSupportsUserSwitching,
+ "supportsUserSwitching() result mismatch")
+
+ t.Logf("Platform: %s, EUID: %d, User: %s", tt.os, tt.euid, tt.currentUser.Username)
+ t.Logf("Results: isPrivileged=%v, supportsUserSwitching=%v",
+ actualIsPrivileged, actualSupportsUserSwitching)
+ })
+ }
+}
+
+func TestCheckPrivileges_WindowsElevatedUserSwitching(t *testing.T) {
+ // Test Windows elevated user switching scenarios with simplified privilege logic
+ tests := []struct {
+ name string
+ currentUser *user.User
+ requestedUsername string
+ allowRoot bool
+ expectedAllowed bool
+ expectedErrorContains string
+ }{
+ {
+ name: "windows_admin_can_switch_to_alice",
+ currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"),
+ requestedUsername: "alice",
+ allowRoot: true,
+ expectedAllowed: true,
+ },
+ {
+ name: "windows_non_admin_can_try_switch",
+ currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\\\Users\\\\alice"),
+ requestedUsername: "bob",
+ allowRoot: true,
+ expectedAllowed: true, // Privilege check allows it, OS will reject during execution
+ },
+ {
+ name: "windows_system_can_switch_to_alice",
+ currentUser: createTestUser("SYSTEM", "S-1-5-18", "S-1-5-18", "C:\\\\Windows\\\\system32\\\\config\\\\systemprofile"),
+ requestedUsername: "alice",
+ allowRoot: true,
+ expectedAllowed: true,
+ },
+ {
+ name: "windows_admin_root_not_allowed",
+ currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"),
+ requestedUsername: "root",
+ allowRoot: false,
+ expectedAllowed: false,
+ expectedErrorContains: "privileged user login is disabled",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Setup test dependencies with Windows OS and specified privileges
+ lookupUsers := map[string]*user.User{
+ tt.requestedUsername: createTestUser(tt.requestedUsername, "1002", "1002", "C:\\\\Users\\\\"+tt.requestedUsername),
+ }
+ cleanup := setupTestDependencies(tt.currentUser, nil, "windows", 1000, lookupUsers, nil)
+ defer cleanup()
+
+ server := &Server{allowRootLogin: tt.allowRoot}
+
+ result := server.CheckPrivileges(PrivilegeCheckRequest{
+ RequestedUsername: tt.requestedUsername,
+ FeatureSupportsUserSwitch: true,
+ FeatureName: "SSH login",
+ })
+
+ assert.Equal(t, tt.expectedAllowed, result.Allowed,
+ "Privilege check result should match expected for %s", tt.name)
+
+ if !tt.expectedAllowed && tt.expectedErrorContains != "" {
+ assert.NotNil(t, result.Error, "Should have error when not allowed")
+ assert.Contains(t, result.Error.Error(), tt.expectedErrorContains,
+ "Error should contain expected message")
+ }
+
+ if tt.expectedAllowed && tt.requestedUsername != "" && tt.currentUser.Username != tt.requestedUsername {
+ assert.True(t, result.RequiresUserSwitching, "Should require user switching for different user")
+ }
+ })
+ }
+}
diff --git a/client/ssh/server/userswitching_js.go b/client/ssh/server/userswitching_js.go
new file mode 100644
index 000000000..333c19259
--- /dev/null
+++ b/client/ssh/server/userswitching_js.go
@@ -0,0 +1,8 @@
+//go:build js
+
+package server
+
+// enableUserSwitching is not supported on JS/WASM
+func enableUserSwitching() error {
+ return errNotSupported
+}
diff --git a/client/ssh/server/userswitching_unix.go b/client/ssh/server/userswitching_unix.go
new file mode 100644
index 000000000..bc1557419
--- /dev/null
+++ b/client/ssh/server/userswitching_unix.go
@@ -0,0 +1,260 @@
+//go:build unix
+
+package server
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "net/netip"
+ "os"
+ "os/exec"
+ "os/user"
+ "regexp"
+ "runtime"
+ "strconv"
+
+ "github.com/gliderlabs/ssh"
+ log "github.com/sirupsen/logrus"
+)
+
+// POSIX portable filename character set regex: [a-zA-Z0-9._-]
+// First character cannot be hyphen (POSIX requirement)
+var posixUsernameRegex = regexp.MustCompile(`^[a-zA-Z0-9._][a-zA-Z0-9._-]*$`)
+
+// validateUsername validates that a username conforms to POSIX standards with security considerations
+func validateUsername(username string) error {
+ if username == "" {
+ return errors.New("username cannot be empty")
+ }
+
+ // POSIX allows up to 256 characters, but practical limit is 32 for compatibility
+ if len(username) > 32 {
+ return errors.New("username too long (max 32 characters)")
+ }
+
+ if !posixUsernameRegex.MatchString(username) {
+ return errors.New("username contains invalid characters (must match POSIX portable filename character set)")
+ }
+
+ if username == "." || username == ".." {
+ return fmt.Errorf("username cannot be '.' or '..'")
+ }
+
+ // Warn if username is fully numeric (can cause issues with UID/username ambiguity)
+ if isFullyNumeric(username) {
+ log.Warnf("fully numeric username '%s' may cause issues with some commands", username)
+ }
+
+ return nil
+}
+
+// isFullyNumeric checks if username contains only digits
+func isFullyNumeric(username string) bool {
+ for _, char := range username {
+ if char < '0' || char > '9' {
+ return false
+ }
+ }
+ return true
+}
+
+// createPtyLoginCommand creates a Pty command using login for privileged processes
+func (s *Server) createPtyLoginCommand(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) {
+ loginPath, args, err := s.getLoginCmd(localUser.Username, session.RemoteAddr())
+ if err != nil {
+ return nil, fmt.Errorf("get login command: %w", err)
+ }
+
+ execCmd := exec.CommandContext(session.Context(), loginPath, args...)
+ execCmd.Dir = localUser.HomeDir
+ execCmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
+
+ return execCmd, nil
+}
+
+// getLoginCmd returns the login command and args for privileged Pty user switching
+func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []string, error) {
+ loginPath, err := exec.LookPath("login")
+ if err != nil {
+ return "", nil, fmt.Errorf("login command not available: %w", err)
+ }
+
+ addrPort, err := netip.ParseAddrPort(remoteAddr.String())
+ if err != nil {
+ return "", nil, fmt.Errorf("parse remote address: %w", err)
+ }
+
+ switch runtime.GOOS {
+ case "linux":
+ p, a := s.getLinuxLoginCmd(loginPath, username, addrPort.Addr().String())
+ return p, a, nil
+ case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly":
+ return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil
+ default:
+ return "", nil, fmt.Errorf("unsupported Unix platform for login command: %s", runtime.GOOS)
+ }
+}
+
+// getLinuxLoginCmd returns the login command for Linux systems.
+// Handles differences between util-linux and shadow-utils login implementations.
+func (s *Server) getLinuxLoginCmd(loginPath, username, remoteIP string) (string, []string) {
+ // Special handling for Arch Linux without /etc/pam.d/remote
+ var loginArgs []string
+ if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
+ loginArgs = []string{"-f", username, "-p"}
+ } else {
+ loginArgs = []string{"-f", username, "-h", remoteIP, "-p"}
+ }
+
+ // util-linux login requires setsid -c to create a new session and set the
+ // controlling terminal. Without this, vhangup() kills the parent process.
+ // See https://bugs.debian.org/1078023 for details.
+ // TODO: handle this via the executor using syscall.Setsid() + TIOCSCTTY + syscall.Exec()
+ // to avoid external setsid dependency.
+ if !s.loginIsUtilLinux {
+ return loginPath, loginArgs
+ }
+
+ setsidPath, err := exec.LookPath("setsid")
+ if err != nil {
+ log.Warnf("setsid not available but util-linux login detected, login may fail: %v", err)
+ return loginPath, loginArgs
+ }
+
+ args := append([]string{"-w", "-c", loginPath}, loginArgs...)
+ return setsidPath, args
+}
+
+// fileExists checks if a file exists
+func (s *Server) fileExists(path string) bool {
+ _, err := os.Stat(path)
+ return err == nil
+}
+
+// parseUserCredentials extracts numeric UID, GID, and supplementary groups
+func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []uint32, error) {
+ uid64, err := strconv.ParseUint(localUser.Uid, 10, 32)
+ if err != nil {
+ return 0, 0, nil, fmt.Errorf("invalid UID %s: %w", localUser.Uid, err)
+ }
+ uid := uint32(uid64)
+
+ gid64, err := strconv.ParseUint(localUser.Gid, 10, 32)
+ if err != nil {
+ return 0, 0, nil, fmt.Errorf("invalid GID %s: %w", localUser.Gid, err)
+ }
+ gid := uint32(gid64)
+
+ groups, err := s.getSupplementaryGroups(localUser.Username)
+ if err != nil {
+ log.Warnf("failed to get supplementary groups for user %s: %v", localUser.Username, err)
+ groups = []uint32{gid}
+ }
+
+ return uid, gid, groups, nil
+}
+
+// getSupplementaryGroups retrieves supplementary group IDs for a user
+func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) {
+ u, err := user.Lookup(username)
+ if err != nil {
+ return nil, fmt.Errorf("lookup user %s: %w", username, err)
+ }
+
+ groupIDStrings, err := u.GroupIds()
+ if err != nil {
+ return nil, fmt.Errorf("get group IDs for user %s: %w", username, err)
+ }
+
+ groups := make([]uint32, len(groupIDStrings))
+ for i, gidStr := range groupIDStrings {
+ gid64, err := strconv.ParseUint(gidStr, 10, 32)
+ if err != nil {
+ return nil, fmt.Errorf("invalid group ID %s for user %s: %w", gidStr, username, err)
+ }
+ groups[i] = uint32(gid64)
+ }
+
+ return groups, nil
+}
+
+// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping.
+// Returns the command and a cleanup function (no-op on Unix).
+func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
+ log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
+
+ if err := validateUsername(localUser.Username); err != nil {
+ return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
+ }
+
+ uid, gid, groups, err := s.parseUserCredentials(localUser)
+ if err != nil {
+ return nil, nil, fmt.Errorf("parse user credentials: %w", err)
+ }
+ privilegeDropper := NewPrivilegeDropper()
+ config := ExecutorConfig{
+ UID: uid,
+ GID: gid,
+ Groups: groups,
+ WorkingDir: localUser.HomeDir,
+ Shell: getUserShell(localUser.Uid),
+ Command: session.RawCommand(),
+ PTY: hasPty,
+ }
+
+ cmd, err := privilegeDropper.CreateExecutorCommand(session.Context(), config)
+ return cmd, func() {}, err
+}
+
+// enableUserSwitching is a no-op on Unix systems
+func enableUserSwitching() error {
+ return nil
+}
+
+// createPtyCommand creates the exec.Cmd for Pty execution respecting privilege check results
+func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) {
+ localUser := privilegeResult.User
+ if localUser == nil {
+ return nil, errors.New("no user in privilege result")
+ }
+
+ if privilegeResult.UsedFallback {
+ return s.createDirectPtyCommand(session, localUser, ptyReq), nil
+ }
+
+ return s.createPtyLoginCommand(localUser, ptyReq, session)
+}
+
+// createDirectPtyCommand creates a direct Pty command without privilege dropping
+func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.User, ptyReq ssh.Pty) *exec.Cmd {
+ log.Debugf("creating direct Pty command for user %s (no user switching needed)", localUser.Username)
+
+ shell := getUserShell(localUser.Uid)
+ args := s.getShellCommandArgs(shell, session.RawCommand())
+
+ cmd := exec.CommandContext(session.Context(), args[0], args[1:]...)
+ cmd.Dir = localUser.HomeDir
+ cmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
+
+ return cmd
+}
+
+// preparePtyEnv prepares environment variables for Pty execution
+func (s *Server) preparePtyEnv(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) []string {
+ termType := ptyReq.Term
+ if termType == "" {
+ termType = "xterm-256color"
+ }
+
+ env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
+ env = append(env, prepareSSHEnv(session)...)
+ env = append(env, fmt.Sprintf("TERM=%s", termType))
+
+ for _, v := range session.Environ() {
+ if acceptEnv(v) {
+ env = append(env, v)
+ }
+ }
+ return env
+}
diff --git a/client/ssh/server/userswitching_windows.go b/client/ssh/server/userswitching_windows.go
new file mode 100644
index 000000000..5a5f75fa4
--- /dev/null
+++ b/client/ssh/server/userswitching_windows.go
@@ -0,0 +1,274 @@
+//go:build windows
+
+package server
+
+import (
+ "errors"
+ "fmt"
+ "os/exec"
+ "os/user"
+ "strings"
+ "unsafe"
+
+ "github.com/gliderlabs/ssh"
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/sys/windows"
+)
+
+// validateUsername validates Windows usernames according to SAM Account Name rules
+func validateUsername(username string) error {
+ if username == "" {
+ return fmt.Errorf("username cannot be empty")
+ }
+
+ usernameToValidate := extractUsernameFromDomain(username)
+
+ if err := validateUsernameLength(usernameToValidate); err != nil {
+ return err
+ }
+
+ if err := validateUsernameCharacters(usernameToValidate); err != nil {
+ return err
+ }
+
+ if err := validateUsernameFormat(usernameToValidate); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// extractUsernameFromDomain extracts the username part from domain\username or username@domain format
+func extractUsernameFromDomain(username string) string {
+ if idx := strings.LastIndex(username, `\`); idx != -1 {
+ return username[idx+1:]
+ }
+ if idx := strings.Index(username, "@"); idx != -1 {
+ return username[:idx]
+ }
+ return username
+}
+
+// validateUsernameLength checks if username length is within Windows limits
+func validateUsernameLength(username string) error {
+ if len(username) > 20 {
+ return fmt.Errorf("username too long (max 20 characters for Windows)")
+ }
+ return nil
+}
+
+// validateUsernameCharacters checks for invalid characters in Windows usernames
+func validateUsernameCharacters(username string) error {
+ invalidChars := []rune{'"', '/', '[', ']', ':', ';', '|', '=', ',', '+', '*', '?', '<', '>', ' ', '`', '&', '\n'}
+ for _, char := range username {
+ for _, invalid := range invalidChars {
+ if char == invalid {
+ return fmt.Errorf("username contains invalid characters")
+ }
+ }
+ if char < 32 || char == 127 {
+ return fmt.Errorf("username contains control characters")
+ }
+ }
+ return nil
+}
+
+// validateUsernameFormat checks for invalid username formats and patterns
+func validateUsernameFormat(username string) error {
+ if username == "." || username == ".." {
+ return fmt.Errorf("username cannot be '.' or '..'")
+ }
+
+ if strings.HasSuffix(username, ".") {
+ return fmt.Errorf("username cannot end with a period")
+ }
+
+ return nil
+}
+
+// createExecutorCommand creates a command using Windows executor for privilege dropping.
+// Returns the command and a cleanup function that must be called after starting the process.
+func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
+ log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
+
+ username, _ := s.parseUsername(localUser.Username)
+ if err := validateUsername(username); err != nil {
+ return nil, nil, fmt.Errorf("invalid username %q: %w", username, err)
+ }
+
+ return s.createUserSwitchCommand(localUser, session, hasPty)
+}
+
+// createUserSwitchCommand creates a command with Windows user switching.
+// Returns the command and a cleanup function that must be called after starting the process.
+func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) {
+ username, domain := s.parseUsername(localUser.Username)
+
+ shell := getUserShell(localUser.Uid)
+
+ rawCmd := session.RawCommand()
+ var command string
+ if rawCmd != "" {
+ command = rawCmd
+ }
+
+ config := WindowsExecutorConfig{
+ Username: username,
+ Domain: domain,
+ WorkingDir: localUser.HomeDir,
+ Shell: shell,
+ Command: command,
+ Interactive: interactive || (rawCmd == ""),
+ }
+
+ dropper := NewPrivilegeDropper()
+ cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ cleanup := func() {
+ if token != 0 {
+ if err := windows.CloseHandle(windows.Handle(token)); err != nil {
+ log.Debugf("close primary token: %v", err)
+ }
+ }
+ }
+
+ return cmd, cleanup, nil
+}
+
+// parseUsername extracts username and domain from a Windows username
+func (s *Server) parseUsername(fullUsername string) (username, domain string) {
+ // Handle DOMAIN\username format
+ if idx := strings.LastIndex(fullUsername, `\`); idx != -1 {
+ domain = fullUsername[:idx]
+ username = fullUsername[idx+1:]
+ return username, domain
+ }
+
+ // Handle username@domain format
+ if username, domain, ok := strings.Cut(fullUsername, "@"); ok {
+ return username, domain
+ }
+
+ // Local user (no domain)
+ return fullUsername, "."
+}
+
+// hasPrivilege checks if the current process has a specific privilege
+func hasPrivilege(token windows.Handle, privilegeName string) (bool, error) {
+ var luid windows.LUID
+ if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil {
+ return false, fmt.Errorf("lookup privilege value: %w", err)
+ }
+
+ var returnLength uint32
+ err := windows.GetTokenInformation(
+ windows.Token(token),
+ windows.TokenPrivileges,
+ nil, // null buffer to get size
+ 0,
+ &returnLength,
+ )
+
+ if err != nil && !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
+ return false, fmt.Errorf("get token information size: %w", err)
+ }
+
+ buffer := make([]byte, returnLength)
+ err = windows.GetTokenInformation(
+ windows.Token(token),
+ windows.TokenPrivileges,
+ &buffer[0],
+ returnLength,
+ &returnLength,
+ )
+ if err != nil {
+ return false, fmt.Errorf("get token information: %w", err)
+ }
+
+ privileges := (*windows.Tokenprivileges)(unsafe.Pointer(&buffer[0]))
+
+ // Check if the privilege is present and enabled
+ for i := uint32(0); i < privileges.PrivilegeCount; i++ {
+ privilege := (*windows.LUIDAndAttributes)(unsafe.Pointer(
+ uintptr(unsafe.Pointer(&privileges.Privileges[0])) +
+ uintptr(i)*unsafe.Sizeof(windows.LUIDAndAttributes{}),
+ ))
+ if privilege.Luid == luid {
+ return (privilege.Attributes & windows.SE_PRIVILEGE_ENABLED) != 0, nil
+ }
+ }
+
+ return false, nil
+}
+
+// enablePrivilege enables a specific privilege for the current process token
+// This is required because privileges like SeAssignPrimaryTokenPrivilege are present
+// but disabled by default, even for the SYSTEM account
+func enablePrivilege(token windows.Handle, privilegeName string) error {
+ var luid windows.LUID
+ if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil {
+ return fmt.Errorf("lookup privilege value for %s: %w", privilegeName, err)
+ }
+
+ privileges := windows.Tokenprivileges{
+ PrivilegeCount: 1,
+ Privileges: [1]windows.LUIDAndAttributes{
+ {
+ Luid: luid,
+ Attributes: windows.SE_PRIVILEGE_ENABLED,
+ },
+ },
+ }
+
+ err := windows.AdjustTokenPrivileges(
+ windows.Token(token),
+ false,
+ &privileges,
+ 0,
+ nil,
+ nil,
+ )
+ if err != nil {
+ return fmt.Errorf("adjust token privileges for %s: %w", privilegeName, err)
+ }
+
+ hasPriv, err := hasPrivilege(token, privilegeName)
+ if err != nil {
+ return fmt.Errorf("verify privilege %s after enabling: %w", privilegeName, err)
+ }
+ if !hasPriv {
+ return fmt.Errorf("privilege %s could not be enabled (may not be granted to account)", privilegeName)
+ }
+
+ log.Debugf("Successfully enabled privilege %s for current process", privilegeName)
+ return nil
+}
+
+// enableUserSwitching enables required privileges for Windows user switching
+func enableUserSwitching() error {
+ process := windows.CurrentProcess()
+
+ var token windows.Token
+ err := windows.OpenProcessToken(
+ process,
+ windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY,
+ &token,
+ )
+ if err != nil {
+ return fmt.Errorf("open process token: %w", err)
+ }
+ defer func() {
+ if err := windows.CloseHandle(windows.Handle(token)); err != nil {
+ log.Debugf("Failed to close process token: %v", err)
+ }
+ }()
+
+ if err := enablePrivilege(windows.Handle(token), "SeAssignPrimaryTokenPrivilege"); err != nil {
+ return fmt.Errorf("enable SeAssignPrimaryTokenPrivilege: %w", err)
+ }
+ log.Infof("Windows user switching privileges enabled successfully")
+ return nil
+}
diff --git a/client/ssh/server/winpty/conpty.go b/client/ssh/server/winpty/conpty.go
new file mode 100644
index 000000000..0f3659ffe
--- /dev/null
+++ b/client/ssh/server/winpty/conpty.go
@@ -0,0 +1,487 @@
+//go:build windows
+
+package winpty
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "strings"
+ "sync"
+ "syscall"
+ "unsafe"
+
+ "github.com/gliderlabs/ssh"
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/sys/windows"
+)
+
+var (
+ ErrEmptyEnvironment = errors.New("empty environment")
+)
+
+const (
+ extendedStartupInfoPresent = 0x00080000
+ createUnicodeEnvironment = 0x00000400
+ procThreadAttributePseudoConsole = 0x00020016
+
+ PowerShellCommandFlag = "-Command"
+
+ errCloseInputRead = "close input read handle: %v"
+ errCloseConPtyCleanup = "close ConPty handle during cleanup"
+)
+
+// PtyConfig holds configuration for Pty execution.
+type PtyConfig struct {
+ Shell string
+ Command string
+ Width int
+ Height int
+ WorkingDir string
+}
+
+// UserConfig holds user execution configuration.
+type UserConfig struct {
+ Token windows.Handle
+ Environment []string
+}
+
+var (
+ kernel32 = windows.NewLazySystemDLL("kernel32.dll")
+ procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole")
+ procInitializeProcThreadAttributeList = kernel32.NewProc("InitializeProcThreadAttributeList")
+ procUpdateProcThreadAttribute = kernel32.NewProc("UpdateProcThreadAttribute")
+ procDeleteProcThreadAttributeList = kernel32.NewProc("DeleteProcThreadAttributeList")
+)
+
+// ExecutePtyWithUserToken executes a command with ConPty using user token.
+func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
+ args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command)
+ commandLine := buildCommandLine(args)
+
+ config := ExecutionConfig{
+ Pty: ptyConfig,
+ User: userConfig,
+ Session: session,
+ Context: ctx,
+ }
+
+ return executeConPtyWithConfig(commandLine, config)
+}
+
+// ExecutionConfig holds all execution configuration.
+type ExecutionConfig struct {
+ Pty PtyConfig
+ User UserConfig
+ Session ssh.Session
+ Context context.Context
+}
+
+// executeConPtyWithConfig creates ConPty and executes process with configuration.
+func executeConPtyWithConfig(commandLine string, config ExecutionConfig) error {
+ ctx := config.Context
+ session := config.Session
+ width := config.Pty.Width
+ height := config.Pty.Height
+ userToken := config.User.Token
+ userEnv := config.User.Environment
+ workingDir := config.Pty.WorkingDir
+
+ inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
+ if err != nil {
+ return fmt.Errorf("create ConPty pipes: %w", err)
+ }
+
+ hPty, err := createConPty(width, height, inputRead, outputWrite)
+ if err != nil {
+ return fmt.Errorf("create ConPty: %w", err)
+ }
+
+ primaryToken, err := duplicateToPrimaryToken(userToken)
+ if err != nil {
+ if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
+ log.Debugf(errCloseConPtyCleanup)
+ }
+ closeHandles(inputRead, inputWrite, outputRead, outputWrite)
+ return fmt.Errorf("duplicate to primary token: %w", err)
+ }
+ defer func() {
+ if err := windows.CloseHandle(primaryToken); err != nil {
+ log.Debugf("close primary token: %v", err)
+ }
+ }()
+
+ siEx, err := setupConPtyStartupInfo(hPty)
+ if err != nil {
+ if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
+ log.Debugf(errCloseConPtyCleanup)
+ }
+ closeHandles(inputRead, inputWrite, outputRead, outputWrite)
+ return fmt.Errorf("setup startup info: %w", err)
+ }
+ defer func() {
+ _, _, _ = procDeleteProcThreadAttributeList.Call(uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)))
+ }()
+
+ pi, err := createConPtyProcess(commandLine, primaryToken, userEnv, workingDir, siEx)
+ if err != nil {
+ if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
+ log.Debugf(errCloseConPtyCleanup)
+ }
+ closeHandles(inputRead, inputWrite, outputRead, outputWrite)
+ return fmt.Errorf("create process as user with ConPty: %w", err)
+ }
+ defer closeProcessInfo(pi)
+
+ if err := windows.CloseHandle(inputRead); err != nil {
+ log.Debugf(errCloseInputRead, err)
+ }
+ if err := windows.CloseHandle(outputWrite); err != nil {
+ log.Debugf("close output write handle: %v", err)
+ }
+
+ return bridgeConPtyIO(ctx, hPty, inputWrite, outputRead, session, session, session, pi.Process)
+}
+
+// createConPtyPipes creates input/output pipes for ConPty.
+func createConPtyPipes() (inputRead, inputWrite, outputRead, outputWrite windows.Handle, err error) {
+ if err := windows.CreatePipe(&inputRead, &inputWrite, nil, 0); err != nil {
+ return 0, 0, 0, 0, fmt.Errorf("create input pipe: %w", err)
+ }
+
+ if err := windows.CreatePipe(&outputRead, &outputWrite, nil, 0); err != nil {
+ if closeErr := windows.CloseHandle(inputRead); closeErr != nil {
+ log.Debugf(errCloseInputRead, closeErr)
+ }
+ if closeErr := windows.CloseHandle(inputWrite); closeErr != nil {
+ log.Debugf("close input write handle: %v", closeErr)
+ }
+ return 0, 0, 0, 0, fmt.Errorf("create output pipe: %w", err)
+ }
+
+ return inputRead, inputWrite, outputRead, outputWrite, nil
+}
+
+// createConPty creates a Windows ConPty with the specified size and pipe handles.
+func createConPty(width, height int, inputRead, outputWrite windows.Handle) (windows.Handle, error) {
+ size := windows.Coord{X: int16(width), Y: int16(height)}
+
+ var hPty windows.Handle
+ if err := windows.CreatePseudoConsole(size, inputRead, outputWrite, 0, &hPty); err != nil {
+ return 0, fmt.Errorf("CreatePseudoConsole: %w", err)
+ }
+
+ return hPty, nil
+}
+
+// setupConPtyStartupInfo prepares the STARTUPINFOEX with ConPty attributes.
+func setupConPtyStartupInfo(hPty windows.Handle) (*windows.StartupInfoEx, error) {
+ var siEx windows.StartupInfoEx
+ siEx.StartupInfo.Cb = uint32(unsafe.Sizeof(siEx))
+
+ var attrListSize uintptr
+ ret, _, _ := procInitializeProcThreadAttributeList.Call(0, 1, 0, uintptr(unsafe.Pointer(&attrListSize)))
+ if ret == 0 && attrListSize == 0 {
+ return nil, fmt.Errorf("get attribute list size")
+ }
+
+ attrListBytes := make([]byte, attrListSize)
+ siEx.ProcThreadAttributeList = (*windows.ProcThreadAttributeList)(unsafe.Pointer(&attrListBytes[0]))
+
+ ret, _, err := procInitializeProcThreadAttributeList.Call(
+ uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)),
+ 1,
+ 0,
+ uintptr(unsafe.Pointer(&attrListSize)),
+ )
+ if ret == 0 {
+ return nil, fmt.Errorf("initialize attribute list: %w", err)
+ }
+
+ ret, _, err = procUpdateProcThreadAttribute.Call(
+ uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)),
+ 0,
+ procThreadAttributePseudoConsole,
+ uintptr(hPty),
+ unsafe.Sizeof(hPty),
+ 0,
+ 0,
+ )
+ if ret == 0 {
+ return nil, fmt.Errorf("update thread attribute: %w", err)
+ }
+
+ return &siEx, nil
+}
+
+// createConPtyProcess creates the actual process with ConPty.
+func createConPtyProcess(commandLine string, userToken windows.Handle, userEnv []string, workingDir string, siEx *windows.StartupInfoEx) (*windows.ProcessInformation, error) {
+ var pi windows.ProcessInformation
+ creationFlags := uint32(extendedStartupInfoPresent | createUnicodeEnvironment)
+
+ commandLinePtr, err := windows.UTF16PtrFromString(commandLine)
+ if err != nil {
+ return nil, fmt.Errorf("convert command line to UTF16: %w", err)
+ }
+
+ envPtr, err := convertEnvironmentToUTF16(userEnv)
+ if err != nil {
+ return nil, err
+ }
+
+ var workingDirPtr *uint16
+ if workingDir != "" {
+ workingDirPtr, err = windows.UTF16PtrFromString(workingDir)
+ if err != nil {
+ return nil, fmt.Errorf("convert working directory to UTF16: %w", err)
+ }
+ }
+
+ siEx.StartupInfo.Flags |= windows.STARTF_USESTDHANDLES
+ siEx.StartupInfo.StdInput = windows.Handle(0)
+ siEx.StartupInfo.StdOutput = windows.Handle(0)
+ siEx.StartupInfo.StdErr = siEx.StartupInfo.StdOutput
+
+ if userToken != windows.InvalidHandle {
+ err = windows.CreateProcessAsUser(
+ windows.Token(userToken),
+ nil,
+ commandLinePtr,
+ nil,
+ nil,
+ true,
+ creationFlags,
+ envPtr,
+ workingDirPtr,
+ &siEx.StartupInfo,
+ &pi,
+ )
+ } else {
+ err = windows.CreateProcess(
+ nil,
+ commandLinePtr,
+ nil,
+ nil,
+ true,
+ creationFlags,
+ envPtr,
+ workingDirPtr,
+ &siEx.StartupInfo,
+ &pi,
+ )
+ }
+
+ if err != nil {
+ return nil, fmt.Errorf("create process: %w", err)
+ }
+
+ return &pi, nil
+}
+
+// convertEnvironmentToUTF16 converts environment variables to Windows UTF16 format.
+func convertEnvironmentToUTF16(userEnv []string) (*uint16, error) {
+ if len(userEnv) == 0 {
+ // Return nil pointer for empty environment - Windows API will inherit parent environment
+ return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment
+ }
+
+ var envUTF16 []uint16
+ for _, envVar := range userEnv {
+ if envVar != "" {
+ utf16Str, err := windows.UTF16FromString(envVar)
+ if err != nil {
+ log.Debugf("skipping invalid environment variable: %s (error: %v)", envVar, err)
+ continue
+ }
+ envUTF16 = append(envUTF16, utf16Str[:len(utf16Str)-1]...)
+ envUTF16 = append(envUTF16, 0)
+ }
+ }
+ envUTF16 = append(envUTF16, 0)
+
+ if len(envUTF16) > 0 {
+ return &envUTF16[0], nil
+ }
+ // Return nil pointer when no valid environment variables found
+ return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment
+}
+
+// duplicateToPrimaryToken converts an impersonation token to a primary token.
+func duplicateToPrimaryToken(token windows.Handle) (windows.Handle, error) {
+ var primaryToken windows.Handle
+ if err := windows.DuplicateTokenEx(
+ windows.Token(token),
+ windows.TOKEN_ALL_ACCESS,
+ nil,
+ windows.SecurityImpersonation,
+ windows.TokenPrimary,
+ (*windows.Token)(&primaryToken),
+ ); err != nil {
+ return 0, fmt.Errorf("duplicate token: %w", err)
+ }
+ return primaryToken, nil
+}
+
+// SessionExiter provides the Exit method for reporting process exit status.
+type SessionExiter interface {
+ Exit(code int) error
+}
+
+// bridgeConPtyIO handles I/O bridging between ConPty and readers/writers.
+func bridgeConPtyIO(ctx context.Context, hPty, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer, session SessionExiter, process windows.Handle) error {
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+
+ var wg sync.WaitGroup
+ startIOBridging(ctx, &wg, inputWrite, outputRead, reader, writer)
+
+ processErr := waitForProcess(ctx, process)
+ if processErr != nil {
+ return processErr
+ }
+
+ var exitCode uint32
+ if err := windows.GetExitCodeProcess(process, &exitCode); err != nil {
+ log.Debugf("get exit code: %v", err)
+ } else {
+ if err := session.Exit(int(exitCode)); err != nil {
+ log.Debugf("report exit code: %v", err)
+ }
+ }
+
+ // Clean up in the original order after process completes
+ if err := reader.Close(); err != nil {
+ log.Debugf("close reader: %v", err)
+ }
+
+ ret, _, err := procClosePseudoConsole.Call(uintptr(hPty))
+ if ret == 0 {
+ log.Debugf("close ConPty handle: %v", err)
+ }
+
+ wg.Wait()
+
+ if err := windows.CloseHandle(outputRead); err != nil {
+ log.Debugf("close output read handle: %v", err)
+ }
+
+ return nil
+}
+
+// startIOBridging starts the I/O bridging goroutines.
+func startIOBridging(ctx context.Context, wg *sync.WaitGroup, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer) {
+ wg.Add(2)
+
+ // Input: reader (SSH session) -> inputWrite (ConPty)
+ go func() {
+ defer wg.Done()
+ defer func() {
+ if err := windows.CloseHandle(inputWrite); err != nil {
+ log.Debugf("close input write handle in goroutine: %v", err)
+ }
+ }()
+
+ if _, err := io.Copy(&windowsHandleWriter{handle: inputWrite}, reader); err != nil {
+ log.Debugf("input copy ended with error: %v", err)
+ }
+ }()
+
+ // Output: outputRead (ConPty) -> writer (SSH session)
+ go func() {
+ defer wg.Done()
+ if _, err := io.Copy(writer, &windowsHandleReader{handle: outputRead}); err != nil {
+ log.Debugf("output copy ended with error: %v", err)
+ }
+ }()
+}
+
+// waitForProcess waits for process completion with context cancellation.
+func waitForProcess(ctx context.Context, process windows.Handle) error {
+ if _, err := windows.WaitForSingleObject(process, windows.INFINITE); err != nil {
+ return fmt.Errorf("wait for process %d: %w", process, err)
+ }
+ return nil
+}
+
+// buildShellArgs builds shell arguments for ConPty execution.
+func buildShellArgs(shell, command string) []string {
+ if command != "" {
+ return []string{shell, PowerShellCommandFlag, command}
+ }
+ return []string{shell}
+}
+
+// buildCommandLine builds a Windows command line from arguments using proper escaping.
+func buildCommandLine(args []string) string {
+ if len(args) == 0 {
+ return ""
+ }
+
+ var result strings.Builder
+ for i, arg := range args {
+ if i > 0 {
+ result.WriteString(" ")
+ }
+ result.WriteString(syscall.EscapeArg(arg))
+ }
+ return result.String()
+}
+
+// closeHandles closes multiple Windows handles.
+func closeHandles(handles ...windows.Handle) {
+ for _, handle := range handles {
+ if handle != windows.InvalidHandle {
+ if err := windows.CloseHandle(handle); err != nil {
+ log.Debugf("close handle: %v", err)
+ }
+ }
+ }
+}
+
+// closeProcessInfo closes process and thread handles.
+func closeProcessInfo(pi *windows.ProcessInformation) {
+ if pi != nil {
+ if err := windows.CloseHandle(pi.Process); err != nil {
+ log.Debugf("close process handle: %v", err)
+ }
+ if err := windows.CloseHandle(pi.Thread); err != nil {
+ log.Debugf("close thread handle: %v", err)
+ }
+ }
+}
+
+// windowsHandleReader wraps a Windows handle for reading.
+type windowsHandleReader struct {
+ handle windows.Handle
+}
+
+func (r *windowsHandleReader) Read(p []byte) (n int, err error) {
+ var bytesRead uint32
+ if err := windows.ReadFile(r.handle, p, &bytesRead, nil); err != nil {
+ return 0, err
+ }
+ return int(bytesRead), nil
+}
+
+func (r *windowsHandleReader) Close() error {
+ return windows.CloseHandle(r.handle)
+}
+
+// windowsHandleWriter wraps a Windows handle for writing.
+type windowsHandleWriter struct {
+ handle windows.Handle
+}
+
+func (w *windowsHandleWriter) Write(p []byte) (n int, err error) {
+ var bytesWritten uint32
+ if err := windows.WriteFile(w.handle, p, &bytesWritten, nil); err != nil {
+ return 0, err
+ }
+ return int(bytesWritten), nil
+}
+
+func (w *windowsHandleWriter) Close() error {
+ return windows.CloseHandle(w.handle)
+}
diff --git a/client/ssh/server/winpty/conpty_test.go b/client/ssh/server/winpty/conpty_test.go
new file mode 100644
index 000000000..4f04e1fad
--- /dev/null
+++ b/client/ssh/server/winpty/conpty_test.go
@@ -0,0 +1,290 @@
+//go:build windows
+
+package winpty
+
+import (
+ "testing"
+
+ log "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "golang.org/x/sys/windows"
+)
+
+func TestBuildShellArgs(t *testing.T) {
+ tests := []struct {
+ name string
+ shell string
+ command string
+ expected []string
+ }{
+ {
+ name: "Shell with command",
+ shell: "powershell.exe",
+ command: "Get-Process",
+ expected: []string{"powershell.exe", "-Command", "Get-Process"},
+ },
+ {
+ name: "CMD with command",
+ shell: "cmd.exe",
+ command: "dir",
+ expected: []string{"cmd.exe", "-Command", "dir"},
+ },
+ {
+ name: "Shell interactive",
+ shell: "powershell.exe",
+ command: "",
+ expected: []string{"powershell.exe"},
+ },
+ {
+ name: "CMD interactive",
+ shell: "cmd.exe",
+ command: "",
+ expected: []string{"cmd.exe"},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := buildShellArgs(tt.shell, tt.command)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestBuildCommandLine(t *testing.T) {
+ tests := []struct {
+ name string
+ args []string
+ expected string
+ }{
+ {
+ name: "Simple args",
+ args: []string{"cmd.exe", "/c", "echo"},
+ expected: "cmd.exe /c echo",
+ },
+ {
+ name: "Args with spaces",
+ args: []string{"Program Files\\app.exe", "arg with spaces"},
+ expected: `"Program Files\app.exe" "arg with spaces"`,
+ },
+ {
+ name: "Args with quotes",
+ args: []string{"cmd.exe", "/c", `echo "hello world"`},
+ expected: `cmd.exe /c "echo \"hello world\""`,
+ },
+ {
+ name: "PowerShell calling PowerShell",
+ args: []string{"powershell.exe", "-Command", `powershell.exe -Command "Get-Process | Where-Object {$_.Name -eq 'notepad'}"`},
+ expected: `powershell.exe -Command "powershell.exe -Command \"Get-Process | Where-Object {$_.Name -eq 'notepad'}\""`,
+ },
+ {
+ name: "Complex nested quotes",
+ args: []string{"cmd.exe", "/c", `echo "He said \"Hello\" to me"`},
+ expected: `cmd.exe /c "echo \"He said \\\"Hello\\\" to me\""`,
+ },
+ {
+ name: "Path with spaces and args",
+ args: []string{`C:\Program Files\MyApp\app.exe`, "--config", `C:\My Config\settings.json`},
+ expected: `"C:\Program Files\MyApp\app.exe" --config "C:\My Config\settings.json"`,
+ },
+ {
+ name: "Empty argument",
+ args: []string{"cmd.exe", "/c", "echo", ""},
+ expected: `cmd.exe /c echo ""`,
+ },
+ {
+ name: "Argument with backslashes",
+ args: []string{"robocopy", `C:\Source\`, `C:\Dest\`, "/E"},
+ expected: `robocopy C:\Source\ C:\Dest\ /E`,
+ },
+ {
+ name: "Empty args",
+ args: []string{},
+ expected: "",
+ },
+ {
+ name: "Single arg with space",
+ args: []string{"path with spaces"},
+ expected: `"path with spaces"`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := buildCommandLine(tt.args)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestCreateConPtyPipes(t *testing.T) {
+ inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
+ require.NoError(t, err, "Should create ConPty pipes successfully")
+
+ // Verify all handles are valid
+ assert.NotEqual(t, windows.InvalidHandle, inputRead, "Input read handle should be valid")
+ assert.NotEqual(t, windows.InvalidHandle, inputWrite, "Input write handle should be valid")
+ assert.NotEqual(t, windows.InvalidHandle, outputRead, "Output read handle should be valid")
+ assert.NotEqual(t, windows.InvalidHandle, outputWrite, "Output write handle should be valid")
+
+ // Clean up handles
+ closeHandles(inputRead, inputWrite, outputRead, outputWrite)
+}
+
+func TestCreateConPty(t *testing.T) {
+ inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
+ require.NoError(t, err, "Should create ConPty pipes successfully")
+ defer closeHandles(inputRead, inputWrite, outputRead, outputWrite)
+
+ hPty, err := createConPty(80, 24, inputRead, outputWrite)
+ require.NoError(t, err, "Should create ConPty successfully")
+ assert.NotEqual(t, windows.InvalidHandle, hPty, "ConPty handle should be valid")
+
+ // Clean up ConPty
+ ret, _, _ := procClosePseudoConsole.Call(uintptr(hPty))
+ assert.NotEqual(t, uintptr(0), ret, "Should close ConPty successfully")
+}
+
+func TestConvertEnvironmentToUTF16(t *testing.T) {
+ tests := []struct {
+ name string
+ userEnv []string
+ hasError bool
+ }{
+ {
+ name: "Valid environment variables",
+ userEnv: []string{"PATH=C:\\Windows", "USER=testuser", "HOME=C:\\Users\\testuser"},
+ hasError: false,
+ },
+ {
+ name: "Empty environment",
+ userEnv: []string{},
+ hasError: false,
+ },
+ {
+ name: "Environment with empty strings",
+ userEnv: []string{"PATH=C:\\Windows", "", "USER=testuser"},
+ hasError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := convertEnvironmentToUTF16(tt.userEnv)
+ if tt.hasError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ if len(tt.userEnv) == 0 {
+ assert.Nil(t, result, "Empty environment should return nil")
+ } else {
+ assert.NotNil(t, result, "Non-empty environment should return valid pointer")
+ }
+ }
+ })
+ }
+}
+
+func TestDuplicateToPrimaryToken(t *testing.T) {
+ if testing.Short() {
+ t.Skip("Skipping token tests in short mode")
+ }
+
+ // Get current process token for testing
+ var token windows.Token
+ err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_ALL_ACCESS, &token)
+ require.NoError(t, err, "Should open current process token")
+ defer func() {
+ if err := windows.CloseHandle(windows.Handle(token)); err != nil {
+ t.Logf("Failed to close token: %v", err)
+ }
+ }()
+
+ primaryToken, err := duplicateToPrimaryToken(windows.Handle(token))
+ require.NoError(t, err, "Should duplicate token to primary")
+ assert.NotEqual(t, windows.InvalidHandle, primaryToken, "Primary token should be valid")
+
+ // Clean up
+ err = windows.CloseHandle(primaryToken)
+ assert.NoError(t, err, "Should close primary token")
+}
+
+func TestWindowsHandleReader(t *testing.T) {
+ // Create a pipe for testing
+ var readHandle, writeHandle windows.Handle
+ err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0)
+ require.NoError(t, err, "Should create pipe for testing")
+ defer closeHandles(readHandle, writeHandle)
+
+ // Write test data
+ testData := []byte("Hello, Windows Handle Reader!")
+ var bytesWritten uint32
+ err = windows.WriteFile(writeHandle, testData, &bytesWritten, nil)
+ require.NoError(t, err, "Should write test data")
+ require.Equal(t, uint32(len(testData)), bytesWritten, "Should write all test data")
+
+ // Close write handle to signal EOF
+ if err := windows.CloseHandle(writeHandle); err != nil {
+ t.Fatalf("Should close write handle: %v", err)
+ }
+ writeHandle = windows.InvalidHandle
+
+ // Test reading
+ reader := &windowsHandleReader{handle: readHandle}
+ buffer := make([]byte, len(testData))
+ n, err := reader.Read(buffer)
+ require.NoError(t, err, "Should read from handle")
+ assert.Equal(t, len(testData), n, "Should read expected number of bytes")
+ assert.Equal(t, testData, buffer, "Should read expected data")
+}
+
+func TestWindowsHandleWriter(t *testing.T) {
+ // Create a pipe for testing
+ var readHandle, writeHandle windows.Handle
+ err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0)
+ require.NoError(t, err, "Should create pipe for testing")
+ defer closeHandles(readHandle, writeHandle)
+
+ // Test writing
+ testData := []byte("Hello, Windows Handle Writer!")
+ writer := &windowsHandleWriter{handle: writeHandle}
+ n, err := writer.Write(testData)
+ require.NoError(t, err, "Should write to handle")
+ assert.Equal(t, len(testData), n, "Should write expected number of bytes")
+
+ // Close write handle
+ if err := windows.CloseHandle(writeHandle); err != nil {
+ t.Fatalf("Should close write handle: %v", err)
+ }
+
+ // Verify data was written by reading it back
+ buffer := make([]byte, len(testData))
+ var bytesRead uint32
+ err = windows.ReadFile(readHandle, buffer, &bytesRead, nil)
+ require.NoError(t, err, "Should read back written data")
+ assert.Equal(t, uint32(len(testData)), bytesRead, "Should read back expected number of bytes")
+ assert.Equal(t, testData, buffer, "Should read back expected data")
+}
+
+// BenchmarkConPtyCreation benchmarks ConPty creation performance
+func BenchmarkConPtyCreation(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ hPty, err := createConPty(80, 24, inputRead, outputWrite)
+ if err != nil {
+ closeHandles(inputRead, inputWrite, outputRead, outputWrite)
+ b.Fatal(err)
+ }
+
+ // Clean up
+ if ret, _, err := procClosePseudoConsole.Call(uintptr(hPty)); ret == 0 {
+ log.Debugf("ClosePseudoConsole failed: %v", err)
+ }
+ closeHandles(inputRead, inputWrite, outputRead, outputWrite)
+ }
+}
diff --git a/client/ssh/server_mock.go b/client/ssh/server_mock.go
deleted file mode 100644
index 76f43fd4e..000000000
--- a/client/ssh/server_mock.go
+++ /dev/null
@@ -1,46 +0,0 @@
-//go:build !js
-
-package ssh
-
-import "context"
-
-// MockServer mocks ssh.Server
-type MockServer struct {
- Ctx context.Context
- StopFunc func() error
- StartFunc func() error
- AddAuthorizedKeyFunc func(peer, newKey string) error
- RemoveAuthorizedKeyFunc func(peer string)
-}
-
-// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys
-func (srv *MockServer) RemoveAuthorizedKey(peer string) {
- if srv.RemoveAuthorizedKeyFunc == nil {
- return
- }
- srv.RemoveAuthorizedKeyFunc(peer)
-}
-
-// AddAuthorizedKey add a given peer key to server authorized keys
-func (srv *MockServer) AddAuthorizedKey(peer, newKey string) error {
- if srv.AddAuthorizedKeyFunc == nil {
- return nil
- }
- return srv.AddAuthorizedKeyFunc(peer, newKey)
-}
-
-// Stop stops SSH server.
-func (srv *MockServer) Stop() error {
- if srv.StopFunc == nil {
- return nil
- }
- return srv.StopFunc()
-}
-
-// Start starts SSH server. Blocking
-func (srv *MockServer) Start() error {
- if srv.StartFunc == nil {
- return nil
- }
- return srv.StartFunc()
-}
diff --git a/client/ssh/server_test.go b/client/ssh/server_test.go
deleted file mode 100644
index 1f310c2bb..000000000
--- a/client/ssh/server_test.go
+++ /dev/null
@@ -1,123 +0,0 @@
-//go:build !js
-
-package ssh
-
-import (
- "fmt"
- "github.com/stretchr/testify/assert"
- "golang.org/x/crypto/ssh"
- "strings"
- "testing"
-)
-
-func TestServer_AddAuthorizedKey(t *testing.T) {
- key, err := GeneratePrivateKey(ED25519)
- if err != nil {
- t.Fatal(err)
- }
- server, err := newDefaultServer(key, "localhost:")
- if err != nil {
- t.Fatal(err)
- }
-
- // add multiple keys
- keys := map[string][]byte{}
- for i := 0; i < 10; i++ {
- peer := fmt.Sprintf("%s-%d", "remotePeer", i)
- remotePrivKey, err := GeneratePrivateKey(ED25519)
- if err != nil {
- t.Fatal(err)
- }
- remotePubKey, err := GeneratePublicKey(remotePrivKey)
- if err != nil {
- t.Fatal(err)
- }
-
- err = server.AddAuthorizedKey(peer, string(remotePubKey))
- if err != nil {
- t.Error(err)
- }
- keys[peer] = remotePubKey
- }
-
- // make sure that all keys have been added
- for peer, remotePubKey := range keys {
- k, ok := server.authorizedKeys[peer]
- assert.True(t, ok, "expecting remotePeer key to be found in authorizedKeys")
-
- assert.Equal(t, string(remotePubKey), strings.TrimSpace(string(ssh.MarshalAuthorizedKey(k))))
- }
-
-}
-
-func TestServer_RemoveAuthorizedKey(t *testing.T) {
- key, err := GeneratePrivateKey(ED25519)
- if err != nil {
- t.Fatal(err)
- }
- server, err := newDefaultServer(key, "localhost:")
- if err != nil {
- t.Fatal(err)
- }
-
- remotePrivKey, err := GeneratePrivateKey(ED25519)
- if err != nil {
- t.Fatal(err)
- }
- remotePubKey, err := GeneratePublicKey(remotePrivKey)
- if err != nil {
- t.Fatal(err)
- }
-
- err = server.AddAuthorizedKey("remotePeer", string(remotePubKey))
- if err != nil {
- t.Error(err)
- }
-
- server.RemoveAuthorizedKey("remotePeer")
-
- _, ok := server.authorizedKeys["remotePeer"]
- assert.False(t, ok, "expecting remotePeer's SSH key to be removed")
-}
-
-func TestServer_PubKeyHandler(t *testing.T) {
- key, err := GeneratePrivateKey(ED25519)
- if err != nil {
- t.Fatal(err)
- }
- server, err := newDefaultServer(key, "localhost:")
- if err != nil {
- t.Fatal(err)
- }
-
- var keys []ssh.PublicKey
- for i := 0; i < 10; i++ {
- peer := fmt.Sprintf("%s-%d", "remotePeer", i)
- remotePrivKey, err := GeneratePrivateKey(ED25519)
- if err != nil {
- t.Fatal(err)
- }
- remotePubKey, err := GeneratePublicKey(remotePrivKey)
- if err != nil {
- t.Fatal(err)
- }
-
- remoteParsedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(remotePubKey)
- if err != nil {
- t.Fatal(err)
- }
-
- err = server.AddAuthorizedKey(peer, string(remotePubKey))
- if err != nil {
- t.Error(err)
- }
- keys = append(keys, remoteParsedPubKey)
- }
-
- for _, key := range keys {
- accepted := server.publicKeyHandler(nil, key)
-
- assert.Truef(t, accepted, "expecting SSH connection to be accepted for a given SSH key %s", string(ssh.MarshalAuthorizedKey(key)))
- }
-
-}
diff --git a/client/ssh/util.go b/client/ssh/ssh.go
similarity index 86%
rename from client/ssh/util.go
rename to client/ssh/ssh.go
index a54a609bc..c0024c599 100644
--- a/client/ssh/util.go
+++ b/client/ssh/ssh.go
@@ -32,9 +32,8 @@ const RSA KeyType = "rsa"
// RSAKeySize is a size of newly generated RSA key
const RSAKeySize = 2048
-// GeneratePrivateKey creates RSA Private Key of specified byte size
+// GeneratePrivateKey creates a private key of the specified type.
func GeneratePrivateKey(keyType KeyType) ([]byte, error) {
-
var key crypto.Signer
var err error
switch keyType {
@@ -59,7 +58,7 @@ func GeneratePrivateKey(keyType KeyType) ([]byte, error) {
return pemBytes, nil
}
-// GeneratePublicKey returns the public part of the private key
+// GeneratePublicKey returns the public part of the private key.
func GeneratePublicKey(key []byte) ([]byte, error) {
signer, err := gossh.ParsePrivateKey(key)
if err != nil {
@@ -70,20 +69,17 @@ func GeneratePublicKey(key []byte) ([]byte, error) {
return []byte(strKey), nil
}
-// EncodePrivateKeyToPEM encodes Private Key from RSA to PEM format
+// EncodePrivateKeyToPEM encodes a private key to PEM format.
func EncodePrivateKeyToPEM(privateKey crypto.Signer) ([]byte, error) {
mk, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
return nil, err
}
- // pem.Block
privBlock := pem.Block{
Type: "PRIVATE KEY",
Bytes: mk,
}
-
- // Private key in PEM format
privatePEM := pem.EncodeToMemory(&privBlock)
return privatePEM, nil
}
diff --git a/client/ssh/testutil/user_helpers.go b/client/ssh/testutil/user_helpers.go
new file mode 100644
index 000000000..8960d8dd0
--- /dev/null
+++ b/client/ssh/testutil/user_helpers.go
@@ -0,0 +1,173 @@
+package testutil
+
+import (
+ "fmt"
+ "log"
+ "os"
+ "os/exec"
+ "os/user"
+ "runtime"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+var testCreatedUsers = make(map[string]bool)
+var testUsersToCleanup []string
+
+// GetTestUsername returns an appropriate username for testing
+func GetTestUsername(t *testing.T) string {
+ if runtime.GOOS == "windows" {
+ currentUser, err := user.Current()
+ require.NoError(t, err, "Should be able to get current user")
+
+ if IsSystemAccount(currentUser.Username) {
+ if IsCI() {
+ if testUser := GetOrCreateTestUser(t); testUser != "" {
+ return testUser
+ }
+ } else {
+ if _, err := user.Lookup("Administrator"); err == nil {
+ return "Administrator"
+ }
+ if testUser := GetOrCreateTestUser(t); testUser != "" {
+ return testUser
+ }
+ }
+ }
+ return currentUser.Username
+ }
+
+ currentUser, err := user.Current()
+ require.NoError(t, err, "Should be able to get current user")
+ return currentUser.Username
+}
+
+// IsCI checks if we're running in a CI environment
+func IsCI() bool {
+ if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
+ return true
+ }
+
+ hostname, err := os.Hostname()
+ if err == nil && strings.HasPrefix(hostname, "runner") {
+ return true
+ }
+
+ return false
+}
+
+// IsSystemAccount checks if the user is a system account that can't authenticate
+func IsSystemAccount(username string) bool {
+ systemAccounts := []string{
+ "system",
+ "NT AUTHORITY\\SYSTEM",
+ "NT AUTHORITY\\LOCAL SERVICE",
+ "NT AUTHORITY\\NETWORK SERVICE",
+ }
+
+ for _, sysAccount := range systemAccounts {
+ if strings.EqualFold(username, sysAccount) {
+ return true
+ }
+ }
+
+ return strings.HasSuffix(username, "$")
+}
+
+// RegisterTestUserCleanup registers a test user for cleanup
+func RegisterTestUserCleanup(username string) {
+ if !testCreatedUsers[username] {
+ testCreatedUsers[username] = true
+ testUsersToCleanup = append(testUsersToCleanup, username)
+ }
+}
+
+// CleanupTestUsers removes all created test users
+func CleanupTestUsers() {
+ for _, username := range testUsersToCleanup {
+ RemoveWindowsTestUser(username)
+ }
+ testUsersToCleanup = nil
+ testCreatedUsers = make(map[string]bool)
+}
+
+// GetOrCreateTestUser creates a test user on Windows if needed
+func GetOrCreateTestUser(t *testing.T) string {
+ testUsername := "netbird-test-user"
+
+ if _, err := user.Lookup(testUsername); err == nil {
+ return testUsername
+ }
+
+ if CreateWindowsTestUser(t, testUsername) {
+ RegisterTestUserCleanup(testUsername)
+ return testUsername
+ }
+
+ return ""
+}
+
+// RemoveWindowsTestUser removes a local user on Windows using PowerShell
+func RemoveWindowsTestUser(username string) {
+ if runtime.GOOS != "windows" {
+ return
+ }
+
+ psCmd := fmt.Sprintf(`
+ try {
+ Remove-LocalUser -Name "%s" -ErrorAction Stop
+ Write-Output "User removed successfully"
+ } catch {
+ if ($_.Exception.Message -like "*cannot be found*") {
+ Write-Output "User not found (already removed)"
+ } else {
+ Write-Error $_.Exception.Message
+ }
+ }
+ `, username)
+
+ cmd := exec.Command("powershell", "-Command", psCmd)
+ output, err := cmd.CombinedOutput()
+
+ if err != nil {
+ log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
+ } else {
+ log.Printf("Test user %s cleanup result: %s", username, string(output))
+ }
+}
+
+// CreateWindowsTestUser creates a local user on Windows using PowerShell
+func CreateWindowsTestUser(t *testing.T, username string) bool {
+ if runtime.GOOS != "windows" {
+ return false
+ }
+
+ psCmd := fmt.Sprintf(`
+ try {
+ $password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
+ New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
+ Add-LocalGroupMember -Group "Users" -Member "%s"
+ Write-Output "User created successfully"
+ } catch {
+ if ($_.Exception.Message -like "*already exists*") {
+ Write-Output "User already exists"
+ } else {
+ Write-Error $_.Exception.Message
+ exit 1
+ }
+ }
+ `, username, username)
+
+ cmd := exec.Command("powershell", "-Command", psCmd)
+ output, err := cmd.CombinedOutput()
+
+ if err != nil {
+ t.Logf("Failed to create test user: %v, output: %s", err, string(output))
+ return false
+ }
+
+ t.Logf("Test user creation result: %s", string(output))
+ return true
+}
diff --git a/client/ssh/testutil/user_helpers_test.go b/client/ssh/testutil/user_helpers_test.go
new file mode 100644
index 000000000..db2f5f06d
--- /dev/null
+++ b/client/ssh/testutil/user_helpers_test.go
@@ -0,0 +1,115 @@
+package testutil
+
+import (
+ "os/user"
+ "runtime"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// TestUserCurrentBehavior validates user.Current() behavior on Windows.
+// When running as SYSTEM on a domain-joined machine, user.Current() returns:
+// - Username: Computer account name (e.g., "DOMAIN\MACHINE$")
+// - SID: SYSTEM SID (S-1-5-18)
+func TestUserCurrentBehavior(t *testing.T) {
+ if runtime.GOOS != "windows" {
+ t.Skip("Windows-specific test")
+ }
+
+ currentUser, err := user.Current()
+ require.NoError(t, err, "Should be able to get current user")
+
+ t.Logf("Current user - Username: %s, SID: %s", currentUser.Username, currentUser.Uid)
+
+ // When running as SYSTEM, validate expected behavior
+ if currentUser.Uid == "S-1-5-18" {
+ t.Run("SYSTEM_account_behavior", func(t *testing.T) {
+ // SID must be S-1-5-18 for SYSTEM
+ require.Equal(t, "S-1-5-18", currentUser.Uid,
+ "SYSTEM account must have SID S-1-5-18")
+
+ // Username can be either "NT AUTHORITY\SYSTEM" (standalone)
+ // or "DOMAIN\MACHINE$" (domain-joined)
+ username := currentUser.Username
+ isNTAuthority := strings.Contains(strings.ToUpper(username), "NT AUTHORITY")
+ isComputerAccount := strings.HasSuffix(username, "$")
+
+ assert.True(t, isNTAuthority || isComputerAccount,
+ "Username should be either 'NT AUTHORITY\\SYSTEM' or computer account (ending with $), got: %s",
+ username)
+
+ if isComputerAccount {
+ t.Logf("SYSTEM as computer account: %s", username)
+ } else if isNTAuthority {
+ t.Logf("SYSTEM as NT AUTHORITY\\SYSTEM")
+ }
+ })
+ }
+
+ // Validate that IsSystemAccount correctly identifies system accounts
+ t.Run("IsSystemAccount_validation", func(t *testing.T) {
+ // Test with current user if it's a system account
+ if currentUser.Uid == "S-1-5-18" || // SYSTEM
+ currentUser.Uid == "S-1-5-19" || // LOCAL SERVICE
+ currentUser.Uid == "S-1-5-20" { // NETWORK SERVICE
+
+ result := IsSystemAccount(currentUser.Username)
+ assert.True(t, result,
+ "IsSystemAccount should recognize system account: %s (SID: %s)",
+ currentUser.Username, currentUser.Uid)
+ }
+
+ // Test explicit cases
+ testCases := []struct {
+ username string
+ expected bool
+ reason string
+ }{
+ {"NT AUTHORITY\\SYSTEM", true, "NT AUTHORITY\\SYSTEM"},
+ {"system", true, "system"},
+ {"SYSTEM", true, "SYSTEM (case insensitive)"},
+ {"NT AUTHORITY\\LOCAL SERVICE", true, "LOCAL SERVICE"},
+ {"NT AUTHORITY\\NETWORK SERVICE", true, "NETWORK SERVICE"},
+ {"DOMAIN\\MACHINE$", true, "computer account (ends with $)"},
+ {"WORKGROUP\\WIN2K19-C2$", true, "computer account (ends with $)"},
+ {"Administrator", false, "Administrator is not a system account"},
+ {"alice", false, "regular user"},
+ {"DOMAIN\\alice", false, "domain user"},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.username, func(t *testing.T) {
+ result := IsSystemAccount(tc.username)
+ assert.Equal(t, tc.expected, result,
+ "IsSystemAccount(%q) should be %v because: %s",
+ tc.username, tc.expected, tc.reason)
+ })
+ }
+ })
+}
+
+// TestComputerAccountDetection validates computer account detection.
+func TestComputerAccountDetection(t *testing.T) {
+ if runtime.GOOS != "windows" {
+ t.Skip("Windows-specific test")
+ }
+
+ computerAccounts := []string{
+ "MACHINE$",
+ "WIN2K19-C2$",
+ "DOMAIN\\MACHINE$",
+ "WORKGROUP\\SERVER$",
+ "server.domain.com$",
+ }
+
+ for _, account := range computerAccounts {
+ t.Run(account, func(t *testing.T) {
+ result := IsSystemAccount(account)
+ assert.True(t, result,
+ "Computer account %q should be recognized as system account", account)
+ })
+ }
+}
diff --git a/client/ssh/window_freebsd.go b/client/ssh/window_freebsd.go
deleted file mode 100644
index ef4848341..000000000
--- a/client/ssh/window_freebsd.go
+++ /dev/null
@@ -1,10 +0,0 @@
-//go:build freebsd
-
-package ssh
-
-import (
- "os"
-)
-
-func setWinSize(file *os.File, width, height int) {
-}
diff --git a/client/ssh/window_unix.go b/client/ssh/window_unix.go
deleted file mode 100644
index 2891eb70e..000000000
--- a/client/ssh/window_unix.go
+++ /dev/null
@@ -1,14 +0,0 @@
-//go:build linux || darwin
-
-package ssh
-
-import (
- "os"
- "syscall"
- "unsafe"
-)
-
-func setWinSize(file *os.File, width, height int) {
- syscall.Syscall(syscall.SYS_IOCTL, file.Fd(), uintptr(syscall.TIOCSWINSZ), //nolint
- uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(height), uint16(width), 0, 0})))
-}
diff --git a/client/ssh/window_windows.go b/client/ssh/window_windows.go
deleted file mode 100644
index 5abd41f27..000000000
--- a/client/ssh/window_windows.go
+++ /dev/null
@@ -1,9 +0,0 @@
-package ssh
-
-import (
- "os"
-)
-
-func setWinSize(file *os.File, width, height int) {
-
-}
diff --git a/client/status/status.go b/client/status/status.go
index db5b7dc0b..d975f0e29 100644
--- a/client/status/status.go
+++ b/client/status/status.go
@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
+ probeRelay "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/version"
@@ -80,6 +81,18 @@ type NsServerGroupStateOutput struct {
Error string `json:"error" yaml:"error"`
}
+type SSHSessionOutput struct {
+ Username string `json:"username" yaml:"username"`
+ RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
+ Command string `json:"command" yaml:"command"`
+ JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"`
+}
+
+type SSHServerStateOutput struct {
+ Enabled bool `json:"enabled" yaml:"enabled"`
+ Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
+}
+
type OutputOverview struct {
Peers PeersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
@@ -99,6 +112,7 @@ type OutputOverview struct {
Events []SystemEventOutput `json:"events" yaml:"events"`
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
ProfileName string `json:"profileName" yaml:"profileName"`
+ SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
}
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview {
@@ -120,6 +134,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
relayOverview := mapRelays(pbFullStatus.GetRelays())
peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
+ sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
overview := OutputOverview{
Peers: peersOverview,
@@ -140,6 +155,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
Events: mapEvents(pbFullStatus.GetEvents()),
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
ProfileName: profName,
+ SSHServerState: sshServerOverview,
}
if anon {
@@ -189,6 +205,30 @@ func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput {
return mappedNSGroups
}
+func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput {
+ if sshServerState == nil {
+ return SSHServerStateOutput{
+ Enabled: false,
+ Sessions: []SSHSessionOutput{},
+ }
+ }
+
+ sessions := make([]SSHSessionOutput, 0, len(sshServerState.GetSessions()))
+ for _, session := range sshServerState.GetSessions() {
+ sessions = append(sessions, SSHSessionOutput{
+ Username: session.GetUsername(),
+ RemoteAddress: session.GetRemoteAddress(),
+ Command: session.GetCommand(),
+ JWTUsername: session.GetJwtUsername(),
+ })
+ }
+
+ return SSHServerStateOutput{
+ Enabled: sshServerState.GetEnabled(),
+ Sessions: sessions,
+ }
+}
+
func mapPeers(
peers []*proto.PeerState,
statusFilter string,
@@ -205,15 +245,18 @@ func mapPeers(
localICEEndpoint := ""
remoteICEEndpoint := ""
relayServerAddress := ""
- connType := "P2P"
+ connType := "-"
lastHandshake := time.Time{}
transferReceived := int64(0)
transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
- if pbPeerState.Relayed {
- connType = "Relayed"
+ if isPeerConnected {
+ connType = "P2P"
+ if pbPeerState.Relayed {
+ connType = "Relayed"
+ }
}
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) {
@@ -296,7 +339,7 @@ func ParseToYAML(overview OutputOverview) (string, error) {
return string(yamlBytes), nil
}
-func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
+func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
var managementConnString string
if overview.ManagementState.Connected {
managementConnString = "Connected"
@@ -337,10 +380,16 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
for _, relay := range overview.Relays.Details {
available := "Available"
reason := ""
+
if !relay.Available {
- available = "Unavailable"
- reason = fmt.Sprintf(", reason: %s", relay.Error)
+ if relay.Error == probeRelay.ErrCheckInProgress.Error() {
+ available = "Checking..."
+ } else {
+ available = "Unavailable"
+ reason = fmt.Sprintf(", reason: %s", relay.Error)
+ }
}
+
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
}
} else {
@@ -395,6 +444,41 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
lazyConnectionEnabledStatus = "true"
}
+ sshServerStatus := "Disabled"
+ if overview.SSHServerState.Enabled {
+ sessionCount := len(overview.SSHServerState.Sessions)
+ if sessionCount > 0 {
+ sessionWord := "session"
+ if sessionCount > 1 {
+ sessionWord = "sessions"
+ }
+ sshServerStatus = fmt.Sprintf("Enabled (%d active %s)", sessionCount, sessionWord)
+ } else {
+ sshServerStatus = "Enabled"
+ }
+
+ if showSSHSessions && sessionCount > 0 {
+ for _, session := range overview.SSHServerState.Sessions {
+ var sessionDisplay string
+ if session.JWTUsername != "" {
+ sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s",
+ session.JWTUsername,
+ session.RemoteAddress,
+ session.Username,
+ session.Command,
+ )
+ } else {
+ sessionDisplay = fmt.Sprintf("[%s@%s] %s",
+ session.Username,
+ session.RemoteAddress,
+ session.Command,
+ )
+ }
+ sshServerStatus += "\n " + sessionDisplay
+ }
+ }
+ }
+
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS
@@ -418,6 +502,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
"Interface type: %s\n"+
"Quantum resistance: %s\n"+
"Lazy connection: %s\n"+
+ "SSH Server: %s\n"+
"Networks: %s\n"+
"Forwarding rules: %d\n"+
"Peers count: %s\n",
@@ -434,6 +519,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
interfaceTypeString,
rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
+ sshServerStatus,
networks,
overview.NumberOfForwardingRules,
peersCountString,
@@ -444,7 +530,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
func ParseToFullDetailSummary(overview OutputOverview) string {
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
parsedEventsString := parseEvents(overview.Events)
- summary := ParseGeneralSummary(overview, true, true, true)
+ summary := ParseGeneralSummary(overview, true, true, true, true)
return fmt.Sprintf(
"Peers detail:"+
@@ -736,4 +822,13 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
event.Metadata[k] = a.AnonymizeString(v)
}
}
+
+ for i, session := range overview.SSHServerState.Sessions {
+ if host, port, err := net.SplitHostPort(session.RemoteAddress); err == nil {
+ overview.SSHServerState.Sessions[i].RemoteAddress = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
+ } else {
+ overview.SSHServerState.Sessions[i].RemoteAddress = a.AnonymizeIPString(session.RemoteAddress)
+ }
+ overview.SSHServerState.Sessions[i].Command = a.AnonymizeString(session.Command)
+ }
}
diff --git a/client/status/status_test.go b/client/status/status_test.go
index 660efd9ef..1dca1e5b1 100644
--- a/client/status/status_test.go
+++ b/client/status/status_test.go
@@ -231,6 +231,10 @@ var overview = OutputOverview{
Networks: []string{
"10.10.0.0/24",
},
+ SSHServerState: SSHServerStateOutput{
+ Enabled: false,
+ Sessions: []SSHSessionOutput{},
+ },
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
@@ -385,7 +389,11 @@ func TestParsingToJSON(t *testing.T) {
],
"events": [],
"lazyConnectionEnabled": false,
- "profileName":""
+ "profileName":"",
+ "sshServer":{
+ "enabled":false,
+ "sessions":[]
+ }
}`
// @formatter:on
@@ -488,6 +496,9 @@ dnsServers:
events: []
lazyConnectionEnabled: false
profileName: ""
+sshServer:
+ enabled: false
+ sessions: []
`
assert.Equal(t, expectedYAML, yaml)
@@ -554,6 +565,7 @@ NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Lazy connection: false
+SSH Server: Disabled
Networks: 10.10.0.0/24
Forwarding rules: 0
Peers count: 2/2 Connected
@@ -563,7 +575,7 @@ Peers count: 2/2 Connected
}
func TestParsingToShortVersion(t *testing.T) {
- shortVersion := ParseGeneralSummary(overview, false, false, false)
+ shortVersion := ParseGeneralSummary(overview, false, false, false, false)
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
Daemon version: 0.14.1
@@ -578,6 +590,7 @@ NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Lazy connection: false
+SSH Server: Disabled
Networks: 10.10.0.0/24
Forwarding rules: 0
Peers count: 2/2 Connected
diff --git a/client/system/info.go b/client/system/info.go
index a180be4c0..01176e765 100644
--- a/client/system/info.go
+++ b/client/system/info.go
@@ -72,6 +72,12 @@ type Info struct {
BlockInbound bool
LazyConnectionEnabled bool
+
+ EnableSSHRoot bool
+ EnableSSHSFTP bool
+ EnableSSHLocalPortForwarding bool
+ EnableSSHRemotePortForwarding bool
+ DisableSSHAuth bool
}
func (i *Info) SetFlags(
@@ -79,6 +85,8 @@ func (i *Info) SetFlags(
serverSSHAllowed *bool,
disableClientRoutes, disableServerRoutes,
disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool,
+ enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
+ disableSSHAuth *bool,
) {
i.RosenpassEnabled = rosenpassEnabled
i.RosenpassPermissive = rosenpassPermissive
@@ -94,6 +102,22 @@ func (i *Info) SetFlags(
i.BlockInbound = blockInbound
i.LazyConnectionEnabled = lazyConnectionEnabled
+
+ if enableSSHRoot != nil {
+ i.EnableSSHRoot = *enableSSHRoot
+ }
+ if enableSSHSFTP != nil {
+ i.EnableSSHSFTP = *enableSSHSFTP
+ }
+ if enableSSHLocalPortForwarding != nil {
+ i.EnableSSHLocalPortForwarding = *enableSSHLocalPortForwarding
+ }
+ if enableSSHRemotePortForwarding != nil {
+ i.EnableSSHRemotePortForwarding = *enableSSHRemotePortForwarding
+ }
+ if disableSSHAuth != nil {
+ i.DisableSSHAuth = *disableSSHAuth
+ }
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
diff --git a/client/ui/assets/netbird-disconnected.ico b/client/ui/assets/netbird-disconnected.ico
new file mode 100644
index 000000000..812e9d283
Binary files /dev/null and b/client/ui/assets/netbird-disconnected.ico differ
diff --git a/client/ui/assets/netbird-disconnected.png b/client/ui/assets/netbird-disconnected.png
new file mode 100644
index 000000000..79d4775ea
Binary files /dev/null and b/client/ui/assets/netbird-disconnected.png differ
diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go
index 7c2000a9d..87bac8c31 100644
--- a/client/ui/client_ui.go
+++ b/client/ui/client_ui.go
@@ -31,19 +31,19 @@ import (
"fyne.io/systray"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
- "github.com/skratchdot/open-golang/open"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
+ protobuf "google.golang.org/protobuf/proto"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
+ "github.com/netbirdio/netbird/client/internal/sleep"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ui/desktop"
"github.com/netbirdio/netbird/client/ui/event"
"github.com/netbirdio/netbird/client/ui/process"
-
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
@@ -56,6 +56,7 @@ const (
const (
censoredPreSharedKey = "**********"
+ maxSSHJWTCacheTTL = 86_400 // 24 hours in seconds
)
func main() {
@@ -86,21 +87,24 @@ func main() {
// Create the service client (this also builds the settings or networks UI if requested).
client := newServiceClient(&newServiceClientArgs{
- addr: flags.daemonAddr,
- logFile: logFile,
- app: a,
- showSettings: flags.showSettings,
- showNetworks: flags.showNetworks,
- showLoginURL: flags.showLoginURL,
- showDebug: flags.showDebug,
- showProfiles: flags.showProfiles,
+ addr: flags.daemonAddr,
+ logFile: logFile,
+ app: a,
+ showSettings: flags.showSettings,
+ showNetworks: flags.showNetworks,
+ showLoginURL: flags.showLoginURL,
+ showDebug: flags.showDebug,
+ showProfiles: flags.showProfiles,
+ showQuickActions: flags.showQuickActions,
+ showUpdate: flags.showUpdate,
+ showUpdateVersion: flags.showUpdateVersion,
})
// Watch for theme/settings changes to update the icon.
go watchSettingsChanges(a, client)
// Run in window mode if any UI flag was set.
- if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles {
+ if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate {
a.Run()
return
}
@@ -112,23 +116,31 @@ func main() {
return
}
if running {
- log.Warnf("another process is running with pid %d, exiting", pid)
+ log.Infof("another process is running with pid %d, sending signal to show window", pid)
+ if err := sendShowWindowSignal(pid); err != nil {
+ log.Errorf("send signal to running instance: %v", err)
+ }
return
}
+ client.setupSignalHandler(client.ctx)
+
client.setDefaultFonts()
systray.Run(client.onTrayReady, client.onTrayExit)
}
type cliFlags struct {
- daemonAddr string
- showSettings bool
- showNetworks bool
- showProfiles bool
- showDebug bool
- showLoginURL bool
- errorMsg string
- saveLogsInFile bool
+ daemonAddr string
+ showSettings bool
+ showNetworks bool
+ showProfiles bool
+ showDebug bool
+ showLoginURL bool
+ showQuickActions bool
+ errorMsg string
+ saveLogsInFile bool
+ showUpdate bool
+ showUpdateVersion string
}
// parseFlags reads and returns all needed command-line flags.
@@ -144,9 +156,12 @@ func parseFlags() *cliFlags {
flag.BoolVar(&flags.showNetworks, "networks", false, "run networks window")
flag.BoolVar(&flags.showProfiles, "profiles", false, "run profiles window")
flag.BoolVar(&flags.showDebug, "debug", false, "run debug window")
+ flag.BoolVar(&flags.showQuickActions, "quick-actions", false, "run quick actions window")
flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window")
flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
+ flag.BoolVar(&flags.showUpdate, "update", false, "show update progress window")
+ flag.StringVar(&flags.showUpdateVersion, "update-version", "", "version to update to")
flag.Parse()
return &flags
}
@@ -159,11 +174,9 @@ func initLogFile() (string, error) {
// watchSettingsChanges listens for Fyne theme/settings changes and updates the client icon.
func watchSettingsChanges(a fyne.App, client *serviceClient) {
- settingsChangeChan := make(chan fyne.Settings)
- a.Settings().AddChangeListener(settingsChangeChan)
- for range settingsChangeChan {
+ a.Settings().AddListener(func(settings fyne.Settings) {
client.updateIcon()
- }
+ })
}
// showErrorMessage displays an error message in a simple window.
@@ -203,10 +216,11 @@ var iconConnectedDot []byte
var iconDisconnectedDot []byte
type serviceClient struct {
- ctx context.Context
- cancel context.CancelFunc
- addr string
- conn proto.DaemonServiceClient
+ ctx context.Context
+ cancel context.CancelFunc
+ addr string
+ conn proto.DaemonServiceClient
+ connLock sync.Mutex
eventHandler *eventHandler
@@ -260,34 +274,50 @@ type serviceClient struct {
iMTU *widget.Entry
// switch elements for settings form
- sRosenpassPermissive *widget.Check
- sNetworkMonitor *widget.Check
- sDisableDNS *widget.Check
- sDisableClientRoutes *widget.Check
- sDisableServerRoutes *widget.Check
- sBlockLANAccess *widget.Check
+ sRosenpassPermissive *widget.Check
+ sNetworkMonitor *widget.Check
+ sDisableDNS *widget.Check
+ sDisableClientRoutes *widget.Check
+ sDisableServerRoutes *widget.Check
+ sBlockLANAccess *widget.Check
+ sEnableSSHRoot *widget.Check
+ sEnableSSHSFTP *widget.Check
+ sEnableSSHLocalPortForward *widget.Check
+ sEnableSSHRemotePortForward *widget.Check
+ sDisableSSHAuth *widget.Check
+ iSSHJWTCacheTTL *widget.Entry
// observable settings over corresponding iMngURL and iPreSharedKey values.
- managementURL string
- preSharedKey string
- RosenpassPermissive bool
- interfaceName string
- interfacePort int
- mtu uint16
- networkMonitor bool
- disableDNS bool
- disableClientRoutes bool
- disableServerRoutes bool
- blockLANAccess bool
+ managementURL string
+ preSharedKey string
+
+ RosenpassPermissive bool
+ interfaceName string
+ interfacePort int
+ mtu uint16
+ networkMonitor bool
+ disableDNS bool
+ disableClientRoutes bool
+ disableServerRoutes bool
+ blockLANAccess bool
+ enableSSHRoot bool
+ enableSSHSFTP bool
+ enableSSHLocalPortForward bool
+ enableSSHRemotePortForward bool
+ disableSSHAuth bool
+ sshJWTCacheTTL int
connected bool
update *version.Update
daemonVersion string
updateIndicationLock sync.Mutex
isUpdateIconActive bool
+ settingsEnabled bool
+ profilesEnabled bool
showNetworks bool
wNetworks fyne.Window
wProfiles fyne.Window
+ wQuickActions fyne.Window
eventManager *event.Manager
@@ -297,6 +327,10 @@ type serviceClient struct {
mExitNodeDeselectAll *systray.MenuItem
logFile string
wLoginURL fyne.Window
+ wUpdateProgress fyne.Window
+ updateContextCancel context.CancelFunc
+
+ connectCancel context.CancelFunc
}
type menuHandler struct {
@@ -305,14 +339,17 @@ type menuHandler struct {
}
type newServiceClientArgs struct {
- addr string
- logFile string
- app fyne.App
- showSettings bool
- showNetworks bool
- showDebug bool
- showLoginURL bool
- showProfiles bool
+ addr string
+ logFile string
+ app fyne.App
+ showSettings bool
+ showNetworks bool
+ showDebug bool
+ showLoginURL bool
+ showProfiles bool
+ showQuickActions bool
+ showUpdate bool
+ showUpdateVersion string
}
// newServiceClient instance constructor
@@ -330,7 +367,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
showAdvancedSettings: args.showSettings,
showNetworks: args.showNetworks,
- update: version.NewUpdate("nb/client-ui"),
+ update: version.NewUpdateAndStart("nb/client-ui"),
}
s.eventHandler = newEventHandler(s)
@@ -348,6 +385,10 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
s.showDebugUI()
case args.showProfiles:
s.showProfilesUI()
+ case args.showQuickActions:
+ s.showQuickActionsUI()
+ case args.showUpdate:
+ s.showUpdateProgress(ctx, args.showUpdateVersion)
}
return s
@@ -424,18 +465,22 @@ func (s *serviceClient) showSettingsUI() {
s.sDisableClientRoutes = widget.NewCheck("This peer won't route traffic to other peers", nil)
s.sDisableServerRoutes = widget.NewCheck("This peer won't act as router for others", nil)
s.sBlockLANAccess = widget.NewCheck("Blocks local network access when used as exit node", nil)
+ s.sEnableSSHRoot = widget.NewCheck("Enable SSH Root Login", nil)
+ s.sEnableSSHSFTP = widget.NewCheck("Enable SSH SFTP", nil)
+ s.sEnableSSHLocalPortForward = widget.NewCheck("Enable SSH Local Port Forwarding", nil)
+ s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil)
+ s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil)
+ s.iSSHJWTCacheTTL = widget.NewEntry()
s.wSettings.SetContent(s.getSettingsForm())
- s.wSettings.Resize(fyne.NewSize(600, 500))
+ s.wSettings.Resize(fyne.NewSize(600, 400))
s.wSettings.SetFixedSize(true)
s.getSrvConfig()
s.wSettings.Show()
}
-// getSettingsForm to embed it into settings window.
-func (s *serviceClient) getSettingsForm() *widget.Form {
-
+func (s *serviceClient) getConnectionForm() *widget.Form {
var activeProfName string
activeProf, err := s.profileManager.GetActiveProfile()
if err != nil {
@@ -446,164 +491,286 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
return &widget.Form{
Items: []*widget.FormItem{
{Text: "Profile", Widget: widget.NewLabel(activeProfName)},
+ {Text: "Management URL", Widget: s.iMngURL},
+ {Text: "Pre-shared Key", Widget: s.iPreSharedKey},
{Text: "Quantum-Resistance", Widget: s.sRosenpassPermissive},
{Text: "Interface Name", Widget: s.iInterfaceName},
{Text: "Interface Port", Widget: s.iInterfacePort},
{Text: "MTU", Widget: s.iMTU},
- {Text: "Management URL", Widget: s.iMngURL},
- {Text: "Pre-shared Key", Widget: s.iPreSharedKey},
{Text: "Log File", Widget: s.iLogFile},
+ },
+ }
+}
+
+func (s *serviceClient) saveSettings() {
+ // Check if update settings are disabled by daemon
+ features, err := s.getFeatures()
+ if err != nil {
+ log.Errorf("failed to get features from daemon: %v", err)
+ // Continue with default behavior if features can't be retrieved
+ } else if features != nil && features.DisableUpdateSettings {
+ log.Warn("Configuration updates are disabled by daemon")
+ dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings)
+ return
+ }
+
+ if err := s.validateSettings(); err != nil {
+ dialog.ShowError(err, s.wSettings)
+ return
+ }
+
+ port, mtu, err := s.parseNumericSettings()
+ if err != nil {
+ dialog.ShowError(err, s.wSettings)
+ return
+ }
+
+ iMngURL := strings.TrimSpace(s.iMngURL.Text)
+
+ if s.hasSettingsChanged(iMngURL, port, mtu) {
+ if err := s.applySettingsChanges(iMngURL, port, mtu); err != nil {
+ dialog.ShowError(err, s.wSettings)
+ return
+ }
+ }
+
+ s.wSettings.Close()
+}
+
+func (s *serviceClient) validateSettings() error {
+ if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey {
+ if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil {
+ return fmt.Errorf("Invalid Pre-shared Key Value")
+ }
+ }
+ return nil
+}
+
+func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
+ port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64)
+ if err != nil {
+ return 0, 0, errors.New("Invalid interface port")
+ }
+ if port < 1 || port > 65535 {
+ return 0, 0, errors.New("Invalid interface port: out of range 1-65535")
+ }
+
+ var mtu int64
+ mtuText := strings.TrimSpace(s.iMTU.Text)
+ if mtuText != "" {
+ mtu, err = strconv.ParseInt(mtuText, 10, 64)
+ if err != nil {
+ return 0, 0, errors.New("Invalid MTU value")
+ }
+ if mtu < iface.MinMTU || mtu > iface.MaxMTU {
+ return 0, 0, fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU)
+ }
+ }
+
+ return port, mtu, nil
+}
+
+func (s *serviceClient) hasSettingsChanged(iMngURL string, port, mtu int64) bool {
+ return s.managementURL != iMngURL ||
+ s.preSharedKey != s.iPreSharedKey.Text ||
+ s.RosenpassPermissive != s.sRosenpassPermissive.Checked ||
+ s.interfaceName != s.iInterfaceName.Text ||
+ s.interfacePort != int(port) ||
+ s.mtu != uint16(mtu) ||
+ s.networkMonitor != s.sNetworkMonitor.Checked ||
+ s.disableDNS != s.sDisableDNS.Checked ||
+ s.disableClientRoutes != s.sDisableClientRoutes.Checked ||
+ s.disableServerRoutes != s.sDisableServerRoutes.Checked ||
+ s.blockLANAccess != s.sBlockLANAccess.Checked ||
+ s.hasSSHChanges()
+}
+
+func (s *serviceClient) applySettingsChanges(iMngURL string, port, mtu int64) error {
+ s.managementURL = iMngURL
+ s.preSharedKey = s.iPreSharedKey.Text
+ s.mtu = uint16(mtu)
+
+ req, err := s.buildSetConfigRequest(iMngURL, port, mtu)
+ if err != nil {
+ return fmt.Errorf("build config request: %w", err)
+ }
+
+ if err := s.sendConfigUpdate(req); err != nil {
+ return fmt.Errorf("set configuration: %w", err)
+ }
+
+ return nil
+}
+
+func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (*proto.SetConfigRequest, error) {
+ currUser, err := user.Current()
+ if err != nil {
+ return nil, fmt.Errorf("get current user: %w", err)
+ }
+
+ activeProf, err := s.profileManager.GetActiveProfile()
+ if err != nil {
+ return nil, fmt.Errorf("get active profile: %w", err)
+ }
+
+ req := &proto.SetConfigRequest{
+ ProfileName: activeProf.Name,
+ Username: currUser.Username,
+ }
+
+ if iMngURL != "" {
+ req.ManagementUrl = iMngURL
+ }
+
+ req.RosenpassPermissive = &s.sRosenpassPermissive.Checked
+ req.InterfaceName = &s.iInterfaceName.Text
+ req.WireguardPort = &port
+ if mtu > 0 {
+ req.Mtu = &mtu
+ }
+
+ req.NetworkMonitor = &s.sNetworkMonitor.Checked
+ req.DisableDns = &s.sDisableDNS.Checked
+ req.DisableClientRoutes = &s.sDisableClientRoutes.Checked
+ req.DisableServerRoutes = &s.sDisableServerRoutes.Checked
+ req.BlockLanAccess = &s.sBlockLANAccess.Checked
+
+ req.EnableSSHRoot = &s.sEnableSSHRoot.Checked
+ req.EnableSSHSFTP = &s.sEnableSSHSFTP.Checked
+ req.EnableSSHLocalPortForwarding = &s.sEnableSSHLocalPortForward.Checked
+ req.EnableSSHRemotePortForwarding = &s.sEnableSSHRemotePortForward.Checked
+ req.DisableSSHAuth = &s.sDisableSSHAuth.Checked
+
+ sshJWTCacheTTLText := strings.TrimSpace(s.iSSHJWTCacheTTL.Text)
+ if sshJWTCacheTTLText != "" {
+ sshJWTCacheTTL, err := strconv.ParseInt(sshJWTCacheTTLText, 10, 32)
+ if err != nil {
+ return nil, errors.New("Invalid SSH JWT Cache TTL value")
+ }
+ if sshJWTCacheTTL < 0 || sshJWTCacheTTL > maxSSHJWTCacheTTL {
+ return nil, fmt.Errorf("SSH JWT Cache TTL must be between 0 and %d seconds", maxSSHJWTCacheTTL)
+ }
+ sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
+ req.SshJWTCacheTTL = &sshJWTCacheTTL32
+ }
+
+ if s.iPreSharedKey.Text != censoredPreSharedKey {
+ req.OptionalPreSharedKey = &s.iPreSharedKey.Text
+ }
+
+ return req, nil
+}
+
+func (s *serviceClient) sendConfigUpdate(req *proto.SetConfigRequest) error {
+ conn, err := s.getSrvClient(failFastTimeout)
+ if err != nil {
+ return fmt.Errorf("get client: %w", err)
+ }
+
+ _, err = conn.SetConfig(s.ctx, req)
+ if err != nil {
+ return fmt.Errorf("set config: %w", err)
+ }
+
+ // Reconnect if connected to apply the new settings
+ go func() {
+ status, err := conn.Status(s.ctx, &proto.StatusRequest{})
+ if err != nil {
+ log.Errorf("get service status: %v", err)
+ return
+ }
+ if status.Status == string(internal.StatusConnected) {
+ // run down & up
+ _, err = conn.Down(s.ctx, &proto.DownRequest{})
+ if err != nil {
+ log.Errorf("down service: %v", err)
+ }
+
+ _, err = conn.Up(s.ctx, &proto.UpRequest{})
+ if err != nil {
+ log.Errorf("up service: %v", err)
+ return
+ }
+ }
+ }()
+
+ return nil
+}
+
+func (s *serviceClient) getSettingsForm() fyne.CanvasObject {
+ connectionForm := s.getConnectionForm()
+ networkForm := s.getNetworkForm()
+ sshForm := s.getSSHForm()
+ tabs := container.NewAppTabs(
+ container.NewTabItem("Connection", connectionForm),
+ container.NewTabItem("Network", networkForm),
+ container.NewTabItem("SSH", sshForm),
+ )
+ saveButton := widget.NewButtonWithIcon("Save", theme.ConfirmIcon(), s.saveSettings)
+ saveButton.Importance = widget.HighImportance
+ cancelButton := widget.NewButtonWithIcon("Cancel", theme.CancelIcon(), func() {
+ s.wSettings.Close()
+ })
+ buttonContainer := container.NewHBox(
+ layout.NewSpacer(),
+ cancelButton,
+ saveButton,
+ )
+ return container.NewBorder(nil, buttonContainer, nil, nil, tabs)
+}
+
+func (s *serviceClient) getNetworkForm() *widget.Form {
+ return &widget.Form{
+ Items: []*widget.FormItem{
{Text: "Network Monitor", Widget: s.sNetworkMonitor},
{Text: "Disable DNS", Widget: s.sDisableDNS},
{Text: "Disable Client Routes", Widget: s.sDisableClientRoutes},
{Text: "Disable Server Routes", Widget: s.sDisableServerRoutes},
{Text: "Disable LAN Access", Widget: s.sBlockLANAccess},
},
- SubmitText: "Save",
- OnSubmit: func() {
- // Check if update settings are disabled by daemon
- features, err := s.getFeatures()
- if err != nil {
- log.Errorf("failed to get features from daemon: %v", err)
- // Continue with default behavior if features can't be retrieved
- } else if features != nil && features.DisableUpdateSettings {
- log.Warn("Configuration updates are disabled by daemon")
- dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings)
- return
- }
+ }
+}
- if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey {
- // validate preSharedKey if it added
- if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil {
- dialog.ShowError(fmt.Errorf("Invalid Pre-shared Key Value"), s.wSettings)
- return
- }
- }
-
- port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64)
- if err != nil {
- dialog.ShowError(errors.New("Invalid interface port"), s.wSettings)
- return
- }
-
- var mtu int64
- mtuText := strings.TrimSpace(s.iMTU.Text)
- if mtuText != "" {
- var err error
- mtu, err = strconv.ParseInt(mtuText, 10, 64)
- if err != nil {
- dialog.ShowError(errors.New("Invalid MTU value"), s.wSettings)
- return
- }
- if mtu < iface.MinMTU || mtu > iface.MaxMTU {
- dialog.ShowError(fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU), s.wSettings)
- return
- }
- }
-
- iMngURL := strings.TrimSpace(s.iMngURL.Text)
-
- defer s.wSettings.Close()
-
- // Check if any settings have changed
- if s.managementURL != iMngURL || s.preSharedKey != s.iPreSharedKey.Text ||
- s.RosenpassPermissive != s.sRosenpassPermissive.Checked ||
- s.interfaceName != s.iInterfaceName.Text || s.interfacePort != int(port) ||
- s.mtu != uint16(mtu) ||
- s.networkMonitor != s.sNetworkMonitor.Checked ||
- s.disableDNS != s.sDisableDNS.Checked ||
- s.disableClientRoutes != s.sDisableClientRoutes.Checked ||
- s.disableServerRoutes != s.sDisableServerRoutes.Checked ||
- s.blockLANAccess != s.sBlockLANAccess.Checked {
-
- s.managementURL = iMngURL
- s.preSharedKey = s.iPreSharedKey.Text
- s.mtu = uint16(mtu)
-
- currUser, err := user.Current()
- if err != nil {
- log.Errorf("get current user: %v", err)
- return
- }
-
- var req proto.SetConfigRequest
- req.ProfileName = activeProf.Name
- req.Username = currUser.Username
-
- if iMngURL != "" {
- req.ManagementUrl = iMngURL
- }
-
- req.RosenpassPermissive = &s.sRosenpassPermissive.Checked
- req.InterfaceName = &s.iInterfaceName.Text
- req.WireguardPort = &port
- if mtu > 0 {
- req.Mtu = &mtu
- }
- req.NetworkMonitor = &s.sNetworkMonitor.Checked
- req.DisableDns = &s.sDisableDNS.Checked
- req.DisableClientRoutes = &s.sDisableClientRoutes.Checked
- req.DisableServerRoutes = &s.sDisableServerRoutes.Checked
- req.BlockLanAccess = &s.sBlockLANAccess.Checked
-
- if s.iPreSharedKey.Text != censoredPreSharedKey {
- req.OptionalPreSharedKey = &s.iPreSharedKey.Text
- }
-
- conn, err := s.getSrvClient(failFastTimeout)
- if err != nil {
- log.Errorf("get client: %v", err)
- dialog.ShowError(fmt.Errorf("Failed to connect to the service: %v", err), s.wSettings)
- return
- }
- _, err = conn.SetConfig(s.ctx, &req)
- if err != nil {
- log.Errorf("set config: %v", err)
- dialog.ShowError(fmt.Errorf("Failed to set configuration: %v", err), s.wSettings)
- return
- }
-
- go func() {
- status, err := conn.Status(s.ctx, &proto.StatusRequest{})
- if err != nil {
- log.Errorf("get service status: %v", err)
- dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings)
- return
- }
- if status.Status == string(internal.StatusConnected) {
- // run down & up
- _, err = conn.Down(s.ctx, &proto.DownRequest{})
- if err != nil {
- log.Errorf("down service: %v", err)
- }
-
- _, err = conn.Up(s.ctx, &proto.UpRequest{})
- if err != nil {
- log.Errorf("up service: %v", err)
- dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings)
- return
- }
- }
- }()
- }
- },
- OnCancel: func() {
- s.wSettings.Close()
+func (s *serviceClient) getSSHForm() *widget.Form {
+ return &widget.Form{
+ Items: []*widget.FormItem{
+ {Text: "Enable SSH Root Login", Widget: s.sEnableSSHRoot},
+ {Text: "Enable SSH SFTP", Widget: s.sEnableSSHSFTP},
+ {Text: "Enable SSH Local Port Forwarding", Widget: s.sEnableSSHLocalPortForward},
+ {Text: "Enable SSH Remote Port Forwarding", Widget: s.sEnableSSHRemotePortForward},
+ {Text: "Disable SSH Authentication", Widget: s.sDisableSSHAuth},
+ {Text: "JWT Cache TTL (seconds, 0=disabled)", Widget: s.iSSHJWTCacheTTL},
},
}
}
-func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
+func (s *serviceClient) hasSSHChanges() bool {
+ currentSSHJWTCacheTTL := s.sshJWTCacheTTL
+ if text := strings.TrimSpace(s.iSSHJWTCacheTTL.Text); text != "" {
+ val, err := strconv.Atoi(text)
+ if err != nil {
+ return true
+ }
+ currentSSHJWTCacheTTL = val
+ }
+
+ return s.enableSSHRoot != s.sEnableSSHRoot.Checked ||
+ s.enableSSHSFTP != s.sEnableSSHSFTP.Checked ||
+ s.enableSSHLocalPortForward != s.sEnableSSHLocalPortForward.Checked ||
+ s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked ||
+ s.disableSSHAuth != s.sDisableSSHAuth.Checked ||
+ s.sshJWTCacheTTL != currentSSHJWTCacheTTL
+}
+
+func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
- log.Errorf("get client: %v", err)
- return nil, err
+ return nil, fmt.Errorf("get daemon client: %w", err)
}
activeProf, err := s.profileManager.GetActiveProfile()
if err != nil {
- log.Errorf("get active profile: %v", err)
- return nil, err
+ return nil, fmt.Errorf("get active profile: %w", err)
}
currUser, err := user.Current()
@@ -611,84 +778,82 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
return nil, fmt.Errorf("get current user: %w", err)
}
- loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{
+ loginReq := &proto.LoginRequest{
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
ProfileName: &activeProf.Name,
Username: &currUser.Username,
- })
+ }
+
+ profileState, err := s.profileManager.GetProfileState(activeProf.Name)
if err != nil {
- log.Errorf("login to management URL with: %v", err)
- return nil, err
+ log.Debugf("failed to get profile state for login hint: %v", err)
+ } else if profileState.Email != "" {
+ loginReq.Hint = &profileState.Email
+ }
+
+ loginResp, err := conn.Login(ctx, loginReq)
+ if err != nil {
+ return nil, fmt.Errorf("login to management: %w", err)
}
if loginResp.NeedsSSOLogin && openURL {
- err = s.handleSSOLogin(loginResp, conn)
- if err != nil {
- log.Errorf("handle SSO login failed: %v", err)
- return nil, err
+ if err = s.handleSSOLogin(ctx, loginResp, conn); err != nil {
+ return nil, fmt.Errorf("SSO login: %w", err)
}
}
return loginResp, nil
}
-func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
- err := open.Run(loginResp.VerificationURIComplete)
- if err != nil {
- log.Errorf("opening the verification uri in the browser failed: %v", err)
- return err
+func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
+ if err := openURL(loginResp.VerificationURIComplete); err != nil {
+ return fmt.Errorf("open browser: %w", err)
}
- resp, err := conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
+ resp, err := conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
if err != nil {
- log.Errorf("waiting sso login failed with: %v", err)
- return err
+ return fmt.Errorf("wait for SSO login: %w", err)
}
if resp.Email != "" {
- err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{
+ if err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{
Email: resp.Email,
- })
- if err != nil {
- log.Warnf("failed to set profile state: %v", err)
+ }); err != nil {
+ log.Debugf("failed to set profile state: %v", err)
} else {
s.mProfile.refresh()
}
-
}
return nil
}
-func (s *serviceClient) menuUpClick() error {
+func (s *serviceClient) menuUpClick(ctx context.Context, wannaAutoUpdate bool) error {
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
systray.SetTemplateIcon(iconErrorMacOS, s.icError)
- log.Errorf("get client: %v", err)
- return err
+ return fmt.Errorf("get daemon client: %w", err)
}
- _, err = s.login(true)
+ _, err = s.login(ctx, true)
if err != nil {
- log.Errorf("login failed with: %v", err)
- return err
+ return fmt.Errorf("login: %w", err)
}
- status, err := conn.Status(s.ctx, &proto.StatusRequest{})
+ status, err := conn.Status(ctx, &proto.StatusRequest{})
if err != nil {
- log.Errorf("get service status: %v", err)
- return err
+ return fmt.Errorf("get status: %w", err)
}
if status.Status == string(internal.StatusConnected) {
- log.Warnf("already connected")
return nil
}
- if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
- log.Errorf("up service: %v", err)
- return err
+ if _, err := s.conn.Up(s.ctx, &proto.UpRequest{
+ AutoUpdate: protobuf.Bool(wannaAutoUpdate),
+ }); err != nil {
+ return fmt.Errorf("start connection: %w", err)
}
return nil
@@ -698,24 +863,20 @@ func (s *serviceClient) menuDownClick() error {
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
- log.Errorf("get client: %v", err)
- return err
+ return fmt.Errorf("get daemon client: %w", err)
}
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
- log.Errorf("get service status: %v", err)
- return err
+ return fmt.Errorf("get status: %w", err)
}
if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) {
- log.Warnf("already down")
return nil
}
- if _, err := s.conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
- log.Errorf("down service: %v", err)
- return err
+ if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
+ return fmt.Errorf("stop connection: %w", err)
}
return nil
@@ -748,7 +909,7 @@ func (s *serviceClient) updateStatus() error {
var systrayIconState bool
switch {
- case status.Status == string(internal.StatusConnected):
+ case status.Status == string(internal.StatusConnected) && !s.mUp.Disabled():
s.connected = true
s.sendNotification = true
if s.isUpdateIconActive {
@@ -762,6 +923,7 @@ func (s *serviceClient) updateStatus() error {
s.mUp.Disable()
s.mDown.Enable()
s.mNetworks.Enable()
+ s.mExitNode.Enable()
go s.updateExitNodes()
systrayIconState = true
case status.Status == string(internal.StatusConnecting):
@@ -851,6 +1013,7 @@ func (s *serviceClient) onTrayReady() {
newProfileMenuArgs := &newProfileMenuArgs{
ctx: s.ctx,
+ serviceClient: s,
profileManager: s.profileManager,
eventHandler: s.eventHandler,
profileMenuItem: profileMenuItem,
@@ -951,9 +1114,32 @@ func (s *serviceClient) onTrayReady() {
s.updateExitNodes()
}
})
+ s.eventManager.AddHandler(func(event *proto.SystemEvent) {
+ // todo use new Category
+ if windowAction, ok := event.Metadata["progress_window"]; ok {
+ targetVersion, ok := event.Metadata["version"]
+ if !ok {
+ targetVersion = "unknown"
+ }
+ log.Debugf("window action: %v", windowAction)
+ if windowAction == "show" {
+ if s.updateContextCancel != nil {
+ s.updateContextCancel()
+ s.updateContextCancel = nil
+ }
+
+ subCtx, cancel := context.WithCancel(s.ctx)
+ go s.eventHandler.runSelfCommand(subCtx, "update", "--update-version", targetVersion)
+ s.updateContextCancel = cancel
+ }
+ }
+ })
go s.eventManager.Start(s.ctx)
go s.eventHandler.listen(s.ctx)
+
+ // Start sleep detection listener
+ go s.startSleepListener()
}
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
@@ -990,6 +1176,8 @@ func (s *serviceClient) onTrayExit() {
// getSrvClient connection to the service.
func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonServiceClient, error) {
+ s.connLock.Lock()
+ defer s.connLock.Unlock()
if s.conn != nil {
return s.conn, nil
}
@@ -1012,6 +1200,62 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
return s.conn, nil
}
+// startSleepListener initializes the sleep detection service and listens for sleep events
+func (s *serviceClient) startSleepListener() {
+ sleepService, err := sleep.New()
+ if err != nil {
+ log.Warnf("%v", err)
+ return
+ }
+
+ if err := sleepService.Register(s.handleSleepEvents); err != nil {
+ log.Errorf("failed to start sleep detection: %v", err)
+ return
+ }
+
+ log.Info("sleep detection service initialized")
+
+ // Cleanup on context cancellation
+ go func() {
+ <-s.ctx.Done()
+ log.Info("stopping sleep event listener")
+ if err := sleepService.Deregister(); err != nil {
+ log.Errorf("failed to deregister sleep detection: %v", err)
+ }
+ }()
+}
+
+// handleSleepEvents sends a sleep notification to the daemon via gRPC
+func (s *serviceClient) handleSleepEvents(event sleep.EventType) {
+ conn, err := s.getSrvClient(0)
+ if err != nil {
+ log.Errorf("failed to get daemon client for sleep notification: %v", err)
+ return
+ }
+
+ req := &proto.OSLifecycleRequest{}
+
+ switch event {
+ case sleep.EventTypeWakeUp:
+ log.Infof("handle wakeup event: %v", event)
+ req.Type = proto.OSLifecycleRequest_WAKEUP
+ case sleep.EventTypeSleep:
+ log.Infof("handle sleep event: %v", event)
+ req.Type = proto.OSLifecycleRequest_SLEEP
+ default:
+ log.Infof("unknown event: %v", event)
+ return
+ }
+
+ _, err = conn.NotifyOSLifecycle(s.ctx, req)
+ if err != nil {
+ log.Errorf("failed to notify daemon about os lifecycle notification: %v", err)
+ return
+ }
+
+ log.Info("successfully notified daemon about os lifecycle")
+}
+
// setSettingsEnabled enables or disables the settings menu based on the provided state
func (s *serviceClient) setSettingsEnabled(enabled bool) {
if s.mSettings != nil {
@@ -1033,19 +1277,22 @@ func (s *serviceClient) checkAndUpdateFeatures() {
return
}
+ s.updateIndicationLock.Lock()
+ defer s.updateIndicationLock.Unlock()
+
// Update settings menu based on current features
- if features != nil && features.DisableUpdateSettings {
- s.setSettingsEnabled(false)
- } else {
- s.setSettingsEnabled(true)
+ settingsEnabled := features == nil || !features.DisableUpdateSettings
+ if s.settingsEnabled != settingsEnabled {
+ s.settingsEnabled = settingsEnabled
+ s.setSettingsEnabled(settingsEnabled)
}
// Update profile menu based on current features
if s.mProfile != nil {
- if features != nil && features.DisableProfiles {
- s.mProfile.setEnabled(false)
- } else {
- s.mProfile.setEnabled(true)
+ profilesEnabled := features == nil || !features.DisableProfiles
+ if s.profilesEnabled != profilesEnabled {
+ s.profilesEnabled = profilesEnabled
+ s.mProfile.setEnabled(profilesEnabled)
}
}
}
@@ -1121,6 +1368,25 @@ func (s *serviceClient) getSrvConfig() {
s.disableServerRoutes = cfg.DisableServerRoutes
s.blockLANAccess = cfg.BlockLANAccess
+ if cfg.EnableSSHRoot != nil {
+ s.enableSSHRoot = *cfg.EnableSSHRoot
+ }
+ if cfg.EnableSSHSFTP != nil {
+ s.enableSSHSFTP = *cfg.EnableSSHSFTP
+ }
+ if cfg.EnableSSHLocalPortForwarding != nil {
+ s.enableSSHLocalPortForward = *cfg.EnableSSHLocalPortForwarding
+ }
+ if cfg.EnableSSHRemotePortForwarding != nil {
+ s.enableSSHRemotePortForward = *cfg.EnableSSHRemotePortForwarding
+ }
+ if cfg.DisableSSHAuth != nil {
+ s.disableSSHAuth = *cfg.DisableSSHAuth
+ }
+ if cfg.SSHJWTCacheTTL != nil {
+ s.sshJWTCacheTTL = *cfg.SSHJWTCacheTTL
+ }
+
if s.showAdvancedSettings {
s.iMngURL.SetText(s.managementURL)
s.iPreSharedKey.SetText(cfg.PreSharedKey)
@@ -1141,6 +1407,24 @@ func (s *serviceClient) getSrvConfig() {
s.sDisableClientRoutes.SetChecked(cfg.DisableClientRoutes)
s.sDisableServerRoutes.SetChecked(cfg.DisableServerRoutes)
s.sBlockLANAccess.SetChecked(cfg.BlockLANAccess)
+ if cfg.EnableSSHRoot != nil {
+ s.sEnableSSHRoot.SetChecked(*cfg.EnableSSHRoot)
+ }
+ if cfg.EnableSSHSFTP != nil {
+ s.sEnableSSHSFTP.SetChecked(*cfg.EnableSSHSFTP)
+ }
+ if cfg.EnableSSHLocalPortForwarding != nil {
+ s.sEnableSSHLocalPortForward.SetChecked(*cfg.EnableSSHLocalPortForwarding)
+ }
+ if cfg.EnableSSHRemotePortForwarding != nil {
+ s.sEnableSSHRemotePortForward.SetChecked(*cfg.EnableSSHRemotePortForwarding)
+ }
+ if cfg.DisableSSHAuth != nil {
+ s.sDisableSSHAuth.SetChecked(*cfg.DisableSSHAuth)
+ }
+ if cfg.SSHJWTCacheTTL != nil {
+ s.iSSHJWTCacheTTL.SetText(strconv.Itoa(*cfg.SSHJWTCacheTTL))
+ }
}
if s.mNotifications == nil {
@@ -1211,6 +1495,15 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
config.DisableServerRoutes = cfg.DisableServerRoutes
config.BlockLANAccess = cfg.BlockLanAccess
+ config.EnableSSHRoot = &cfg.EnableSSHRoot
+ config.EnableSSHSFTP = &cfg.EnableSSHSFTP
+ config.EnableSSHLocalPortForwarding = &cfg.EnableSSHLocalPortForwarding
+ config.EnableSSHRemotePortForwarding = &cfg.EnableSSHRemotePortForwarding
+ config.DisableSSHAuth = &cfg.DisableSSHAuth
+
+ ttl := int(cfg.SshJWTCacheTTL)
+ config.SSHJWTCacheTTL = &ttl
+
return &config
}
@@ -1382,7 +1675,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
return
}
- resp, err := s.login(false)
+ resp, err := s.login(ctx, false)
if err != nil {
log.Errorf("failed to fetch login URL: %v", err)
return
@@ -1402,7 +1695,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
return
}
- _, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode})
+ _, err = conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode})
if err != nil {
log.Errorf("Waiting sso login failed with: %v", err)
label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.")
@@ -1410,7 +1703,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
}
label.SetText("Re-authentication successful.\nReconnecting")
- status, err := conn.Status(s.ctx, &proto.StatusRequest{})
+ status, err := conn.Status(ctx, &proto.StatusRequest{})
if err != nil {
log.Errorf("get service status: %v", err)
return
@@ -1423,7 +1716,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
return
}
- _, err = conn.Up(s.ctx, &proto.UpRequest{})
+ _, err = conn.Up(ctx, &proto.UpRequest{})
if err != nil {
label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.")
log.Errorf("Reconnecting failed with: %v", err)
@@ -1487,6 +1780,10 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
}
func openURL(url string) error {
+ if browser := os.Getenv("BROWSER"); browser != "" {
+ return exec.Command(browser, url).Start()
+ }
+
var err error
switch runtime.GOOS {
case "windows":
diff --git a/client/ui/debug.go b/client/ui/debug.go
index 76afc7753..51fa28575 100644
--- a/client/ui/debug.go
+++ b/client/ui/debug.go
@@ -18,6 +18,7 @@ import (
"github.com/skratchdot/open-golang/open"
"github.com/netbirdio/netbird/client/internal"
+ "github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
uptypes "github.com/netbirdio/netbird/upload-server/types"
@@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData(
return "", err
}
+ pm := profilemanager.NewProfileManager()
+ var profName string
+ if activeProf, err := pm.GetActiveProfile(); err == nil {
+ profName = activeProf.Name
+ }
+
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("Failed to get post-up status: %v", err)
@@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData(
var postUpStatusOutput string
if postUpStatus != nil {
- overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "")
+ overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
@@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData(
var preDownStatusOutput string
if preDownStatus != nil {
- overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "")
+ overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
@@ -493,7 +500,7 @@ func (s *serviceClient) createDebugBundleFromCollection(
if uploadFailureReason != "" {
showUploadFailedDialog(progress.window, localPath, uploadFailureReason)
} else {
- showUploadSuccessDialog(progress.window, localPath, uploadedKey)
+ showUploadSuccessDialog(s.app, progress.window, localPath, uploadedKey)
}
} else {
showBundleCreatedDialog(progress.window, localPath)
@@ -558,7 +565,7 @@ func (s *serviceClient) handleDebugCreation(
if uploadFailureReason != "" {
showUploadFailedDialog(w, localPath, uploadFailureReason)
} else {
- showUploadSuccessDialog(w, localPath, uploadedKey)
+ showUploadSuccessDialog(s.app, w, localPath, uploadedKey)
}
} else {
showBundleCreatedDialog(w, localPath)
@@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
return nil, fmt.Errorf("get client: %v", err)
}
+ pm := profilemanager.NewProfileManager()
+ var profName string
+ if activeProf, err := pm.GetActiveProfile(); err == nil {
+ profName = activeProf.Name
+ }
+
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("failed to get status for debug bundle: %v", err)
@@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
var statusOutput string
if statusResp != nil {
- overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "")
+ overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
@@ -652,7 +665,7 @@ func showUploadFailedDialog(w fyne.Window, localPath, failureReason string) {
}
// showUploadSuccessDialog displays a dialog when upload succeeds
-func showUploadSuccessDialog(w fyne.Window, localPath, uploadedKey string) {
+func showUploadSuccessDialog(a fyne.App, w fyne.Window, localPath, uploadedKey string) {
log.Infof("Upload key: %s", uploadedKey)
keyEntry := widget.NewEntry()
keyEntry.SetText(uploadedKey)
@@ -670,7 +683,7 @@ func showUploadSuccessDialog(w fyne.Window, localPath, uploadedKey string) {
customDialog := dialog.NewCustom("Upload Successful", "OK", content, w)
copyBtn := createButtonWithAction("Copy key", func() {
- w.Clipboard().SetContent(uploadedKey)
+ a.Clipboard().SetContent(uploadedKey)
log.Info("Upload key copied to clipboard")
})
diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go
index e9b7f4f30..9ffacd926 100644
--- a/client/ui/event_handler.go
+++ b/client/ui/event_handler.go
@@ -12,6 +12,8 @@ import (
"fyne.io/fyne/v2"
"fyne.io/systray"
log "github.com/sirupsen/logrus"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
@@ -67,20 +69,55 @@ func (h *eventHandler) listen(ctx context.Context) {
func (h *eventHandler) handleConnectClick() {
h.client.mUp.Disable()
+
+ if h.client.connectCancel != nil {
+ h.client.connectCancel()
+ }
+
+ connectCtx, connectCancel := context.WithCancel(h.client.ctx)
+ h.client.connectCancel = connectCancel
+
go func() {
- defer h.client.mUp.Enable()
- if err := h.client.menuUpClick(); err != nil {
- h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
+ defer connectCancel()
+
+ if err := h.client.menuUpClick(connectCtx, true); err != nil {
+ st, ok := status.FromError(err)
+ if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
+ log.Debugf("connect operation cancelled by user")
+ } else {
+ h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect"))
+ log.Errorf("connect failed: %v", err)
+ }
+ }
+
+ if err := h.client.updateStatus(); err != nil {
+ log.Debugf("failed to update status after connect: %v", err)
}
}()
}
func (h *eventHandler) handleDisconnectClick() {
h.client.mDown.Disable()
+
+ if h.client.connectCancel != nil {
+ log.Debugf("cancelling ongoing connect operation")
+ h.client.connectCancel()
+ h.client.connectCancel = nil
+ }
+
go func() {
- defer h.client.mDown.Enable()
if err := h.client.menuDownClick(); err != nil {
- h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird daemon"))
+ st, ok := status.FromError(err)
+ if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) {
+ h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect"))
+ log.Errorf("disconnect failed: %v", err)
+ } else {
+ log.Debugf("disconnect cancelled or already disconnecting")
+ }
+ }
+
+ if err := h.client.updateStatus(); err != nil {
+ log.Debugf("failed to update status after disconnect: %v", err)
}
}()
}
@@ -148,7 +185,7 @@ func (h *eventHandler) handleAdvancedSettingsClick() {
go func() {
defer h.client.mAdvancedSettings.Enable()
defer h.client.getSrvConfig()
- h.runSelfCommand(h.client.ctx, "settings", "true")
+ h.runSelfCommand(h.client.ctx, "settings")
}()
}
@@ -156,7 +193,7 @@ func (h *eventHandler) handleCreateDebugBundleClick() {
h.client.mCreateDebugBundle.Disable()
go func() {
defer h.client.mCreateDebugBundle.Enable()
- h.runSelfCommand(h.client.ctx, "debug", "true")
+ h.runSelfCommand(h.client.ctx, "debug")
}()
}
@@ -180,7 +217,7 @@ func (h *eventHandler) handleNetworksClick() {
h.client.mNetworks.Disable()
go func() {
defer h.client.mNetworks.Enable()
- h.runSelfCommand(h.client.ctx, "networks", "true")
+ h.runSelfCommand(h.client.ctx, "networks")
}()
}
@@ -200,17 +237,21 @@ func (h *eventHandler) updateConfigWithErr() error {
return nil
}
-func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string) {
+func (h *eventHandler) runSelfCommand(ctx context.Context, command string, args ...string) {
proc, err := os.Executable()
if err != nil {
log.Errorf("error getting executable path: %v", err)
return
}
- cmd := exec.CommandContext(ctx, proc,
- fmt.Sprintf("--%s=%s", command, arg),
+ // Build the full command arguments
+ cmdArgs := []string{
+ fmt.Sprintf("--%s=true", command),
fmt.Sprintf("--daemon-addr=%s", h.client.addr),
- )
+ }
+ cmdArgs = append(cmdArgs, args...)
+
+ cmd := exec.CommandContext(ctx, proc, cmdArgs...)
if out := h.client.attachOutput(cmd); out != nil {
defer func() {
@@ -220,17 +261,17 @@ func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string)
}()
}
- log.Printf("running command: %s --%s=%s --daemon-addr=%s", proc, command, arg, h.client.addr)
+ log.Printf("running command: %s", cmd.String())
if err := cmd.Run(); err != nil {
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
- log.Printf("command '%s %s' failed with exit code %d", command, arg, exitErr.ExitCode())
+ log.Printf("command '%s' failed with exit code %d", cmd.String(), exitErr.ExitCode())
}
return
}
- log.Printf("command '%s %s' completed successfully", command, arg)
+ log.Printf("command '%s' completed successfully", cmd.String())
}
func (h *eventHandler) logout(ctx context.Context) error {
@@ -245,6 +286,6 @@ func (h *eventHandler) logout(ctx context.Context) error {
}
h.client.getSrvConfig()
-
+
return nil
}
diff --git a/client/ui/icons.go b/client/ui/icons.go
index e88fb9378..874f24fdd 100644
--- a/client/ui/icons.go
+++ b/client/ui/icons.go
@@ -9,6 +9,9 @@ import (
//go:embed assets/netbird.png
var iconAbout []byte
+//go:embed assets/netbird-disconnected.png
+var iconAboutDisconnected []byte
+
//go:embed assets/netbird-systemtray-connected.png
var iconConnected []byte
diff --git a/client/ui/icons_windows.go b/client/ui/icons_windows.go
index 2107d3852..bd57b2690 100644
--- a/client/ui/icons_windows.go
+++ b/client/ui/icons_windows.go
@@ -7,6 +7,9 @@ import (
//go:embed assets/netbird.ico
var iconAbout []byte
+//go:embed assets/netbird-disconnected.ico
+var iconAboutDisconnected []byte
+
//go:embed assets/netbird-systemtray-connected.ico
var iconConnected []byte
diff --git a/client/ui/process/process.go b/client/ui/process/process.go
index d0ef54896..28276f416 100644
--- a/client/ui/process/process.go
+++ b/client/ui/process/process.go
@@ -28,7 +28,8 @@ func IsAnotherProcessRunning() (int32, bool, error) {
continue
}
- if strings.Contains(strings.ToLower(runningProcessPath), processName) && isProcessOwnedByCurrentUser(p) {
+ runningProcessName := strings.ToLower(filepath.Base(runningProcessPath))
+ if runningProcessName == processName && isProcessOwnedByCurrentUser(p) {
return p.Pid, true, nil
}
}
diff --git a/client/ui/profile.go b/client/ui/profile.go
index 075223795..a38d8918a 100644
--- a/client/ui/profile.go
+++ b/client/ui/profile.go
@@ -387,6 +387,7 @@ type subItem struct {
type profileMenu struct {
mu sync.Mutex
ctx context.Context
+ serviceClient *serviceClient
profileManager *profilemanager.ProfileManager
eventHandler *eventHandler
profileMenuItem *systray.MenuItem
@@ -396,7 +397,7 @@ type profileMenu struct {
logoutSubItem *subItem
profilesState []Profile
downClickCallback func() error
- upClickCallback func() error
+ upClickCallback func(context.Context, bool) error
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
loadSettingsCallback func()
app fyne.App
@@ -404,12 +405,13 @@ type profileMenu struct {
type newProfileMenuArgs struct {
ctx context.Context
+ serviceClient *serviceClient
profileManager *profilemanager.ProfileManager
eventHandler *eventHandler
profileMenuItem *systray.MenuItem
emailMenuItem *systray.MenuItem
downClickCallback func() error
- upClickCallback func() error
+ upClickCallback func(context.Context, bool) error
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
loadSettingsCallback func()
app fyne.App
@@ -418,6 +420,7 @@ type newProfileMenuArgs struct {
func newProfileMenu(args newProfileMenuArgs) *profileMenu {
p := profileMenu{
ctx: args.ctx,
+ serviceClient: args.serviceClient,
profileManager: args.profileManager,
eventHandler: args.eventHandler,
profileMenuItem: args.profileMenuItem,
@@ -569,10 +572,19 @@ func (p *profileMenu) refresh() {
}
}
- if err := p.upClickCallback(); err != nil {
+ if p.serviceClient.connectCancel != nil {
+ p.serviceClient.connectCancel()
+ }
+
+ connectCtx, connectCancel := context.WithCancel(p.ctx)
+ p.serviceClient.connectCancel = connectCancel
+
+ if err := p.upClickCallback(connectCtx, false); err != nil {
log.Errorf("failed to handle up click after switching profile: %v", err)
}
+ connectCancel()
+
p.refresh()
p.loadSettingsCallback()
}
diff --git a/client/ui/quickactions.go b/client/ui/quickactions.go
new file mode 100644
index 000000000..76440d684
--- /dev/null
+++ b/client/ui/quickactions.go
@@ -0,0 +1,349 @@
+//go:build !(linux && 386)
+
+//go:generate fyne bundle -o quickactions_assets.go assets/connected.png
+//go:generate fyne bundle -o quickactions_assets.go -append assets/disconnected.png
+package main
+
+import (
+ "context"
+ _ "embed"
+ "fmt"
+ "runtime"
+ "sync/atomic"
+ "time"
+
+ "fyne.io/fyne/v2"
+ "fyne.io/fyne/v2/canvas"
+ "fyne.io/fyne/v2/container"
+ "fyne.io/fyne/v2/layout"
+ "fyne.io/fyne/v2/widget"
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/internal"
+ "github.com/netbirdio/netbird/client/proto"
+)
+
+type quickActionsUiState struct {
+ connectionStatus string
+ isToggleButtonEnabled bool
+ isConnectionChanged bool
+ toggleAction func()
+}
+
+func newQuickActionsUiState() quickActionsUiState {
+ return quickActionsUiState{
+ connectionStatus: string(internal.StatusIdle),
+ isToggleButtonEnabled: false,
+ isConnectionChanged: false,
+ }
+}
+
+type clientConnectionStatusProvider interface {
+ connectionStatus(ctx context.Context) (string, error)
+}
+
+type daemonClientConnectionStatusProvider struct {
+ client proto.DaemonServiceClient
+}
+
+func (d daemonClientConnectionStatusProvider) connectionStatus(ctx context.Context) (string, error) {
+ childCtx, cancel := context.WithTimeout(ctx, 400*time.Millisecond)
+ defer cancel()
+ status, err := d.client.Status(childCtx, &proto.StatusRequest{})
+ if err != nil {
+ return "", err
+ }
+
+ return status.Status, nil
+}
+
+type clientCommand interface {
+ execute() error
+}
+
+type connectCommand struct {
+ connectClient func() error
+}
+
+func (c connectCommand) execute() error {
+ return c.connectClient()
+}
+
+type disconnectCommand struct {
+ disconnectClient func() error
+}
+
+func (c disconnectCommand) execute() error {
+ return c.disconnectClient()
+}
+
+type quickActionsViewModel struct {
+ provider clientConnectionStatusProvider
+ connect clientCommand
+ disconnect clientCommand
+ uiChan chan quickActionsUiState
+ isWatchingConnectionStatus atomic.Bool
+}
+
+func newQuickActionsViewModel(ctx context.Context, provider clientConnectionStatusProvider, connect, disconnect clientCommand, uiChan chan quickActionsUiState) {
+ viewModel := quickActionsViewModel{
+ provider: provider,
+ connect: connect,
+ disconnect: disconnect,
+ uiChan: uiChan,
+ }
+
+ viewModel.isWatchingConnectionStatus.Store(true)
+
+ // base UI status
+ uiChan <- newQuickActionsUiState()
+
+ // this retrieves the current connection status
+ // and pushes the UI state that reflects it via uiChan
+ go viewModel.watchConnectionStatus(ctx)
+}
+
+func (q *quickActionsViewModel) updateUiState(ctx context.Context) {
+ uiState := newQuickActionsUiState()
+ connectionStatus, err := q.provider.connectionStatus(ctx)
+
+ if err != nil {
+ log.Errorf("Status: Error - %v", err)
+ q.uiChan <- uiState
+ return
+ }
+
+ if connectionStatus == string(internal.StatusConnected) {
+ uiState.toggleAction = func() {
+ q.executeCommand(q.disconnect)
+ }
+ } else {
+ uiState.toggleAction = func() {
+ q.executeCommand(q.connect)
+ }
+ }
+
+ uiState.isToggleButtonEnabled = true
+ uiState.connectionStatus = connectionStatus
+ q.uiChan <- uiState
+}
+
+func (q *quickActionsViewModel) watchConnectionStatus(ctx context.Context) {
+ ticker := time.NewTicker(1000 * time.Millisecond)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-ticker.C:
+ if q.isWatchingConnectionStatus.Load() {
+ q.updateUiState(ctx)
+ }
+ }
+ }
+}
+
+func (q *quickActionsViewModel) executeCommand(command clientCommand) {
+ uiState := newQuickActionsUiState()
+ // newQuickActionsUiState starts with Idle connection status,
+ // and all that's necessary here is to just disable the toggle button.
+ uiState.connectionStatus = ""
+
+ q.uiChan <- uiState
+
+ q.isWatchingConnectionStatus.Store(false)
+
+ err := command.execute()
+
+ if err != nil {
+ log.Errorf("Status: Error - %v", err)
+ q.isWatchingConnectionStatus.Store(true)
+ } else {
+ uiState = newQuickActionsUiState()
+ uiState.isConnectionChanged = true
+ q.uiChan <- uiState
+ }
+}
+
+func getSystemTrayName() string {
+ os := runtime.GOOS
+ switch os {
+ case "darwin":
+ return "menu bar"
+ default:
+ return "system tray"
+ }
+}
+
+func (s *serviceClient) getNetBirdImage(name string, content []byte) *canvas.Image {
+ imageSize := fyne.NewSize(64, 64)
+
+ resource := fyne.NewStaticResource(name, content)
+ image := canvas.NewImageFromResource(resource)
+ image.FillMode = canvas.ImageFillContain
+ image.SetMinSize(imageSize)
+ image.Resize(imageSize)
+
+ return image
+}
+
+type quickActionsUiComponents struct {
+ content *fyne.Container
+ toggleConnectionButton *widget.Button
+ connectedLabelText, disconnectedLabelText string
+ connectedImage, disconnectedImage *canvas.Image
+ connectedCircleRes, disconnectedCircleRes fyne.Resource
+}
+
+// applyQuickActionsUiState applies a single UI state to the quick actions window.
+// It closes the window and returns true if the connection status has changed,
+// in which case the caller should stop processing further states.
+func (s *serviceClient) applyQuickActionsUiState(
+ uiState quickActionsUiState,
+ components quickActionsUiComponents,
+) bool {
+ if uiState.isConnectionChanged {
+ fyne.DoAndWait(func() {
+ s.wQuickActions.Close()
+ })
+ return true
+ }
+
+ var logo *canvas.Image
+ var buttonText string
+ var buttonIcon fyne.Resource
+
+ if uiState.connectionStatus == string(internal.StatusConnected) {
+ buttonText = components.connectedLabelText
+ buttonIcon = components.connectedCircleRes
+ logo = components.connectedImage
+ } else if uiState.connectionStatus == string(internal.StatusIdle) {
+ buttonText = components.disconnectedLabelText
+ buttonIcon = components.disconnectedCircleRes
+ logo = components.disconnectedImage
+ }
+
+ fyne.DoAndWait(func() {
+ if buttonText != "" {
+ components.toggleConnectionButton.SetText(buttonText)
+ }
+
+ if buttonIcon != nil {
+ components.toggleConnectionButton.SetIcon(buttonIcon)
+ }
+
+ if uiState.isToggleButtonEnabled {
+ components.toggleConnectionButton.Enable()
+ } else {
+ components.toggleConnectionButton.Disable()
+ }
+
+ components.toggleConnectionButton.OnTapped = func() {
+ if uiState.toggleAction != nil {
+ go uiState.toggleAction()
+ }
+ }
+
+ components.toggleConnectionButton.Refresh()
+
+ // the second position in the content's object array is the NetBird logo.
+ if logo != nil {
+ components.content.Objects[1] = logo
+ components.content.Refresh()
+ }
+ })
+
+ return false
+}
+
+// showQuickActionsUI displays a simple window with the NetBird logo and a connection toggle button.
+func (s *serviceClient) showQuickActionsUI() {
+ s.wQuickActions = s.app.NewWindow("NetBird")
+ vmCtx, vmCancel := context.WithCancel(s.ctx)
+ s.wQuickActions.SetOnClosed(vmCancel)
+
+ client, err := s.getSrvClient(defaultFailTimeout)
+
+ connCmd := connectCommand{
+ connectClient: func() error {
+ return s.menuUpClick(s.ctx, false)
+ },
+ }
+
+ disConnCmd := disconnectCommand{
+ disconnectClient: func() error {
+ return s.menuDownClick()
+ },
+ }
+
+ if err != nil {
+ log.Errorf("get service client: %v", err)
+ return
+ }
+
+ uiChan := make(chan quickActionsUiState, 1)
+ newQuickActionsViewModel(vmCtx, daemonClientConnectionStatusProvider{client: client}, connCmd, disConnCmd, uiChan)
+
+ connectedImage := s.getNetBirdImage("netbird.png", iconAbout)
+ disconnectedImage := s.getNetBirdImage("netbird-disconnected.png", iconAboutDisconnected)
+
+ connectedCircle := canvas.NewImageFromResource(resourceConnectedPng)
+ disconnectedCircle := canvas.NewImageFromResource(resourceDisconnectedPng)
+
+ connectedLabelText := "Disconnect"
+ disconnectedLabelText := "Connect"
+
+ toggleConnectionButton := widget.NewButtonWithIcon(disconnectedLabelText, disconnectedCircle.Resource, func() {
+ // This button's tap function will be set when an ui state arrives via the uiChan channel.
+ })
+
+ // Button starts disabled until the first ui state arrives.
+ toggleConnectionButton.Disable()
+
+ hintLabelText := fmt.Sprintf("You can always access NetBird from your %s.", getSystemTrayName())
+ hintLabel := widget.NewLabel(hintLabelText)
+
+ content := container.NewVBox(
+ layout.NewSpacer(),
+ disconnectedImage,
+ layout.NewSpacer(),
+ container.NewCenter(toggleConnectionButton),
+ layout.NewSpacer(),
+ container.NewCenter(hintLabel),
+ )
+
+ // this watches for ui state updates.
+ go func() {
+
+ for {
+ select {
+ case <-vmCtx.Done():
+ return
+ case uiState, ok := <-uiChan:
+ if !ok {
+ return
+ }
+
+ closed := s.applyQuickActionsUiState(
+ uiState,
+ quickActionsUiComponents{
+ content,
+ toggleConnectionButton,
+ connectedLabelText, disconnectedLabelText,
+ connectedImage, disconnectedImage,
+ connectedCircle.Resource, disconnectedCircle.Resource,
+ },
+ )
+ if closed {
+ return
+ }
+ }
+ }
+ }()
+
+ s.wQuickActions.SetContent(content)
+ s.wQuickActions.Resize(fyne.NewSize(400, 200))
+ s.wQuickActions.SetFixedSize(true)
+ s.wQuickActions.Show()
+}
diff --git a/client/ui/quickactions_assets.go b/client/ui/quickactions_assets.go
new file mode 100644
index 000000000..9ff5e85a2
--- /dev/null
+++ b/client/ui/quickactions_assets.go
@@ -0,0 +1,23 @@
+// auto-generated
+// Code generated by '$ fyne bundle'. DO NOT EDIT.
+
+package main
+
+import (
+ _ "embed"
+ "fyne.io/fyne/v2"
+)
+
+//go:embed assets/connected.png
+var resourceConnectedPngData []byte
+var resourceConnectedPng = &fyne.StaticResource{
+ StaticName: "assets/connected.png",
+ StaticContent: resourceConnectedPngData,
+}
+
+//go:embed assets/disconnected.png
+var resourceDisconnectedPngData []byte
+var resourceDisconnectedPng = &fyne.StaticResource{
+ StaticName: "assets/disconnected.png",
+ StaticContent: resourceDisconnectedPngData,
+}
diff --git a/client/ui/signal_unix.go b/client/ui/signal_unix.go
new file mode 100644
index 000000000..99de99f0f
--- /dev/null
+++ b/client/ui/signal_unix.go
@@ -0,0 +1,76 @@
+//go:build !windows && !(linux && 386)
+
+package main
+
+import (
+ "context"
+ "os"
+ "os/exec"
+ "os/signal"
+ "syscall"
+
+ log "github.com/sirupsen/logrus"
+)
+
+// setupSignalHandler sets up a signal handler to listen for SIGUSR1.
+// When received, it opens the quick actions window.
+func (s *serviceClient) setupSignalHandler(ctx context.Context) {
+ sigChan := make(chan os.Signal, 1)
+ signal.Notify(sigChan, syscall.SIGUSR1)
+
+ go func() {
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-sigChan:
+ log.Info("received SIGUSR1 signal, opening quick actions window")
+ s.openQuickActions()
+ }
+ }
+ }()
+}
+
+// openQuickActions opens the quick actions window by spawning a new process.
+func (s *serviceClient) openQuickActions() {
+ proc, err := os.Executable()
+ if err != nil {
+ log.Errorf("get executable path: %v", err)
+ return
+ }
+
+ cmd := exec.CommandContext(s.ctx, proc,
+ "--quick-actions=true",
+ "--daemon-addr="+s.addr,
+ )
+
+ if out := s.attachOutput(cmd); out != nil {
+ defer func() {
+ if err := out.Close(); err != nil {
+ log.Errorf("close log file %s: %v", s.logFile, err)
+ }
+ }()
+ }
+
+ log.Infof("running command: %s --quick-actions=true --daemon-addr=%s", proc, s.addr)
+
+ if err := cmd.Start(); err != nil {
+ log.Errorf("start quick actions window: %v", err)
+ return
+ }
+
+ go func() {
+ if err := cmd.Wait(); err != nil {
+ log.Debugf("quick actions window exited: %v", err)
+ }
+ }()
+}
+
+// sendShowWindowSignal sends SIGUSR1 to the specified PID.
+func sendShowWindowSignal(pid int32) error {
+ process, err := os.FindProcess(int(pid))
+ if err != nil {
+ return err
+ }
+ return process.Signal(syscall.SIGUSR1)
+}
diff --git a/client/ui/signal_windows.go b/client/ui/signal_windows.go
new file mode 100644
index 000000000..ca98be526
--- /dev/null
+++ b/client/ui/signal_windows.go
@@ -0,0 +1,171 @@
+//go:build windows
+
+package main
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "os/exec"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/sys/windows"
+)
+
+const (
+ quickActionsTriggerEventName = `Global\NetBirdQuickActionsTriggerEvent`
+ waitTimeout = 5 * time.Second
+ // SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
+ desiredAccesses = windows.SYNCHRONIZE | windows.EVENT_MODIFY_STATE
+)
+
+func getEventNameUint16Pointer() (*uint16, error) {
+ eventNamePtr, err := windows.UTF16PtrFromString(quickActionsTriggerEventName)
+ if err != nil {
+ log.Errorf("Failed to convert event name '%s' to UTF16: %v", quickActionsTriggerEventName, err)
+ return nil, err
+ }
+
+ return eventNamePtr, nil
+}
+
+// setupSignalHandler sets up signal handling for Windows.
+// Windows doesn't support SIGUSR1, so this uses a similar approach using windows.Events.
+func (s *serviceClient) setupSignalHandler(ctx context.Context) {
+ eventNamePtr, err := getEventNameUint16Pointer()
+ if err != nil {
+ return
+ }
+
+ eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
+
+ if err != nil {
+ if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
+ log.Warnf("Quick actions trigger event '%s' already exists. Attempting to open.", quickActionsTriggerEventName)
+ eventHandle, err = windows.OpenEvent(desiredAccesses, false, eventNamePtr)
+ if err != nil {
+ log.Errorf("Failed to open existing quick actions trigger event '%s': %v", quickActionsTriggerEventName, err)
+ return
+ }
+ log.Infof("Successfully opened existing quick actions trigger event '%s'.", quickActionsTriggerEventName)
+ } else {
+ log.Errorf("Failed to create quick actions trigger event '%s': %v", quickActionsTriggerEventName, err)
+ return
+ }
+ }
+
+ if eventHandle == windows.InvalidHandle {
+ log.Errorf("Obtained an invalid handle for quick actions trigger event '%s'", quickActionsTriggerEventName)
+ return
+ }
+
+ log.Infof("Quick actions handler waiting for signal on event: %s", quickActionsTriggerEventName)
+
+ go s.waitForEvent(ctx, eventHandle)
+}
+
+func (s *serviceClient) waitForEvent(ctx context.Context, eventHandle windows.Handle) {
+ defer func() {
+ if err := windows.CloseHandle(eventHandle); err != nil {
+ log.Errorf("Failed to close quick actions event handle '%s': %v", quickActionsTriggerEventName, err)
+ }
+ }()
+
+ for {
+ if ctx.Err() != nil {
+ return
+ }
+
+ status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
+
+ switch status {
+ case windows.WAIT_OBJECT_0:
+ log.Info("Received signal on quick actions event. Opening quick actions window.")
+
+ // reset the event so it can be triggered again later (manual reset == 1)
+ if err := windows.ResetEvent(eventHandle); err != nil {
+ log.Errorf("Failed to reset quick actions event '%s': %v", quickActionsTriggerEventName, err)
+ }
+
+ s.openQuickActions()
+ case uint32(windows.WAIT_TIMEOUT):
+
+ default:
+ if isDone := logUnexpectedStatus(ctx, status, err); isDone {
+ return
+ }
+ }
+ }
+}
+
+func logUnexpectedStatus(ctx context.Context, status uint32, err error) bool {
+ log.Errorf("Unexpected status %d from WaitForSingleObject for quick actions event '%s': %v",
+ status, quickActionsTriggerEventName, err)
+ select {
+ case <-time.After(5 * time.Second):
+ return false
+ case <-ctx.Done():
+ return true
+ }
+}
+
+// openQuickActions opens the quick actions window by spawning a new process.
+func (s *serviceClient) openQuickActions() {
+ proc, err := os.Executable()
+ if err != nil {
+ log.Errorf("get executable path: %v", err)
+ return
+ }
+
+ cmd := exec.CommandContext(s.ctx, proc,
+ "--quick-actions=true",
+ "--daemon-addr="+s.addr,
+ )
+
+ if out := s.attachOutput(cmd); out != nil {
+ defer func() {
+ if err := out.Close(); err != nil {
+ log.Errorf("close log file %s: %v", s.logFile, err)
+ }
+ }()
+ }
+
+ log.Infof("running command: %s --quick-actions=true --daemon-addr=%s", proc, s.addr)
+
+ if err := cmd.Start(); err != nil {
+ log.Errorf("error starting quick actions window: %v", err)
+ return
+ }
+
+ go func() {
+ if err := cmd.Wait(); err != nil {
+ log.Debugf("quick actions window exited: %v", err)
+ }
+ }()
+}
+
+func sendShowWindowSignal(pid int32) error {
+ _, err := os.FindProcess(int(pid))
+ if err != nil {
+ return err
+ }
+
+ eventNamePtr, err := getEventNameUint16Pointer()
+ if err != nil {
+ return err
+ }
+
+ eventHandle, err := windows.OpenEvent(desiredAccesses, false, eventNamePtr)
+ if err != nil {
+ return err
+ }
+
+ err = windows.SetEvent(eventHandle)
+ if err != nil {
+ return fmt.Errorf("Error setting event: %w", err)
+ }
+
+ return nil
+}
diff --git a/client/ui/update.go b/client/ui/update.go
new file mode 100644
index 000000000..25c317bdf
--- /dev/null
+++ b/client/ui/update.go
@@ -0,0 +1,140 @@
+//go:build !(linux && 386)
+
+package main
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ "fyne.io/fyne/v2/container"
+ "fyne.io/fyne/v2/widget"
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/proto"
+)
+
+func (s *serviceClient) showUpdateProgress(ctx context.Context, version string) {
+ log.Infof("show installer progress window: %s", version)
+ s.wUpdateProgress = s.app.NewWindow("Automatically updating client")
+
+ statusLabel := widget.NewLabel("Updating...")
+ infoLabel := widget.NewLabel(fmt.Sprintf("Your client version is older than the auto-update version set in Management.\nUpdating client to: %s.", version))
+ content := container.NewVBox(infoLabel, statusLabel)
+ s.wUpdateProgress.SetContent(content)
+ s.wUpdateProgress.CenterOnScreen()
+ s.wUpdateProgress.SetFixedSize(true)
+ s.wUpdateProgress.SetCloseIntercept(func() {
+ // this is empty to lock window until result known
+ })
+ s.wUpdateProgress.RequestFocus()
+ s.wUpdateProgress.Show()
+
+ updateWindowCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
+
+ // Initialize dot updater
+ updateText := dotUpdater()
+
+ // Channel to receive the result from RPC call
+ resultErrCh := make(chan error, 1)
+ resultOkCh := make(chan struct{}, 1)
+
+ // Start RPC call in background
+ go func() {
+ conn, err := s.getSrvClient(defaultFailTimeout)
+ if err != nil {
+ log.Infof("backend not reachable, upgrade in progress: %v", err)
+ close(resultOkCh)
+ return
+ }
+
+ resp, err := conn.GetInstallerResult(updateWindowCtx, &proto.InstallerResultRequest{})
+ if err != nil {
+ log.Infof("backend stopped responding, upgrade in progress: %v", err)
+ close(resultOkCh)
+ return
+ }
+
+ if !resp.Success {
+ resultErrCh <- mapInstallError(resp.ErrorMsg)
+ return
+ }
+
+ // Success
+ close(resultOkCh)
+ }()
+
+ // Update UI with dots and wait for result
+ go func() {
+ ticker := time.NewTicker(time.Second)
+ defer ticker.Stop()
+ defer cancel()
+
+ // allow closing update window after 10 sec
+ timerResetCloseInterceptor := time.NewTimer(10 * time.Second)
+ defer timerResetCloseInterceptor.Stop()
+
+ for {
+ select {
+ case <-updateWindowCtx.Done():
+ s.showInstallerResult(statusLabel, updateWindowCtx.Err())
+ return
+ case err := <-resultErrCh:
+ s.showInstallerResult(statusLabel, err)
+ return
+ case <-resultOkCh:
+ log.Info("backend exited, upgrade in progress, closing all UI")
+ killParentUIProcess()
+ s.app.Quit()
+ return
+ case <-ticker.C:
+ statusLabel.SetText(updateText())
+ case <-timerResetCloseInterceptor.C:
+ s.wUpdateProgress.SetCloseIntercept(nil)
+ }
+ }
+ }()
+}
+
+func (s *serviceClient) showInstallerResult(statusLabel *widget.Label, err error) {
+ s.wUpdateProgress.SetCloseIntercept(nil)
+ switch {
+ case errors.Is(err, context.DeadlineExceeded):
+ log.Warn("update watcher timed out")
+ statusLabel.SetText("Update timed out. Please try again.")
+ case errors.Is(err, context.Canceled):
+ log.Info("update watcher canceled")
+ statusLabel.SetText("Update canceled.")
+ case err != nil:
+ log.Errorf("update failed: %v", err)
+ statusLabel.SetText("Update failed: " + err.Error())
+ default:
+ s.wUpdateProgress.Close()
+ }
+}
+
+// dotUpdater returns a closure that cycles through dots for a loading animation.
+func dotUpdater() func() string {
+ dotCount := 0
+ return func() string {
+ dotCount = (dotCount + 1) % 4
+ return fmt.Sprintf("%s%s", "Updating", strings.Repeat(".", dotCount))
+ }
+}
+
+func mapInstallError(msg string) error {
+ msg = strings.ToLower(strings.TrimSpace(msg))
+
+ switch {
+ case strings.Contains(msg, "deadline exceeded"), strings.Contains(msg, "timeout"):
+ return context.DeadlineExceeded
+ case strings.Contains(msg, "canceled"), strings.Contains(msg, "cancelled"):
+ return context.Canceled
+ case msg == "":
+ return errors.New("unknown update error")
+ default:
+ return errors.New(msg)
+ }
+}
diff --git a/client/ui/update_notwindows.go b/client/ui/update_notwindows.go
new file mode 100644
index 000000000..5766f18f7
--- /dev/null
+++ b/client/ui/update_notwindows.go
@@ -0,0 +1,7 @@
+//go:build !windows && !(linux && 386)
+
+package main
+
+func killParentUIProcess() {
+ // No-op on non-Windows platforms
+}
diff --git a/client/ui/update_windows.go b/client/ui/update_windows.go
new file mode 100644
index 000000000..1b03936f9
--- /dev/null
+++ b/client/ui/update_windows.go
@@ -0,0 +1,44 @@
+//go:build windows
+
+package main
+
+import (
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/sys/windows"
+
+ nbprocess "github.com/netbirdio/netbird/client/ui/process"
+)
+
+// killParentUIProcess finds and kills the parent systray UI process on Windows.
+// This is a workaround in case the MSI installer fails to properly terminate the UI process.
+// The installer should handle this via util:CloseApplication with TerminateProcess, but this
+// provides an additional safety mechanism to ensure the UI is closed before the upgrade proceeds.
+func killParentUIProcess() {
+ pid, running, err := nbprocess.IsAnotherProcessRunning()
+ if err != nil {
+ log.Warnf("failed to check for parent UI process: %v", err)
+ return
+ }
+
+ if !running {
+ log.Debug("no parent UI process found to kill")
+ return
+ }
+
+ log.Infof("killing parent UI process (PID: %d)", pid)
+
+ // Open the process with terminate rights
+ handle, err := windows.OpenProcess(windows.PROCESS_TERMINATE, false, uint32(pid))
+ if err != nil {
+ log.Warnf("failed to open parent process %d: %v", pid, err)
+ return
+ }
+ defer func() {
+ _ = windows.CloseHandle(handle)
+ }()
+
+ // Terminate the process with exit code 0
+ if err := windows.TerminateProcess(handle, 0); err != nil {
+ log.Warnf("failed to terminate parent process %d: %v", pid, err)
+ }
+}
diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go
index d542e2739..238e272fa 100644
--- a/client/wasm/cmd/main.go
+++ b/client/wasm/cmd/main.go
@@ -11,6 +11,7 @@ import (
log "github.com/sirupsen/logrus"
netbird "github.com/netbirdio/netbird/client/embed"
+ sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/wasm/internal/http"
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
@@ -18,9 +19,10 @@ import (
)
const (
- clientStartTimeout = 30 * time.Second
- clientStopTimeout = 10 * time.Second
- defaultLogLevel = "warn"
+ clientStartTimeout = 30 * time.Second
+ clientStopTimeout = 10 * time.Second
+ defaultLogLevel = "warn"
+ defaultSSHDetectionTimeout = 20 * time.Second
)
func main() {
@@ -125,10 +127,15 @@ func createSSHMethod(client *netbird.Client) js.Func {
username = args[2].String()
}
+ var jwtToken string
+ if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() {
+ jwtToken = args[3].String()
+ }
+
return createPromise(func(resolve, reject js.Value) {
sshClient := ssh.NewClient(client)
- if err := sshClient.Connect(host, port, username); err != nil {
+ if err := sshClient.Connect(host, port, username, jwtToken); err != nil {
reject.Invoke(err.Error())
return
}
@@ -191,12 +198,46 @@ func createPromise(handler func(resolve, reject js.Value)) js.Value {
}))
}
+// createDetectSSHServerMethod creates the SSH server detection method
+func createDetectSSHServerMethod(client *netbird.Client) js.Func {
+ return js.FuncOf(func(this js.Value, args []js.Value) any {
+ if len(args) < 2 {
+ return js.ValueOf("error: requires host and port")
+ }
+
+ host := args[0].String()
+ port := args[1].Int()
+
+ timeoutMs := int(defaultSSHDetectionTimeout.Milliseconds())
+ if len(args) >= 3 && !args[2].IsNull() && !args[2].IsUndefined() {
+ timeoutMs = args[2].Int()
+ if timeoutMs <= 0 {
+ return js.ValueOf("error: timeout must be positive")
+ }
+ }
+
+ return createPromise(func(resolve, reject js.Value) {
+ ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
+ defer cancel()
+
+ serverType, err := sshdetection.DetectSSHServerType(ctx, client, host, port)
+ if err != nil {
+ reject.Invoke(err.Error())
+ return
+ }
+
+ resolve.Invoke(js.ValueOf(serverType.RequiresJWT()))
+ })
+ })
+}
+
// createClientObject wraps the NetBird client in a JavaScript object
func createClientObject(client *netbird.Client) js.Value {
obj := make(map[string]interface{})
obj["start"] = createStartMethod(client)
obj["stop"] = createStopMethod(client)
+ obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
obj["createSSHConnection"] = createSSHMethod(client)
obj["proxyRequest"] = createProxyRequestMethod(client)
obj["createRDPProxy"] = createRDPProxyMethod(client)
diff --git a/client/wasm/internal/rdp/cert_validation.go b/client/wasm/internal/rdp/cert_validation.go
index 4a23a4bc8..1678c3996 100644
--- a/client/wasm/internal/rdp/cert_validation.go
+++ b/client/wasm/internal/rdp/cert_validation.go
@@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
}
}
-func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config {
- return &tls.Config{
+func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config {
+ config := &tls.Config{
InsecureSkipVerify: true, // We'll validate manually after handshake
VerifyConnection: func(cs tls.ConnectionState) error {
var certChain [][]byte
@@ -93,4 +93,15 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tl
return nil
},
}
+
+ // CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3
+ if requiresCredSSP {
+ config.MinVersion = tls.VersionTLS12
+ config.MaxVersion = tls.VersionTLS12
+ } else {
+ config.MinVersion = tls.VersionTLS12
+ config.MaxVersion = tls.VersionTLS13
+ }
+
+ return config
}
diff --git a/client/wasm/internal/rdp/rdcleanpath.go b/client/wasm/internal/rdp/rdcleanpath.go
index 8062a05cc..16bf63bb9 100644
--- a/client/wasm/internal/rdp/rdcleanpath.go
+++ b/client/wasm/internal/rdp/rdcleanpath.go
@@ -6,11 +6,13 @@ import (
"context"
"crypto/tls"
"encoding/asn1"
+ "errors"
"fmt"
"io"
"net"
"sync"
"syscall/js"
+ "time"
log "github.com/sirupsen/logrus"
)
@@ -19,18 +21,34 @@ const (
RDCleanPathVersion = 3390
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
RDCleanPathProxyScheme = "ws"
+
+ rdpDialTimeout = 15 * time.Second
+
+ GeneralErrorCode = 1
+ WSAETimedOut = 10060
+ WSAEConnRefused = 10061
+ WSAEConnAborted = 10053
+ WSAEConnReset = 10054
+ WSAEGenericError = 10050
)
type RDCleanPathPDU struct {
- Version int64 `asn1:"tag:0,explicit"`
- Error []byte `asn1:"tag:1,explicit,optional"`
- Destination string `asn1:"utf8,tag:2,explicit,optional"`
- ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
- ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
- PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
- X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
- ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
- ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
+ Version int64 `asn1:"tag:0,explicit"`
+ Error RDCleanPathErr `asn1:"tag:1,explicit,optional"`
+ Destination string `asn1:"utf8,tag:2,explicit,optional"`
+ ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
+ ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
+ PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
+ X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
+ ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
+ ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
+}
+
+type RDCleanPathErr struct {
+ ErrorCode int16 `asn1:"tag:0,explicit"`
+ HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"`
+ WSALastError int16 `asn1:"tag:2,explicit,optional"`
+ TLSAlertCode int8 `asn1:"tag:3,explicit,optional"`
}
type RDCleanPathProxy struct {
@@ -210,9 +228,13 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
destination := conn.destination
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
- rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
+ ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
+ defer cancel()
+
+ rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
+ p.sendRDCleanPathError(conn, newWSAError(err))
return
}
conn.rdpConn = rdpConn
@@ -220,6 +242,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
_, err = rdpConn.Write(firstPacket)
if err != nil {
log.Errorf("Failed to write first packet: %v", err)
+ p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -227,6 +250,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
n, err := rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
+ p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -269,3 +293,52 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
}
}
+
+func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) {
+ data, err := asn1.Marshal(pdu)
+ if err != nil {
+ log.Errorf("Failed to marshal error PDU: %v", err)
+ return
+ }
+ p.sendToWebSocket(conn, data)
+}
+
+func errorToWSACode(err error) int16 {
+ if err == nil {
+ return WSAEGenericError
+ }
+ var netErr *net.OpError
+ if errors.As(err, &netErr) && netErr.Timeout() {
+ return WSAETimedOut
+ }
+ if errors.Is(err, context.DeadlineExceeded) {
+ return WSAETimedOut
+ }
+ if errors.Is(err, context.Canceled) {
+ return WSAEConnAborted
+ }
+ if errors.Is(err, io.EOF) {
+ return WSAEConnReset
+ }
+ return WSAEGenericError
+}
+
+func newWSAError(err error) RDCleanPathPDU {
+ return RDCleanPathPDU{
+ Version: RDCleanPathVersion,
+ Error: RDCleanPathErr{
+ ErrorCode: GeneralErrorCode,
+ WSALastError: errorToWSACode(err),
+ },
+ }
+}
+
+func newHTTPError(statusCode int16) RDCleanPathPDU {
+ return RDCleanPathPDU{
+ Version: RDCleanPathVersion,
+ Error: RDCleanPathErr{
+ ErrorCode: GeneralErrorCode,
+ HTTPStatusCode: statusCode,
+ },
+ }
+}
diff --git a/client/wasm/internal/rdp/rdcleanpath_handlers.go b/client/wasm/internal/rdp/rdcleanpath_handlers.go
index 010efa5ea..97bb46338 100644
--- a/client/wasm/internal/rdp/rdcleanpath_handlers.go
+++ b/client/wasm/internal/rdp/rdcleanpath_handlers.go
@@ -3,6 +3,7 @@
package rdp
import (
+ "context"
"crypto/tls"
"encoding/asn1"
"io"
@@ -11,11 +12,17 @@ import (
log "github.com/sirupsen/logrus"
)
+const (
+ // MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP)
+ protocolSSL = 0x00000001
+ protocolHybridEx = 0x00000008
+)
+
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
if pdu.Version != RDCleanPathVersion {
- p.sendRDCleanPathError(conn, "Unsupported version")
+ p.sendRDCleanPathError(conn, newHTTPError(400))
return
}
@@ -24,10 +31,13 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
destination = pdu.Destination
}
- rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
+ ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
+ defer cancel()
+
+ rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
- p.sendRDCleanPathError(conn, "Connection failed")
+ p.sendRDCleanPathError(conn, newWSAError(err))
p.cleanupConnection(conn)
return
}
@@ -40,6 +50,34 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
p.setupTLSConnection(conn, pdu)
}
+// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required.
+// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags.
+// Returns (requiresTLS12, selectedProtocol, detectionSuccessful).
+func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) {
+ const minResponseLength = 19
+
+ if len(x224Response) < minResponseLength {
+ return false, 0, false
+ }
+
+ // Per X.224 specification:
+ // x224Response[0] == 0x03: Length of X.224 header (3 bytes)
+ // x224Response[5] == 0xD0: X.224 Data TPDU code
+ if x224Response[0] != 0x03 || x224Response[5] != 0xD0 {
+ return false, 0, false
+ }
+
+ if x224Response[11] == 0x02 {
+ flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 |
+ uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24
+
+ hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0
+ return hasNLA, flags, true
+ }
+
+ return false, 0, false
+}
+
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
var x224Response []byte
if len(pdu.X224ConnectionPDU) > 0 {
@@ -47,7 +85,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
- p.sendRDCleanPathError(conn, "Failed to forward X.224")
+ p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -55,21 +93,32 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
- p.sendRDCleanPathError(conn, "Failed to read X.224 response")
+ p.sendRDCleanPathError(conn, newWSAError(err))
return
}
x224Response = response[:n]
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
}
- tlsConfig := p.getTLSConfigWithValidation(conn)
+ requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response)
+ if detected {
+ if requiresCredSSP {
+ log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol)
+ } else {
+ log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol)
+ }
+ } else {
+ log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3")
+ }
+
+ tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP)
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
conn.tlsConn = tlsConn
if err := tlsConn.Handshake(); err != nil {
log.Errorf("TLS handshake failed: %v", err)
- p.sendRDCleanPathError(conn, "TLS handshake failed")
+ p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -106,47 +155,6 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
p.cleanupConnection(conn)
}
-func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
- if len(pdu.X224ConnectionPDU) > 0 {
- log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
- _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
- if err != nil {
- log.Errorf("Failed to write X.224 PDU: %v", err)
- p.sendRDCleanPathError(conn, "Failed to forward X.224")
- return
- }
-
- response := make([]byte, 1024)
- n, err := conn.rdpConn.Read(response)
- if err != nil {
- log.Errorf("Failed to read X.224 response: %v", err)
- p.sendRDCleanPathError(conn, "Failed to read X.224 response")
- return
- }
-
- responsePDU := RDCleanPathPDU{
- Version: RDCleanPathVersion,
- X224ConnectionPDU: response[:n],
- ServerAddr: conn.destination,
- }
-
- p.sendRDCleanPathPDU(conn, responsePDU)
- } else {
- responsePDU := RDCleanPathPDU{
- Version: RDCleanPathVersion,
- ServerAddr: conn.destination,
- }
- p.sendRDCleanPathPDU(conn, responsePDU)
- }
-
- go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
- go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
-
- <-conn.ctx.Done()
- log.Debug("TCP connection context done, cleaning up")
- p.cleanupConnection(conn)
-}
-
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
@@ -158,21 +166,6 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean
p.sendToWebSocket(conn, data)
}
-func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
- pdu := RDCleanPathPDU{
- Version: RDCleanPathVersion,
- Error: []byte(errorMsg),
- }
-
- data, err := asn1.Marshal(pdu)
- if err != nil {
- log.Errorf("Failed to marshal error PDU: %v", err)
- return
- }
-
- p.sendToWebSocket(conn, data)
-}
-
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
msgChan := make(chan []byte)
errChan := make(chan error)
diff --git a/client/wasm/internal/ssh/client.go b/client/wasm/internal/ssh/client.go
index ca35525eb..568437e56 100644
--- a/client/wasm/internal/ssh/client.go
+++ b/client/wasm/internal/ssh/client.go
@@ -13,6 +13,7 @@ import (
"golang.org/x/crypto/ssh"
netbird "github.com/netbirdio/netbird/client/embed"
+ nbssh "github.com/netbirdio/netbird/client/ssh"
)
const (
@@ -45,34 +46,19 @@ func NewClient(nbClient *netbird.Client) *Client {
}
// Connect establishes an SSH connection through NetBird network
-func (c *Client) Connect(host string, port int, username string) error {
+func (c *Client) Connect(host string, port int, username, jwtToken string) error {
addr := fmt.Sprintf("%s:%d", host, port)
logrus.Infof("SSH: Connecting to %s as %s", addr, username)
- var authMethods []ssh.AuthMethod
-
- nbConfig, err := c.nbClient.GetConfig()
+ authMethods, err := c.getAuthMethods(jwtToken)
if err != nil {
- return fmt.Errorf("get NetBird config: %w", err)
+ return err
}
- if nbConfig.SSHKey == "" {
- return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization")
- }
-
- signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey))
- if err != nil {
- return fmt.Errorf("parse NetBird SSH private key: %w", err)
- }
-
- pubKey := signer.PublicKey()
- logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type())
-
- authMethods = append(authMethods, ssh.PublicKeys(signer))
config := &ssh.ClientConfig{
User: username,
Auth: authMethods,
- HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+ HostKeyCallback: nbssh.CreateHostKeyCallback(c.nbClient),
Timeout: sshDialTimeout,
}
@@ -96,6 +82,33 @@ func (c *Client) Connect(host string, port int, username string) error {
return nil
}
+// getAuthMethods returns SSH authentication methods, preferring JWT if available
+func (c *Client) getAuthMethods(jwtToken string) ([]ssh.AuthMethod, error) {
+ if jwtToken != "" {
+ logrus.Debugf("SSH: Using JWT password authentication")
+ return []ssh.AuthMethod{ssh.Password(jwtToken)}, nil
+ }
+
+ logrus.Debugf("SSH: No JWT token, using public key authentication")
+
+ nbConfig, err := c.nbClient.GetConfig()
+ if err != nil {
+ return nil, fmt.Errorf("get NetBird config: %w", err)
+ }
+
+ if nbConfig.SSHKey == "" {
+ return nil, fmt.Errorf("no NetBird SSH key available")
+ }
+
+ signer, err := ssh.ParsePrivateKey([]byte(nbConfig.SSHKey))
+ if err != nil {
+ return nil, fmt.Errorf("parse NetBird SSH private key: %w", err)
+ }
+
+ logrus.Debugf("SSH: Added public key auth")
+ return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil
+}
+
// StartSession starts an SSH session with PTY
func (c *Client) StartSession(cols, rows int) error {
if c.sshClient == nil {
diff --git a/client/wasm/internal/ssh/key.go b/client/wasm/internal/ssh/key.go
deleted file mode 100644
index 4868ba30a..000000000
--- a/client/wasm/internal/ssh/key.go
+++ /dev/null
@@ -1,50 +0,0 @@
-//go:build js
-
-package ssh
-
-import (
- "crypto/x509"
- "encoding/pem"
- "fmt"
- "strings"
-
- "github.com/sirupsen/logrus"
- "golang.org/x/crypto/ssh"
-)
-
-// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format
-func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) {
- keyStr := string(keyPEM)
- if !strings.Contains(keyStr, "-----BEGIN") {
- keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----")
- }
-
- signer, err := ssh.ParsePrivateKey(keyPEM)
- if err == nil {
- return signer, nil
- }
- logrus.Debugf("SSH: Failed to parse as SSH format: %v", err)
-
- block, _ := pem.Decode(keyPEM)
- if block == nil {
- keyPreview := string(keyPEM)
- if len(keyPreview) > 100 {
- keyPreview = keyPreview[:100]
- }
- return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview)
- }
-
- key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
- if err != nil {
- logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err)
- if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
- return ssh.NewSignerFromKey(rsaKey)
- }
- if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
- return ssh.NewSignerFromKey(ecKey)
- }
- return nil, fmt.Errorf("parse private key: %w", err)
- }
-
- return ssh.NewSignerFromKey(key)
-}
diff --git a/dns/dns.go b/dns/dns.go
index f889a32ec..aa0e16eb1 100644
--- a/dns/dns.go
+++ b/dns/dns.go
@@ -19,6 +19,10 @@ const (
RootZone = "."
// DefaultClass is the class supported by the system
DefaultClass = "IN"
+ // ForwarderClientPort is the port clients connect to. DNAT rewrites packets from ForwarderClientPort to ForwarderServerPort.
+ ForwarderClientPort uint16 = 5353
+ // ForwarderServerPort is the port the DNS forwarder actually listens on. Packets to ForwarderClientPort are DNATed here.
+ ForwarderServerPort uint16 = 22054
)
const invalidHostLabel = "[^a-zA-Z0-9-]+"
@@ -31,6 +35,8 @@ type Config struct {
NameServerGroups []*NameServerGroup
// CustomZones contains a list of custom zone
CustomZones []CustomZone
+ // ForwarderPort is the port clients should connect to on routing peers for DNS forwarding
+ ForwarderPort uint16
}
// CustomZone represents a custom zone to be resolved by the dns server
@@ -39,6 +45,10 @@ type CustomZone struct {
Domain string
// Records custom zone records
Records []SimpleRecord
+ // SearchDomainDisabled indicates whether to add match domains to a search domains list or not
+ SearchDomainDisabled bool
+ // SkipPTRProcess indicates whether a client should process PTR records from custom zones
+ SkipPTRProcess bool
}
// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records
diff --git a/formatter/hook/hook.go b/formatter/hook/hook.go
index c0d8c4eba..f0ee509f8 100644
--- a/formatter/hook/hook.go
+++ b/formatter/hook/hook.go
@@ -60,14 +60,7 @@ func (hook ContextHook) Fire(entry *logrus.Entry) error {
entry.Data["context"] = source
- switch source {
- case HTTPSource:
- addHTTPFields(entry)
- case GRPCSource:
- addGRPCFields(entry)
- case SystemSource:
- addSystemFields(entry)
- }
+ addFields(entry)
return nil
}
@@ -99,7 +92,7 @@ func (hook ContextHook) parseSrc(filePath string) string {
return fmt.Sprintf("%s/%s", pkg, file)
}
-func addHTTPFields(entry *logrus.Entry) {
+func addFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
@@ -109,30 +102,6 @@ func addHTTPFields(entry *logrus.Entry) {
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}
-}
-
-func addGRPCFields(entry *logrus.Entry) {
- if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
- entry.Data[context.RequestIDKey] = ctxReqID
- }
- if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
- entry.Data[context.AccountIDKey] = ctxAccountID
- }
- if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
- entry.Data[context.PeerIDKey] = ctxDeviceID
- }
-}
-
-func addSystemFields(entry *logrus.Entry) {
- if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
- entry.Data[context.RequestIDKey] = ctxReqID
- }
- if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
- entry.Data[context.UserIDKey] = ctxInitiatorID
- }
- if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
- entry.Data[context.AccountIDKey] = ctxAccountID
- }
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
entry.Data[context.PeerIDKey] = ctxDeviceID
}
diff --git a/go.mod b/go.mod
index a1560b409..8f4ec530b 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
module github.com/netbirdio/netbird
-go 1.23.0
+go 1.24.10
require (
cunicu.li/go-rosenpass v0.4.0
@@ -16,9 +16,9 @@ require (
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
- github.com/vishvananda/netlink v1.3.0
- golang.org/x/crypto v0.40.0
- golang.org/x/sys v0.34.0
+ github.com/vishvananda/netlink v1.3.1
+ golang.org/x/crypto v0.45.0
+ golang.org/x/sys v0.38.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
@@ -28,9 +28,10 @@ require (
)
require (
- fyne.io/fyne/v2 v2.5.3
- fyne.io/systray v1.11.0
+ fyne.io/fyne/v2 v2.7.0
+ fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
+ github.com/awnumar/memguard v0.23.0
github.com/aws/aws-sdk-go-v2 v1.36.3
github.com/aws/aws-sdk-go-v2/config v1.29.14
github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2
@@ -43,7 +44,7 @@ require (
github.com/eko/gocache/lib/v4 v4.2.0
github.com/eko/gocache/store/go_cache/v4 v4.2.2
github.com/eko/gocache/store/redis/v4 v4.2.2
- github.com/fsnotify/fsnotify v1.7.0
+ github.com/fsnotify/fsnotify v1.9.0
github.com/gliderlabs/ssh v0.3.8
github.com/godbus/dbus/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.3.0
@@ -56,13 +57,14 @@ require (
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
github.com/hashicorp/go-version v1.6.0
+ github.com/jackc/pgx/v5 v5.5.5
github.com/libdns/route53 v1.5.0
github.com/libp2p/go-netroute v0.2.1
+ github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
- github.com/nadoo/ipset v0.5.0
- github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0
+ github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
@@ -75,14 +77,15 @@ require (
github.com/pion/stun/v3 v3.0.0
github.com/pion/transport/v3 v3.0.7
github.com/pion/turn/v3 v3.0.1
+ github.com/pkg/sftp v1.13.9
github.com/prometheus/client_golang v1.22.0
- github.com/quic-go/quic-go v0.48.2
+ github.com/quic-go/quic-go v0.49.1
github.com/redis/go-redis/v9 v9.7.3
github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.4
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
- github.com/stretchr/testify v1.10.0
+ github.com/stretchr/testify v1.11.1
github.com/testcontainers/testcontainers-go v0.31.0
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0
@@ -98,15 +101,17 @@ require (
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
go.opentelemetry.io/otel/metric v1.35.0
go.opentelemetry.io/otel/sdk/metric v1.35.0
+ go.uber.org/mock v0.5.0
go.uber.org/zap v1.27.0
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
- golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
- golang.org/x/mod v0.25.0
- golang.org/x/net v0.42.0
- golang.org/x/oauth2 v0.28.0
- golang.org/x/sync v0.16.0
- golang.org/x/term v0.33.0
+ golang.org/x/mobile v0.0.0-20251113184115-a159579294ab
+ golang.org/x/mod v0.30.0
+ golang.org/x/net v0.47.0
+ golang.org/x/oauth2 v0.30.0
+ golang.org/x/sync v0.18.0
+ golang.org/x/term v0.37.0
+ golang.org/x/time v0.12.0
google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7
@@ -123,10 +128,11 @@ require (
dario.cat/mergo v1.0.0 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
- github.com/BurntSushi/toml v1.4.0 // indirect
+ github.com/BurntSushi/toml v1.5.0 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/Microsoft/hcsshim v0.12.3 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
+ github.com/awnumar/memcall v0.4.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect
@@ -146,7 +152,7 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
- github.com/containerd/containerd v1.7.27 // indirect
+ github.com/containerd/containerd v1.7.29 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
@@ -157,11 +163,12 @@ require (
github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
- github.com/fredbi/uri v1.1.0 // indirect
- github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe // indirect
- github.com/fyne-io/glfw-js v0.0.0-20241126112943-313d8a0fe1d0 // indirect
- github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 // indirect
- github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6 // indirect
+ github.com/fredbi/uri v1.1.1 // indirect
+ github.com/fyne-io/gl-js v0.2.0 // indirect
+ github.com/fyne-io/glfw-js v0.3.0 // indirect
+ github.com/fyne-io/image v0.1.1 // indirect
+ github.com/fyne-io/oksvg v0.2.0 // indirect
+ github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
@@ -169,7 +176,7 @@ require (
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/go-text/render v0.2.0 // indirect
- github.com/go-text/typesetting v0.2.0 // indirect
+ github.com/go-text/typesetting v0.2.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/btree v1.1.2 // indirect
@@ -177,22 +184,23 @@ require (
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
- github.com/gopherjs/gopherjs v1.17.2 // indirect
+ github.com/hack-pad/go-indexeddb v0.3.2 // indirect
+ github.com/hack-pad/safejs v0.1.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
- github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
- github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
+ github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
- github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
+ github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 // indirect
github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
+ github.com/kr/fs v0.1.0 // indirect
github.com/libdns/libdns v0.2.2 // indirect
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
github.com/magiconair/properties v1.8.7 // indirect
@@ -208,7 +216,8 @@ require (
github.com/moby/term v0.5.0 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
- github.com/nicksnyder/go-i18n/v2 v2.4.0 // indirect
+ github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect
+ github.com/nicksnyder/go-i18n/v2 v2.5.1 // indirect
github.com/nxadm/tail v1.4.8 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
@@ -224,29 +233,27 @@ require (
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
- github.com/rymdport/portal v0.3.0 // indirect
+ github.com/rymdport/portal v0.4.2 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tklauser/go-sysconf v0.3.14 // indirect
github.com/tklauser/numcpus v0.8.0 // indirect
- github.com/vishvananda/netns v0.0.4 // indirect
+ github.com/vishvananda/netns v0.0.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/wlynxg/anet v0.0.3 // indirect
- github.com/yuin/goldmark v1.7.1 // indirect
+ github.com/yuin/goldmark v1.7.8 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
go.opentelemetry.io/otel/trace v1.35.0 // indirect
- go.uber.org/mock v0.4.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
- golang.org/x/image v0.18.0 // indirect
- golang.org/x/text v0.27.0 // indirect
- golang.org/x/time v0.5.0 // indirect
- golang.org/x/tools v0.34.0 // indirect
+ golang.org/x/image v0.33.0 // indirect
+ golang.org/x/text v0.31.0 // indirect
+ golang.org/x/tools v0.39.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
diff --git a/go.sum b/go.sum
index 13838b82d..f10e1e6da 100644
--- a/go.sum
+++ b/go.sum
@@ -1,67 +1,28 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
-cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
-cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU=
-cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU=
-cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY=
-cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc=
-cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0=
-cloud.google.com/go v0.50.0/go.mod h1:r9sluTvynVuxRIOHXQEHMFffphuXHOMZMycpNR5e6To=
-cloud.google.com/go v0.52.0/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4=
-cloud.google.com/go v0.53.0/go.mod h1:fp/UouUEsRkN6ryDKNW/Upv/JBKnv6WDthjR6+vze6M=
-cloud.google.com/go v0.54.0/go.mod h1:1rq2OEkV3YMf6n/9ZvGWI3GWw0VoqH/1x2nd8Is/bPc=
-cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKVk=
-cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs=
-cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc=
-cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY=
-cloud.google.com/go v0.72.0/go.mod h1:M+5Vjvlc2wnp6tjzE102Dw08nGShTscUx2nZMufOKPI=
-cloud.google.com/go v0.74.0/go.mod h1:VV1xSbzvo+9QJOxLDaJfTjx5e+MePCpCWwvftOeQmWk=
-cloud.google.com/go v0.78.0/go.mod h1:QjdrLG0uq+YwhjoVOLsS1t7TW8fs36kLs4XO5R5ECHg=
-cloud.google.com/go v0.79.0/go.mod h1:3bzgcEeQlzbuEAYu4mrWhKqWjmpprinYgKJLgKHnbb8=
-cloud.google.com/go v0.81.0/go.mod h1:mk/AM35KwGk/Nm2YSeZbxXdrNK3KZOYHmLkOqC2V6E0=
cloud.google.com/go/auth v0.3.0 h1:PRyzEpGfx/Z9e8+lHsbkoUVXD0gnu4MNmm7Gp8TQNIs=
cloud.google.com/go/auth v0.3.0/go.mod h1:lBv6NKTWp8E3LPzmO1TbiiRKc4drLOfHsgmlH9ogv5w=
cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4=
cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q=
-cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o=
-cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE=
-cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc=
-cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg=
-cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc=
-cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ=
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
-cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
-cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk=
-cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqClKRT5SZwBmk=
-cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I=
-cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw=
-cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA=
-cloud.google.com/go/pubsub v1.3.1/go.mod h1:i+ucay31+CNRpDW4Lu78I4xXG+O1r/MAHgjpRVR+TSU=
-cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw=
-cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos=
-cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk=
-cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs=
-cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0=
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
-dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
-fyne.io/fyne/v2 v2.5.3 h1:k6LjZx6EzRZhClsuzy6vucLZBstdH2USDGHSGWq8ly8=
-fyne.io/fyne/v2 v2.5.3/go.mod h1:0GOXKqyvNwk3DLmsFu9v0oYM0ZcD1ysGnlHCerKoAmo=
-fyne.io/systray v1.11.0 h1:D9HISlxSkx+jHSniMBR6fCFOUjk1x/OOOJLa9lJYAKg=
-fyne.io/systray v1.11.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
+fyne.io/fyne/v2 v2.7.0 h1:GvZSpE3X0liU/fqstInVvRsaboIVpIWQ4/sfjDGIGGQ=
+fyne.io/fyne/v2 v2.7.0/go.mod h1:xClVlrhxl7D+LT+BWYmcrW4Nf+dJTvkhnPgji7spAwE=
+fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58 h1:eA5/u2XRd8OUkoMqEv3IBlFYSruNlXD8bRHDiqm0VNI=
+fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
-github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0=
-github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
-github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
+github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
+github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/Microsoft/hcsshim v0.12.3 h1:LS9NXqXhMoqNCplK1ApmVSfB4UnVLRDWRapB6EIlxE0=
@@ -70,10 +31,10 @@ github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJT
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
-github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
-github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
-github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
-github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
+github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g=
+github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w=
+github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A=
+github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M=
github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM=
github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 h1:zAybnyUQXIZ5mok5Jqwlf58/TFE7uvd3IAsa1aF9cXs=
@@ -114,8 +75,6 @@ github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ=
github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
-github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs=
-github.com/bketelsen/crypt v0.0.4/go.mod h1:aI6NrJ0pMGgvZKL1iVgXLnfIFJtfV+bKCoqOes/6LfM=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
@@ -138,23 +97,18 @@ github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
-github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
-github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
-github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII=
-github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0=
+github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE=
+github.com/containerd/containerd v1.7.29/go.mod h1:azUkWcOvHrWvaiUjSQH0fjzuHIwSPg1WL5PshGP4Szs=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8=
github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
-github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
-github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
-github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
@@ -182,37 +136,31 @@ github.com/eko/gocache/store/redis/v4 v4.2.2/go.mod h1:LaTxLKx9TG/YUEybQvPMij++D
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
-github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po=
-github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
-github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
-github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g=
github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
-github.com/fredbi/uri v1.1.0 h1:OqLpTXtyRg9ABReqvDGdJPqZUxs8cyBDOMXBbskCaB8=
-github.com/fredbi/uri v1.1.0/go.mod h1:aYTUoAXBOq7BLfVJ8GnKmfcuURosB1xyHDIfWeC/iW4=
+github.com/fredbi/uri v1.1.1 h1:xZHJC08GZNIUhbP5ImTHnt5Ya0T8FI2VAwI/37kh2Ko=
+github.com/fredbi/uri v1.1.1/go.mod h1:4+DZQ5zBjEwQCDmXW5JdIjz0PUA+yJbvtBv+u+adr5o=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
-github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
-github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
-github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe h1:A/wiwvQ0CAjPkuJytaD+SsXkPU0asQ+guQEIg1BJGX4=
-github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe/go.mod h1:d4clgH0/GrRwWjRzJJQXxT/h1TyuNSfF/X64zb/3Ggg=
-github.com/fyne-io/glfw-js v0.0.0-20241126112943-313d8a0fe1d0 h1:/1YRWFv9bAWkoo3SuxpFfzpXH0D/bQnTjNXyF4ih7Os=
-github.com/fyne-io/glfw-js v0.0.0-20241126112943-313d8a0fe1d0/go.mod h1:gsGA2dotD4v0SR6PmPCYvS9JuOeMwAtmfvDE7mbYXMY=
-github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 h1:hnLq+55b7Zh7/2IRzWCpiTcAvjv/P8ERF+N7+xXbZhk=
-github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2/go.mod h1:eO7W361vmlPOrykIg+Rsh1SZ3tQBaOsfzZhsIOb/Lm0=
-github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
+github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
+github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
+github.com/fyne-io/gl-js v0.2.0 h1:+EXMLVEa18EfkXBVKhifYB6OGs3HwKO3lUElA0LlAjs=
+github.com/fyne-io/gl-js v0.2.0/go.mod h1:ZcepK8vmOYLu96JoxbCKJy2ybr+g1pTnaBDdl7c3ajI=
+github.com/fyne-io/glfw-js v0.3.0 h1:d8k2+Y7l+zy2pc7wlGRyPfTgZoqDf3AI4G+2zOWhWUk=
+github.com/fyne-io/glfw-js v0.3.0/go.mod h1:Ri6te7rdZtBgBpxLW19uBpp3Dl6K9K/bRaYdJ22G8Jk=
+github.com/fyne-io/image v0.1.1 h1:WH0z4H7qfvNUw5l4p3bC1q70sa5+YWVt6HCj7y4VNyA=
+github.com/fyne-io/image v0.1.1/go.mod h1:xrfYBh6yspc+KjkgdZU/ifUC9sPA5Iv7WYUBzQKK7JM=
+github.com/fyne-io/oksvg v0.2.0 h1:mxcGU2dx6nwjJsSA9PCYZDuoAcsZ/OuJlvg/Q9Njfo8=
+github.com/fyne-io/oksvg v0.2.0/go.mod h1:dJ9oEkPiWhnTFNCmRgEze+YNprJF7YRbpjgpWS4kzoI=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmCsR2Do=
github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c=
github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU=
-github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6 h1:zDw5v7qm4yH7N8C8uWd+8Ii9rROdgWxQuGoJ9WDXxfk=
-github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6/go.mod h1:9YTyiznxEY1fVinfM7RvRcjRHbw2xLBJ3AAGIT0I4Nw=
-github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
-github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
-github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
+github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 h1:5BVwOaUSBTlVZowGO6VZGw2H/zl9nrd3eCZfYV+NfQA=
+github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71/go.mod h1:9YTyiznxEY1fVinfM7RvRcjRHbw2xLBJ3AAGIT0I4Nw=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+xzxf1jTJKMKZw3H0swfWk9RpWbBbDK5+0=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
@@ -237,11 +185,10 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEe
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/go-text/render v0.2.0 h1:LBYoTmp5jYiJ4NPqDc2pz17MLmA3wHw1dZSVGcOdeAc=
github.com/go-text/render v0.2.0/go.mod h1:CkiqfukRGKJA5vZZISkjSYrcdtgKQWRa2HIzvwNN5SU=
-github.com/go-text/typesetting v0.2.0 h1:fbzsgbmk04KiWtE+c3ZD4W2nmCRzBqrqQOvYlwAOdho=
-github.com/go-text/typesetting v0.2.0/go.mod h1:2+owI/sxa73XA581LAzVuEBZ3WEEV2pXeDswCH/3i1I=
-github.com/go-text/typesetting-utils v0.0.0-20240317173224-1986cbe96c66 h1:GUrm65PQPlhFSKjLPGOZNPNxLCybjzjYBzjfoBGaDUY=
-github.com/go-text/typesetting-utils v0.0.0-20240317173224-1986cbe96c66/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o=
-github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
+github.com/go-text/typesetting v0.2.1 h1:x0jMOGyO3d1qFAPI0j4GSsh7M0Q3Ypjzr4+CEVg82V8=
+github.com/go-text/typesetting v0.2.1/go.mod h1:mTOxEwasOFpAMBjEQDhdWRckoLLeI/+qrQeBCTGEt6M=
+github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066 h1:qCuYC+94v2xrb1PoS4NIDe7DGYtLnU2wWiQe9a1B1c0=
+github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o=
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
@@ -249,27 +196,15 @@ github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
-github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
-github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
-github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
-github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
-github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
-github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
-github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
-github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4=
-github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
-github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
-github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk=
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
@@ -279,25 +214,18 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
-github.com/golang/protobuf v1.5.1/go.mod h1:DopwsBzvsk0Fs44TXzsVbJyPhcCPeIwnvohx4u74HPM=
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
-github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
-github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
@@ -308,25 +236,10 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
-github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
-github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
-github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/nftables v0.3.0 h1:bkyZ0cbpVeMHXOrtlFc8ISmfVqq5gPJukoYieyVmITg=
github.com/google/nftables v0.3.0/go.mod h1:BCp9FsrbF1Fn/Yu6CLUc9GGZFw/+hsxfluNXXmxBfRM=
-github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
-github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
-github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
-github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
-github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
-github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
-github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
-github.com/google/pprof v0.0.0-20201023163331-3e6fc7fc9c4c/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
-github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
-github.com/google/pprof v0.0.0-20210122040257-d980be63207e/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
-github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd h1:1FjCyPC+syAzJ5/2S8fqdZK1R22vvA0J7JZKcuOIQ7Y=
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg=
-github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -334,61 +247,34 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs=
github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
-github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
-github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/googleapis/gax-go/v2 v2.12.3 h1:5/zPPDvw8Q1SuXjrqrZslrqT7dL/uJT2CQii/cLCKqA=
github.com/googleapis/gax-go/v2 v2.12.3/go.mod h1:AKloxT6GtNbaLm8QTNSidHUVsHYcBHwWRvkNFJUQcS4=
github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw=
github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs=
-github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
-github.com/gopherjs/gopherjs v0.0.0-20211219123610-ec9572f70e60/go.mod h1:cz9oNYuRUWGdHmLF2IodMLkAhcPtXeULvcBNagUrxTI=
-github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g=
-github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
-github.com/goxjs/gl v0.0.0-20210104184919-e3fafc6f8f2a/go.mod h1:dy/f2gjY09hwVfIyATps4G2ai7/hLwLkc5TrPqONuXY=
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 h1:Fkzd8ktnpOR9h47SXHe2AYPwelXLH2GjGsjlAloiWfo=
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357/go.mod h1:w9Y7gY31krpLmrVU5ZPG9H7l9fZuRu5/3R3S3FMtVQ4=
-github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo=
-github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 h1:YBftPWNWd4WwGqtY2yeZL2ef8rHAxPBD8KFhJpmcqms=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0/go.mod h1:YN5jB8ie0yfIUg6VvR9Kz84aCaG7AsGZnLjhHbUqwPg=
-github.com/hashicorp/consul/api v1.1.0/go.mod h1:VmuI/Lkw1nC05EYQWNKwWGbkg+FbDBtguAZLlVdkD9Q=
-github.com/hashicorp/consul/sdk v0.1.1/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8=
+github.com/hack-pad/go-indexeddb v0.3.2 h1:DTqeJJYc1usa45Q5r52t01KhvlSN02+Oq+tQbSBI91A=
+github.com/hack-pad/go-indexeddb v0.3.2/go.mod h1:QvfTevpDVlkfomY498LhstjwbPW6QC4VC/lxYb0Kom0=
+github.com/hack-pad/safejs v0.1.0 h1:qPS6vjreAqh2amUqj4WNG1zIw7qlRQJ9K10eDKMCnE8=
+github.com/hack-pad/safejs v0.1.0/go.mod h1:HdS+bKF1NrE72VoXZeWzxFOVQVUSqZJAG0xNCnb+Tio=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
-github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80=
-github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60=
-github.com/hashicorp/go-msgpack v0.5.3/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM=
-github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
-github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw=
-github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU=
-github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4=
-github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
-github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek=
github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
-github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90=
-github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
-github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
-github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
-github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64=
-github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ=
-github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I=
-github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
-github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
-github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w=
-github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
@@ -399,8 +285,8 @@ github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
-github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 h1:Po+wkNdMmN+Zj1tDsJQy7mJlPlwGNQd9JZoPjObagf8=
-github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49/go.mod h1:YiutDnxPRLk5DLUFj6Rw4pRBBURZY07GFr54NdV9mQg=
+github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade h1:FmusiCI1wHw+XQbvL9M+1r/C3SPqKrmBaIOYwVfQoDE=
+github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade/go.mod h1:ZDXo8KHryOWSIqnsb/CiDq7hQUYryCgdVnxbj8tDG7o=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
@@ -410,12 +296,8 @@ github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHW
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
-github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
-github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
-github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
-github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e h1:LvL4XsI70QxOGHed6yhQtAU34Kx3Qq2wwBzGFKY8zKk=
-github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
-github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
+github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
+github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8=
github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
@@ -425,12 +307,10 @@ github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYW
github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
+github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
-github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
-github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
-github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
@@ -442,14 +322,13 @@ github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s=
github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ=
github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA=
github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q=
+github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU=
+github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tAFlj1FYZl8ztUZ13bdq+PLY+NOfbyI=
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
-github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
-github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
-github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4=
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
@@ -461,21 +340,12 @@ github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
-github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
-github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
-github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
-github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI=
-github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg=
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
-github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY=
-github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
-github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
-github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
@@ -490,29 +360,26 @@ github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
-github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
-github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
-github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
-github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
-github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
-github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ=
-github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
+github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY=
+github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ=
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
-github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
-github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
+github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
+github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
+github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
+github.com/nicksnyder/go-i18n/v2 v2.5.1/go.mod h1:DrhgsSDZxoAfvVrBVLXoxZn/pN5TXqaDbq7ju94viiQ=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
@@ -534,10 +401,8 @@ github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQ
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs=
github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY=
-github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
-github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8=
github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0=
github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
@@ -569,15 +434,14 @@ github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8=
-github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/profile v1.7.0 h1:hnbDkaNWPCLMO9wGLdBFTIZvzDrDfBM2072E1S9gJkA=
github.com/pkg/profile v1.7.0/go.mod h1:8Uer0jas47ZQMJ7VD+OHknK4YDY07LPUC6dEvqDjvNo=
-github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI=
+github.com/pkg/sftp v1.13.9 h1:4NGkvGudBL7GteO3m6qnaQ4pC0Kvf0onSVc9gR3EWBw=
+github.com/pkg/sftp v1.13.9/go.mod h1:OBN7bVXdstkFFN/gdnHPUb5TE8eb8G1Rp9wCItqjkkA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
@@ -590,51 +454,35 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
-github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE=
-github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
+github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0=
+github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s=
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
-github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
-github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so=
github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM=
github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4=
github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
-github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
-github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
-github.com/rymdport/portal v0.3.0 h1:QRHcwKwx3kY5JTQcsVhmhC3TGqGQb9LFghVNUy8AdB8=
-github.com/rymdport/portal v0.3.0/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4=
-github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=
+github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU=
+github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4=
github.com/shirou/gopsutil/v3 v3.24.4 h1:dEHgzZXt4LMNm+oYELpzl9YCqV65Yr/6SfrvgRBtXeU=
github.com/shirou/gopsutil/v3 v3.24.4/go.mod h1:lTd2mdiOspcqLgAnr9/nGi71NkeMpWKdmhuxm9GusH8=
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
-github.com/shurcooL/go v0.0.0-20200502201357-93f07166e636/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk=
-github.com/shurcooL/httpfs v0.0.0-20190707220628-8d4bc4ba7749/go.mod h1:ZY1cvUeJuFPAdZ/B6v7RHavJWZn2YPVFQ1OSXhCGOkg=
-github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
-github.com/shurcooL/vfsgen v0.0.0-20200824052919-0d455de96546/go.mod h1:TrYk7fJVaAttu97ZZKrO9UbRa8izdowaMIZcxYMbVaw=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog=
-github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
-github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
-github.com/spf13/afero v1.6.0/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I=
-github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
-github.com/spf13/cobra v1.2.1/go.mod h1:ExllRjgxM/piMAM+3tAZvg8fsklGAf3tPfi+i8t68Nk=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
-github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0EXowPYD95IqWIGo=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
-github.com/spf13/viper v1.8.1/go.mod h1:o0Pch8wJ9BVSWGQMbra6iw0oQ5oktSIBaujf1rJH9Ns=
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE=
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q=
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ=
@@ -644,7 +492,6 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
-github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
@@ -656,9 +503,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
-github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
-github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
-github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
+github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
+github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/testcontainers/testcontainers-go v0.31.0 h1:W0VwIhcEVhRflwL9as3dhY6jXjVCA27AkmbnZ+UTh3U=
github.com/testcontainers/testcontainers-go v0.31.0/go.mod h1:D2lAoA0zUFiSY+eAflqK5mcUx/A5hrrORaEQrd0SefI=
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0 h1:790+S8ewZYCbG+o8IiFlZ8ZZ33XbNO6zV9qhU6xhlRk=
@@ -681,25 +527,22 @@ github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYg
github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
-github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
-github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
-github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
-github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
+github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
+github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
+github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
+github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg=
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
-github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
-github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
-github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
-github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U=
-github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
+github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic=
+github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
github.com/zcalusic/sysinfo v1.1.3 h1:u/AVENkuoikKuIZ4sUEJ6iibpmQP6YpGD8SSMCrqAF0=
@@ -710,16 +553,6 @@ github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg=
github.com/zeebo/blake3 v0.2.3/go.mod h1:mjJjZpnsyIVtVgTOSpJ9vmRE4wgDeyt2HU3qXvvKCaQ=
github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo=
github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4=
-go.etcd.io/etcd/api/v3 v3.5.0/go.mod h1:cbVKeC6lCfl7j/8jBhAK6aIYO9XOjdptoxU/nLQcPvs=
-go.etcd.io/etcd/client/pkg/v3 v3.5.0/go.mod h1:IJHfcCEKxYu1Os13ZdwCwIUTUVGYTSAM3YSwc9/Ac1g=
-go.etcd.io/etcd/client/v2 v2.305.0/go.mod h1:h9puh54ZTgAKtEbut2oe9P4L/oqKCVB6xsXlzd7alYQ=
-go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
-go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
-go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
-go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
-go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
-go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
-go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
@@ -746,211 +579,113 @@ go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
-go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
-go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
-go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
-go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
+go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
+go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
-go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4=
goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U=
-golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
-golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
-golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
-golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
+golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
-golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
-golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
+golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
+golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
+golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
+golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
+golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
-golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
-golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
-golang.org/x/exp v0.0.0-20190731235908-ec7cb31e5a56/go.mod h1:JhuoJpWY28nO4Vef9tZUw9qufEGTyX1+7lmHxV5q5G4=
-golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek=
-golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY=
-golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
-golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
-golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
-golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
-golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
-golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
-golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
-golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
-golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
+golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ=
+golang.org/x/image v0.33.0/go.mod h1:DD3OsTYT9chzuzTQt+zMcOlBHgfoKQb1gry8p76Y1sc=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
-golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
-golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
-golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
-golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
-golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs=
-golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
-golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
-golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
-golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE=
-golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
-golang.org/x/mobile v0.0.0-20211207041440-4e6c2922fdee/go.mod h1:pe2sM7Uk+2Su1y7u/6Z8KJ24D7lepUjFZbhFOrmDfuQ=
-golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a h1:sYbmY3FwUWCBTodZL1S3JUuOvaW6kM2o+clDzzDNBWg=
-golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a/go.mod h1:Ede7gF0KGoHlj822RtphAHK1jLdrcuRBZg0sF1Q+SPc=
-golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
-golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
+golang.org/x/mobile v0.0.0-20251113184115-a159579294ab h1:Iqyc+2zr7aGyLuEadIm0KRJP0Wwt+fhlXLa51Fxf1+Q=
+golang.org/x/mobile v0.0.0-20251113184115-a159579294ab/go.mod h1:Eq3Nh/5pFSWug2ohiudJ1iyU59SO78QFuh4qTTN++I0=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
-golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
-golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
-golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
+golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
+golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
+golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
+golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
+golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20200222125558-5a598a2470a0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
-golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
-golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
-golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
-golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
-golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
-golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
-golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
-golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
-golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
-golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
+golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
-golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
-golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
+golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
+golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
+golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
+golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
-golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
-golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
-golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
-golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
-golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
-golang.org/x/oauth2 v0.0.0-20201109201403-9fd604954f58/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
-golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
-golang.org/x/oauth2 v0.0.0-20210218202405-ba52d332ba99/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
-golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
-golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
-golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
-golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc=
-golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
+golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
+golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
-golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
-golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
+golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
+golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200331124033-c3d80250170d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210104204734-6f8348627aad/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210220050731-9a76102bfb43/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -962,98 +697,60 @@ golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
-golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
-golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
+golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
+golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
+golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
+golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
-golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg=
-golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0=
-golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
+golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
+golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
+golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
+golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
+golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
-golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
-golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
-golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
-golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
-golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
-golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
-golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
+golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
+golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
+golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
+golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
+golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
+golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
-golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
-golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
-golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
-golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
-golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
-golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
-golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
-golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
-golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191112195655-aa38f8e97acc/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191216173652-a0e659d51361/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20191227053925-7b8e75db28f4/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200122220014-bf1340f18c4a/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200204074204-1cc6d1ef6c74/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200212150539-ea181f53ac56/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200224181240-023911ca70b2/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200227222343-706bc42d1f0d/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw=
-golang.org/x/tools v0.0.0-20200312045724-11d5b4c81c7d/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw=
-golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8=
-golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
-golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
-golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
-golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
-golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
-golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
-golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
-golang.org/x/tools v0.0.0-20200904185747-39188db58858/go.mod h1:Cj7w3i3Rnn0Xh82ur9kSqwfTHTeVxaDqrfMjpcNT6bE=
-golang.org/x/tools v0.0.0-20201110124207-079ba7bd75cd/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
-golang.org/x/tools v0.0.0-20201201161351-ac6f37ff4c2a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
-golang.org/x/tools v0.0.0-20201208233053-a543418bbed2/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
-golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
-golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
-golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
-golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
-golang.org/x/tools v0.1.8-0.20211022200916-316ba0b74098/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
-golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
-golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
+golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
+golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
+golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
+golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -1064,103 +761,24 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
-google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
-google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
-google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
-google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
-google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI=
-google.golang.org/api v0.14.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI=
-google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI=
-google.golang.org/api v0.17.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
-google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
-google.golang.org/api v0.19.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
-google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
-google.golang.org/api v0.22.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
-google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE=
-google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE=
-google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM=
-google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc=
-google.golang.org/api v0.35.0/go.mod h1:/XrVsuzM0rZmrsbjJutiuftIzeuTQcEeaYcSk/mQ1dg=
-google.golang.org/api v0.36.0/go.mod h1:+z5ficQTmoYpPn8LCUNVpK5I7hwkpjbcgqA7I34qYtE=
-google.golang.org/api v0.40.0/go.mod h1:fYKFpnQN0DsDSKRVRcQSDQNtqWPfM9i+zNPxepjRCQ8=
-google.golang.org/api v0.41.0/go.mod h1:RkxM5lITDfTzmyKFPt+wGrCJbVfniCr2ool8kTBzRTU=
-google.golang.org/api v0.43.0/go.mod h1:nQsDGjRXMo4lvh5hP0TKqF244gqhGcr/YSIykhUk/94=
-google.golang.org/api v0.44.0/go.mod h1:EBOGZqzyhtvMDoxwS97ctnh0zUmYY6CxqXsc1AvkYD8=
google.golang.org/api v0.177.0 h1:8a0p/BbPa65GlqGWtUKxot4p0TV8OGOfyTjtmkXNXmk=
google.golang.org/api v0.177.0/go.mod h1:srbhue4MLjkjbkux5p3dw/ocYOSZTaIEvf7bCOnFQDw=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
-google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
-google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0=
-google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
-google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
-google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
-google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
-google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
-google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
-google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
-google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8=
-google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
-google.golang.org/genproto v0.0.0-20191115194625-c23dd37a84c9/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
-google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
-google.golang.org/genproto v0.0.0-20191230161307-f3c370f40bfb/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
-google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
-google.golang.org/genproto v0.0.0-20200122232147-0452cf42e150/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
-google.golang.org/genproto v0.0.0-20200204135345-fa8e72b47b90/go.mod h1:GmwEX6Z4W5gMy59cAlVYjN9JhxgbQH6Gn+gFDQe2lzA=
-google.golang.org/genproto v0.0.0-20200212174721-66ed5ce911ce/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
-google.golang.org/genproto v0.0.0-20200224152610-e50cd9704f63/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
-google.golang.org/genproto v0.0.0-20200228133532-8c2c7df3a383/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
-google.golang.org/genproto v0.0.0-20200305110556-506484158171/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
-google.golang.org/genproto v0.0.0-20200312145019-da6875a35672/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
-google.golang.org/genproto v0.0.0-20200331122359-1ee6d9798940/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
-google.golang.org/genproto v0.0.0-20200430143042-b979b6f78d84/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
-google.golang.org/genproto v0.0.0-20200511104702-f5ebc3bea380/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
-google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
-google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587/go.mod h1:YsZOwe1myG/8QRHRsmBRE1LrgQY60beZKjly0O1fX9U=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
-google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7FcilCzHH/e9qn6dsT145K34l5v+OpcnNgKAAA=
-google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20200904004341-0bd0a958aa1d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20201109203340-2640f1f9cdfb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20201201144952-b05cb90ed32e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20201210142538-e3217bee35cc/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20201214200347-8c77b98c765d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20210222152913-aa3ee6e6a81c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20210303154014-9728d6b83eeb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20210310155132-4ce2db91004e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A=
-google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0=
google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463 h1:hE3bRWtU6uceqlh4fhrSnUyjKHMKB9KrTLLG+bc0ddM=
google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463/go.mod h1:U90ffi8eUL9MwPcrJylN5+Mk2v3vuPDptd5yyNUiRR8=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 h1:pFyd6EwwL2TqFf8emdthzeX+gZE1ElRq3iM8pui4KBY=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
-google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
-google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
-google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
-google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
-google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKal+60=
-google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk=
-google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak=
-google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak=
-google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak=
-google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0=
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
-google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA51WJ8=
-google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
-google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
-google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
-google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM=
google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok=
google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
@@ -1171,7 +789,6 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi
google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
-google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4=
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
@@ -1180,14 +797,11 @@ google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
-gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE=
gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ=
-gopkg.in/ini.v1 v1.62.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8=
@@ -1197,14 +811,12 @@ gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
-gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo=
@@ -1221,12 +833,4 @@ gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs=
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
-honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
-honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
-honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
-honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
-honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
-rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
-rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
-rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh
index e3fcbfdde..92252d0b3 100755
--- a/infrastructure_files/configure.sh
+++ b/infrastructure_files/configure.sh
@@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then
echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME"
echo ""
- export NETBIRD_SIGNAL_PROTOCOL="https"
unset NETBIRD_LETSENCRYPT_DOMAIN
unset NETBIRD_MGMT_API_CERT_FILE
unset NETBIRD_MGMT_API_CERT_KEY_FILE
fi
+if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then
+ export NETBIRD_SIGNAL_PROTOCOL="https"
+fi
+
# Check if management identity provider is set
if [ -n "$NETBIRD_MGMT_IDP" ]; then
EXTRA_CONFIG={}
diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl
index b24e853b4..1c9c63f78 100644
--- a/infrastructure_files/docker-compose.yml.tmpl
+++ b/infrastructure_files/docker-compose.yml.tmpl
@@ -40,13 +40,22 @@ services:
signal:
<<: *default
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
+ depends_on:
+ - dashboard
volumes:
- $SIGNAL_VOLUMENAME:/var/lib/netbird
+ - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
ports:
- $NETBIRD_SIGNAL_PORT:80
# # port and command for Let's Encrypt validation
# - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
+ command: [
+ "--cert-file", "$NETBIRD_MGMT_API_CERT_FILE",
+ "--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE",
+ "--log-file", "console",
+ "--port", "80"
+ ]
# Relay
relay:
diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik
index 196e26a66..0010974c5 100644
--- a/infrastructure_files/docker-compose.yml.tmpl.traefik
+++ b/infrastructure_files/docker-compose.yml.tmpl.traefik
@@ -49,6 +49,7 @@ services:
- traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
+ - traefik.http.routers.netbird-signal.service=netbird-signal
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c
diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh
index bc326cd7e..09c5225ad 100644
--- a/infrastructure_files/getting-started-with-zitadel.sh
+++ b/infrastructure_files/getting-started-with-zitadel.sh
@@ -682,17 +682,6 @@ renderManagementJson() {
"URI": "stun:$NETBIRD_DOMAIN:3478"
}
],
- "TURNConfig": {
- "Turns": [
- {
- "Proto": "udp",
- "URI": "turn:$NETBIRD_DOMAIN:3478",
- "Username": "$TURN_USER",
- "Password": "$TURN_PASSWORD"
- }
- ],
- "TimeBasedCredentials": false
- },
"Relay": {
"Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"],
"CredentialsTTL": "24h",
diff --git a/management/internals/controllers/network_map/controller/cache/dns_config_cache.go b/management/internals/controllers/network_map/controller/cache/dns_config_cache.go
new file mode 100644
index 000000000..8cc634ef4
--- /dev/null
+++ b/management/internals/controllers/network_map/controller/cache/dns_config_cache.go
@@ -0,0 +1,31 @@
+package cache
+
+import (
+ "sync"
+
+ "github.com/netbirdio/netbird/shared/management/proto"
+)
+
+// DNSConfigCache is a thread-safe cache for DNS configuration components
+type DNSConfigCache struct {
+ NameServerGroups sync.Map
+}
+
+// GetNameServerGroup retrieves a cached name server group
+func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
+ if c == nil {
+ return nil, false
+ }
+ if value, ok := c.NameServerGroups.Load(key); ok {
+ return value.(*proto.NameServerGroup), true
+ }
+ return nil, false
+}
+
+// SetNameServerGroup stores a name server group in the cache
+func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
+ if c == nil {
+ return
+ }
+ c.NameServerGroups.Store(key, value)
+}
diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go
new file mode 100644
index 000000000..df16e1922
--- /dev/null
+++ b/management/internals/controllers/network_map/controller/controller.go
@@ -0,0 +1,829 @@
+package controller
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "slices"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/exp/maps"
+ "golang.org/x/mod/semver"
+
+ nbdns "github.com/netbirdio/netbird/dns"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
+ "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
+ "github.com/netbirdio/netbird/management/internals/server/config"
+ "github.com/netbirdio/netbird/management/internals/shared/grpc"
+ "github.com/netbirdio/netbird/management/server/account"
+ "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
+ "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/posture"
+ "github.com/netbirdio/netbird/management/server/settings"
+ "github.com/netbirdio/netbird/management/server/store"
+ "github.com/netbirdio/netbird/management/server/telemetry"
+ "github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/shared/management/proto"
+ "github.com/netbirdio/netbird/shared/management/status"
+ "github.com/netbirdio/netbird/util"
+)
+
+type Controller struct {
+ repo Repository
+ metrics *metrics
+ // This should not be here, but we need to maintain it for the time being
+ accountManagerMetrics *telemetry.AccountManagerMetrics
+ peersUpdateManager network_map.PeersUpdateManager
+ settingsManager settings.Manager
+ EphemeralPeersManager ephemeral.Manager
+
+ accountUpdateLocks sync.Map
+ sendAccountUpdateLocks sync.Map
+ updateAccountPeersBufferInterval atomic.Int64
+ // dnsDomain is used for peer resolution. This is appended to the peer's name
+ dnsDomain string
+ config *config.Config
+
+ requestBuffer account.RequestBuffer
+
+ proxyController port_forwarding.Controller
+
+ integratedPeerValidator integrated_validator.IntegratedValidator
+
+ holder *types.Holder
+
+ expNewNetworkMap bool
+ expNewNetworkMapAIDs map[string]struct{}
+}
+
+type bufferUpdate struct {
+ mu sync.Mutex
+ next *time.Timer
+ update atomic.Bool
+}
+
+var _ network_map.Controller = (*Controller)(nil)
+
+func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
+ nMetrics, err := newMetrics(metrics.UpdateChannelMetrics())
+ if err != nil {
+ log.Fatal(fmt.Errorf("error creating metrics: %w", err))
+ }
+
+ newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
+ if err != nil {
+ log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
+ newNetworkMapBuilder = false
+ }
+
+ ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
+ expIDs := make(map[string]struct{}, len(ids))
+ for _, id := range ids {
+ expIDs[id] = struct{}{}
+ }
+
+ return &Controller{
+ repo: newRepository(store),
+ metrics: nMetrics,
+ accountManagerMetrics: metrics.AccountManagerMetrics(),
+ peersUpdateManager: peersUpdateManager,
+ requestBuffer: requestBuffer,
+ integratedPeerValidator: integratedPeerValidator,
+ settingsManager: settingsManager,
+ dnsDomain: dnsDomain,
+ config: config,
+
+ proxyController: proxyController,
+ EphemeralPeersManager: ephemeralPeersManager,
+
+ holder: types.NewHolder(),
+ expNewNetworkMap: newNetworkMapBuilder,
+ expNewNetworkMapAIDs: expIDs,
+ }
+}
+
+func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
+ peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get peer %s: %v", peerID, err)
+ }
+
+ c.EphemeralPeersManager.OnPeerConnected(ctx, peer)
+
+ return c.peersUpdateManager.CreateChannel(ctx, peerID), nil
+}
+
+func (c *Controller) OnPeerDisconnected(ctx context.Context, accountID string, peerID string) {
+ c.peersUpdateManager.CloseChannel(ctx, peerID)
+ peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to get peer %s: %v", peerID, err)
+ return
+ }
+ c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer)
+}
+
+func (c *Controller) CountStreams() int {
+ return c.peersUpdateManager.CountStreams()
+}
+
+func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
+ log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
+ var (
+ account *types.Account
+ err error
+ )
+ if c.experimentalNetworkMap(accountID) {
+ account = c.getAccountFromHolderOrInit(accountID)
+ } else {
+ account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
+ if err != nil {
+ return fmt.Errorf("failed to get account: %v", err)
+ }
+ }
+
+ globalStart := time.Now()
+
+ hasPeersConnected := false
+ for _, peer := range account.Peers {
+ if c.peersUpdateManager.HasChannel(peer.ID) {
+ hasPeersConnected = true
+ break
+ }
+
+ }
+
+ if !hasPeersConnected {
+ return nil
+ }
+
+ approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
+ if err != nil {
+ return fmt.Errorf("failed to get validate peers: %v", err)
+ }
+
+ var wg sync.WaitGroup
+ semaphore := make(chan struct{}, 10)
+
+ dnsCache := &cache.DNSConfigCache{}
+ dnsDomain := c.GetDNSDomain(account.Settings)
+ customZone := account.GetPeersCustomZone(ctx, dnsDomain)
+ resourcePolicies := account.GetResourcePoliciesMap()
+ routers := account.GetResourceRoutersMap()
+ groupIDToUserIDs := account.GetActiveGroupUsers()
+
+ if c.experimentalNetworkMap(accountID) {
+ c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
+ }
+
+ proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
+ return fmt.Errorf("failed to get proxy network maps: %v", err)
+ }
+
+ extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID)
+ if err != nil {
+ return fmt.Errorf("failed to get flow enabled status: %v", err)
+ }
+
+ dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
+
+ for _, peer := range account.Peers {
+ if !c.peersUpdateManager.HasChannel(peer.ID) {
+ log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
+ continue
+ }
+
+ wg.Add(1)
+ semaphore <- struct{}{}
+ go func(p *nbpeer.Peer) {
+ defer wg.Done()
+ defer func() { <-semaphore }()
+
+ start := time.Now()
+
+ postureChecks, err := c.getPeerPostureChecks(account, p.ID)
+ if err != nil {
+ log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
+ return
+ }
+
+ c.metrics.CountCalcPostureChecksDuration(time.Since(start))
+ start = time.Now()
+
+ var remotePeerNetworkMap *types.NetworkMap
+
+ if c.experimentalNetworkMap(accountID) {
+ remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
+ } else {
+ remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
+ }
+
+ c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
+
+ proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
+ if ok {
+ remotePeerNetworkMap.Merge(proxyNetworkMap)
+ }
+
+ peerGroups := account.GetPeerGroups(p.ID)
+ start = time.Now()
+ update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
+ c.metrics.CountToSyncResponseDuration(time.Since(start))
+
+ c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update})
+ }(peer)
+ }
+
+ wg.Wait()
+ if c.accountManagerMetrics != nil {
+ c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart))
+ }
+
+ return nil
+}
+
+func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string) error {
+ log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
+
+ bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
+ b := bufUpd.(*bufferUpdate)
+
+ if !b.mu.TryLock() {
+ b.update.Store(true)
+ return nil
+ }
+
+ if b.next != nil {
+ b.next.Stop()
+ }
+
+ go func() {
+ defer b.mu.Unlock()
+ _ = c.sendUpdateAccountPeers(ctx, accountID)
+ if !b.update.Load() {
+ return
+ }
+ b.update.Store(false)
+ if b.next == nil {
+ b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
+ _ = c.sendUpdateAccountPeers(ctx, accountID)
+ })
+ return
+ }
+ b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
+ }()
+
+ return nil
+}
+
+// UpdatePeers updates all peers that belong to an account.
+// Should be called when changes have to be synced to peers.
+func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
+ if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
+ return fmt.Errorf("recalculate network map cache: %v", err)
+ }
+
+ return c.sendUpdateAccountPeers(ctx, accountID)
+}
+
+func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
+ if !c.peersUpdateManager.HasChannel(peerId) {
+ return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
+ }
+
+ account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
+ if err != nil {
+ return fmt.Errorf("failed to send out updates to peer %s: %v", peerId, err)
+ }
+
+ peer := account.GetPeer(peerId)
+ if peer == nil {
+ return fmt.Errorf("peer %s doesn't exists in account %s", peerId, accountId)
+ }
+
+ approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
+ if err != nil {
+ return fmt.Errorf("failed to get validated peers: %v", err)
+ }
+
+ dnsCache := &cache.DNSConfigCache{}
+ dnsDomain := c.GetDNSDomain(account.Settings)
+ customZone := account.GetPeersCustomZone(ctx, dnsDomain)
+ resourcePolicies := account.GetResourcePoliciesMap()
+ routers := account.GetResourceRoutersMap()
+ groupIDToUserIDs := account.GetActiveGroupUsers()
+
+ postureChecks, err := c.getPeerPostureChecks(account, peerId)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err)
+ return fmt.Errorf("failed to get posture checks for peer %s: %v", peerId, err)
+ }
+
+ proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
+ return err
+ }
+
+ var remotePeerNetworkMap *types.NetworkMap
+
+ if c.experimentalNetworkMap(accountId) {
+ remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
+ } else {
+ remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
+ }
+
+ proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
+ if ok {
+ remotePeerNetworkMap.Merge(proxyNetworkMap)
+ }
+
+ extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
+ if err != nil {
+ return fmt.Errorf("failed to get extra settings: %v", err)
+ }
+
+ peerGroups := account.GetPeerGroups(peerId)
+ dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
+
+ update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
+ c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update})
+
+ return nil
+}
+
+func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
+ log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName())
+
+ bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
+ b := bufUpd.(*bufferUpdate)
+
+ if !b.mu.TryLock() {
+ b.update.Store(true)
+ return nil
+ }
+
+ if b.next != nil {
+ b.next.Stop()
+ }
+
+ go func() {
+ defer b.mu.Unlock()
+ _ = c.UpdateAccountPeers(ctx, accountID)
+ if !b.update.Load() {
+ return
+ }
+ b.update.Store(false)
+ if b.next == nil {
+ b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
+ _ = c.UpdateAccountPeers(ctx, accountID)
+ })
+ return
+ }
+ b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
+ }()
+
+ return nil
+}
+
+func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
+ if isRequiresApproval {
+ network, err := c.repo.GetAccountNetwork(ctx, accountID)
+ if err != nil {
+ return nil, nil, nil, 0, err
+ }
+
+ emptyMap := &types.NetworkMap{
+ Network: network.Copy(),
+ }
+ return peer, emptyMap, nil, 0, nil
+ }
+
+ var (
+ account *types.Account
+ err error
+ )
+ if c.experimentalNetworkMap(accountID) {
+ account = c.getAccountFromHolderOrInit(accountID)
+ } else {
+ account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
+ if err != nil {
+ return nil, nil, nil, 0, err
+ }
+ }
+
+ approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
+ if err != nil {
+ return nil, nil, nil, 0, err
+ }
+
+ startPosture := time.Now()
+ postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
+ if err != nil {
+ return nil, nil, nil, 0, err
+ }
+ log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
+
+ customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
+
+ proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
+ return nil, nil, nil, 0, err
+ }
+
+ var networkMap *types.NetworkMap
+
+ if c.experimentalNetworkMap(accountID) {
+ networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
+ } else {
+ networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics, account.GetActiveGroupUsers())
+ }
+
+ proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
+ if ok {
+ networkMap.Merge(proxyNetworkMap)
+ }
+
+ dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
+
+ return peer, networkMap, postureChecks, dnsFwdPort, nil
+}
+
+func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
+ c.enrichAccountFromHolder(account)
+ account.InitNetworkMapBuilderIfNeeded(validatedPeers)
+}
+
+func (c *Controller) getPeerNetworkMapExp(
+ ctx context.Context,
+ accountId string,
+ peerId string,
+ validatedPeers map[string]struct{},
+ customZone nbdns.CustomZone,
+ metrics *telemetry.AccountManagerMetrics,
+) *types.NetworkMap {
+ account := c.getAccountFromHolderOrInit(accountId)
+ if account == nil {
+ log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
+ return &types.NetworkMap{
+ Network: &types.Network{},
+ }
+ }
+ return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
+}
+
+func (c *Controller) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error {
+ c.enrichAccountFromHolder(account)
+ return account.OnPeerAddedUpdNetworkMapCache(peerId)
+}
+
+func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
+ c.enrichAccountFromHolder(account)
+ return account.OnPeerDeletedUpdNetworkMapCache(peerId)
+}
+
+func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
+ account := c.getAccountFromHolder(accountId)
+ if account == nil {
+ return
+ }
+ account.UpdatePeerInNetworkMapCache(peer)
+}
+
+func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
+ account.RecalculateNetworkMapCache(validatedPeers)
+ c.updateAccountInHolder(account)
+}
+
+func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
+ if c.experimentalNetworkMap(accountId) {
+ account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
+ if err != nil {
+ return err
+ }
+ validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
+ return err
+ }
+ c.recalculateNetworkMapCache(account, validatedPeers)
+ }
+ return nil
+}
+
+func (c *Controller) experimentalNetworkMap(accountId string) bool {
+ _, ok := c.expNewNetworkMapAIDs[accountId]
+ return c.expNewNetworkMap || ok
+}
+
+func (c *Controller) enrichAccountFromHolder(account *types.Account) {
+ a := c.holder.GetAccount(account.Id)
+ if a == nil {
+ c.holder.AddAccount(account)
+ return
+ }
+ account.NetworkMapCache = a.NetworkMapCache
+ if account.NetworkMapCache == nil {
+ return
+ }
+ account.NetworkMapCache.UpdateAccountPointer(account)
+ c.holder.AddAccount(account)
+}
+
+func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
+ return c.holder.GetAccount(accountID)
+}
+
+func (c *Controller) getAccountFromHolderOrInit(accountID string) *types.Account {
+ a := c.holder.GetAccount(accountID)
+ if a != nil {
+ return a
+ }
+ account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountWithBackpressure)
+ if err != nil {
+ return nil
+ }
+ return account
+}
+
+func (c *Controller) updateAccountInHolder(account *types.Account) {
+ c.holder.AddAccount(account)
+}
+
+// GetDNSDomain returns the configured dnsDomain
+func (c *Controller) GetDNSDomain(settings *types.Settings) string {
+ if settings == nil {
+ return c.dnsDomain
+ }
+ if settings.DNSDomain == "" {
+ return c.dnsDomain
+ }
+
+ return settings.DNSDomain
+}
+
+// getPeerPostureChecks returns the posture checks applied for a given peer.
+func (c *Controller) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) {
+ peerPostureChecks := make(map[string]*posture.Checks)
+
+ if len(account.PostureChecks) == 0 {
+ return nil, nil
+ }
+
+ for _, policy := range account.Policies {
+ if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
+ continue
+ }
+
+ if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
+ return nil, err
+ }
+ }
+
+ return maps.Values(peerPostureChecks), nil
+}
+
+func (c *Controller) StartWarmup(ctx context.Context) {
+ var initialInterval int64
+ intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS")
+ interval, err := strconv.Atoi(intervalStr)
+ if err != nil {
+ initialInterval = 1
+ log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err)
+ } else {
+ initialInterval = int64(interval) * 10
+ go func() {
+ startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S")
+ startupPeriod, err := strconv.Atoi(startupPeriodStr)
+ if err != nil {
+ startupPeriod = 1
+ log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err)
+ }
+ time.Sleep(time.Duration(startupPeriod) * time.Second)
+ c.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond))
+ log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval)
+ }()
+ }
+ c.updateAccountPeersBufferInterval.Store(int64(time.Duration(initialInterval) * time.Millisecond))
+ log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval)
+
+}
+
+// computeForwarderPort checks if all peers in the account have updated to a specific version or newer.
+// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
+func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
+ if len(peers) == 0 {
+ return int64(network_map.OldForwarderPort)
+ }
+
+ reqVer := semver.Canonical(requiredVersion)
+
+ // Check if all peers have the required version or newer
+ for _, peer := range peers {
+
+ // Development version is always supported
+ if peer.Meta.WtVersion == "development" {
+ continue
+ }
+ peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
+ if peerVersion == "" {
+ // If any peer doesn't have version info, return 0
+ return int64(network_map.OldForwarderPort)
+ }
+
+ // Compare versions
+ if semver.Compare(peerVersion, reqVer) < 0 {
+ return int64(network_map.OldForwarderPort)
+ }
+ }
+
+ // All peers have the required version or newer
+ return int64(network_map.DnsForwarderPort)
+}
+
+// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
+func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error {
+ isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
+ if err != nil {
+ return err
+ }
+
+ if !isInGroup {
+ return nil
+ }
+
+ for _, sourcePostureCheckID := range policy.SourcePostureChecks {
+ postureCheck := account.GetPostureChecks(sourcePostureCheckID)
+ if postureCheck == nil {
+ return errors.New("failed to add policy posture checks: posture checks not found")
+ }
+ peerPostureChecks[sourcePostureCheckID] = postureCheck
+ }
+
+ return nil
+}
+
+// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
+func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) {
+ for _, rule := range policy.Rules {
+ if !rule.Enabled {
+ continue
+ }
+
+ for _, sourceGroup := range rule.Sources {
+ group := account.GetGroup(sourceGroup)
+ if group == nil {
+ return false, fmt.Errorf("failed to check peer in policy source group: group not found")
+ }
+
+ if slices.Contains(group.Peers, peerID) {
+ return true, nil
+ }
+ }
+ }
+
+ return false, nil
+}
+
+func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
+ peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
+ if err != nil {
+ return fmt.Errorf("failed to get peers by ids: %w", err)
+ }
+
+ for _, peer := range peers {
+ c.UpdatePeerInNetworkMapCache(accountID, peer)
+ }
+
+ err = c.bufferSendUpdateAccountPeers(ctx, accountID)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
+ }
+
+ return nil
+}
+
+func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
+ for _, peerID := range peerIDs {
+ if c.experimentalNetworkMap(accountID) {
+ account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
+ if err != nil {
+ return err
+ }
+
+ err = c.onPeerAddedUpdNetworkMapCache(account, peerID)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ return c.bufferSendUpdateAccountPeers(ctx, accountID)
+}
+
+func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
+ network, err := c.repo.GetAccountNetwork(ctx, accountID)
+ if err != nil {
+ return err
+ }
+
+ peers, err := c.repo.GetAccountPeers(ctx, accountID)
+ if err != nil {
+ return err
+ }
+
+ dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
+ for _, peerID := range peerIDs {
+ c.peersUpdateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
+ Update: &proto.SyncResponse{
+ RemotePeers: []*proto.RemotePeerConfig{},
+ RemotePeersIsEmpty: true,
+ NetworkMap: &proto.NetworkMap{
+ Serial: network.CurrentSerial(),
+ RemotePeers: []*proto.RemotePeerConfig{},
+ RemotePeersIsEmpty: true,
+ FirewallRules: []*proto.FirewallRule{},
+ FirewallRulesIsEmpty: true,
+ DNSConfig: &proto.DNSConfig{
+ ForwarderPort: dnsFwdPort,
+ },
+ },
+ },
+ })
+ c.peersUpdateManager.CloseChannel(ctx, peerID)
+
+ if c.experimentalNetworkMap(accountID) {
+ account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
+ continue
+ }
+ err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
+ continue
+ }
+ }
+ }
+
+ return c.bufferSendUpdateAccountPeers(ctx, accountID)
+}
+
+// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
+func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
+ account, err := c.repo.GetAccountByPeerID(ctx, peerID)
+ if err != nil {
+ return nil, err
+ }
+
+ peer := account.GetPeer(peerID)
+ if peer == nil {
+ return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
+ }
+
+ groups := make(map[string][]string)
+ for groupID, group := range account.Groups {
+ groups[groupID] = group.Peers
+ }
+
+ validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
+ if err != nil {
+ return nil, err
+ }
+ customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
+
+ proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
+ return nil, err
+ }
+
+ var networkMap *types.NetworkMap
+
+ if c.experimentalNetworkMap(peer.AccountID) {
+ networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
+ } else {
+ networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
+ }
+
+ proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
+ if ok {
+ networkMap.Merge(proxyNetworkMap)
+ }
+
+ return networkMap, nil
+}
+
+func (c *Controller) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) {
+ c.peersUpdateManager.CloseChannels(ctx, peerIDs)
+}
diff --git a/management/internals/controllers/network_map/controller/controller_test.go b/management/internals/controllers/network_map/controller/controller_test.go
new file mode 100644
index 000000000..90e7b6e18
--- /dev/null
+++ b/management/internals/controllers/network_map/controller/controller_test.go
@@ -0,0 +1,109 @@
+package controller
+
+import (
+ "testing"
+
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
+)
+
+func TestComputeForwarderPort(t *testing.T) {
+ // Test with empty peers list
+ peers := []*nbpeer.Peer{}
+ result := computeForwarderPort(peers, "v0.59.0")
+ if result != int64(network_map.OldForwarderPort) {
+ t.Errorf("Expected %d for empty peers list, got %d", network_map.OldForwarderPort, result)
+ }
+
+ // Test with peers that have old versions
+ peers = []*nbpeer.Peer{
+ {
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.57.0",
+ },
+ },
+ {
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.26.0",
+ },
+ },
+ }
+ result = computeForwarderPort(peers, "v0.59.0")
+ if result != int64(network_map.OldForwarderPort) {
+ t.Errorf("Expected %d for peers with old versions, got %d", network_map.OldForwarderPort, result)
+ }
+
+ // Test with peers that have new versions
+ peers = []*nbpeer.Peer{
+ {
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.59.0",
+ },
+ },
+ {
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.59.0",
+ },
+ },
+ }
+ result = computeForwarderPort(peers, "v0.59.0")
+ if result != int64(network_map.DnsForwarderPort) {
+ t.Errorf("Expected %d for peers with new versions, got %d", network_map.DnsForwarderPort, result)
+ }
+
+ // Test with peers that have mixed versions
+ peers = []*nbpeer.Peer{
+ {
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.59.0",
+ },
+ },
+ {
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.57.0",
+ },
+ },
+ }
+ result = computeForwarderPort(peers, "v0.59.0")
+ if result != int64(network_map.OldForwarderPort) {
+ t.Errorf("Expected %d for peers with mixed versions, got %d", network_map.OldForwarderPort, result)
+ }
+
+ // Test with peers that have empty version
+ peers = []*nbpeer.Peer{
+ {
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "",
+ },
+ },
+ }
+ result = computeForwarderPort(peers, "v0.59.0")
+ if result != int64(network_map.OldForwarderPort) {
+ t.Errorf("Expected %d for peers with empty version, got %d", network_map.OldForwarderPort, result)
+ }
+
+ peers = []*nbpeer.Peer{
+ {
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "development",
+ },
+ },
+ }
+ result = computeForwarderPort(peers, "v0.59.0")
+ if result == int64(network_map.OldForwarderPort) {
+ t.Errorf("Expected %d for peers with dev version, got %d", network_map.DnsForwarderPort, result)
+ }
+
+ // Test with peers that have unknown version string
+ peers = []*nbpeer.Peer{
+ {
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "unknown",
+ },
+ },
+ }
+ result = computeForwarderPort(peers, "v0.59.0")
+ if result != int64(network_map.OldForwarderPort) {
+ t.Errorf("Expected %d for peers with unknown version, got %d", network_map.OldForwarderPort, result)
+ }
+}
diff --git a/management/internals/controllers/network_map/controller/metrics.go b/management/internals/controllers/network_map/controller/metrics.go
new file mode 100644
index 000000000..5832d2130
--- /dev/null
+++ b/management/internals/controllers/network_map/controller/metrics.go
@@ -0,0 +1,15 @@
+package controller
+
+import (
+ "github.com/netbirdio/netbird/management/server/telemetry"
+)
+
+type metrics struct {
+ *telemetry.UpdateChannelMetrics
+}
+
+func newMetrics(updateChannelMetrics *telemetry.UpdateChannelMetrics) (*metrics, error) {
+ return &metrics{
+ updateChannelMetrics,
+ }, nil
+}
diff --git a/management/internals/controllers/network_map/controller/repository.go b/management/internals/controllers/network_map/controller/repository.go
new file mode 100644
index 000000000..3ed51a5c3
--- /dev/null
+++ b/management/internals/controllers/network_map/controller/repository.go
@@ -0,0 +1,49 @@
+package controller
+
+import (
+ "context"
+
+ "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/store"
+ "github.com/netbirdio/netbird/management/server/types"
+)
+
+type Repository interface {
+ GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error)
+ GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error)
+ GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
+ GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
+ GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
+}
+
+type repository struct {
+ store store.Store
+}
+
+var _ Repository = (*repository)(nil)
+
+func newRepository(s store.Store) Repository {
+ return &repository{
+ store: s,
+ }
+}
+
+func (r *repository) GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error) {
+ return r.store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
+}
+
+func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error) {
+ return r.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
+}
+
+func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) {
+ return r.store.GetAccountByPeerID(ctx, peerID)
+}
+
+func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) {
+ return r.store.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, peerIDs)
+}
+
+func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) {
+ return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
+}
diff --git a/management/internals/controllers/network_map/interface.go b/management/internals/controllers/network_map/interface.go
new file mode 100644
index 000000000..b1de7d017
--- /dev/null
+++ b/management/internals/controllers/network_map/interface.go
@@ -0,0 +1,39 @@
+package network_map
+
+//go:generate go run go.uber.org/mock/mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
+
+import (
+ "context"
+
+ nbdns "github.com/netbirdio/netbird/dns"
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/posture"
+ "github.com/netbirdio/netbird/management/server/types"
+)
+
+const (
+ EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
+ EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
+
+ DnsForwarderPort = nbdns.ForwarderServerPort
+ OldForwarderPort = nbdns.ForwarderClientPort
+ DnsForwarderPortMinVersion = "v0.59.0"
+)
+
+type Controller interface {
+ UpdateAccountPeers(ctx context.Context, accountID string) error
+ UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
+ BufferUpdateAccountPeers(ctx context.Context, accountID string) error
+ GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
+ GetDNSDomain(settings *types.Settings) string
+ StartWarmup(context.Context)
+ GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
+ CountStreams() int
+
+ OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error
+ OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error
+ OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
+ DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
+ OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
+ OnPeerDisconnected(ctx context.Context, accountID string, peerID string)
+}
diff --git a/management/internals/controllers/network_map/interface_mock.go b/management/internals/controllers/network_map/interface_mock.go
new file mode 100644
index 000000000..5a98eefa8
--- /dev/null
+++ b/management/internals/controllers/network_map/interface_mock.go
@@ -0,0 +1,240 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: ./interface.go
+//
+// Generated by this command:
+//
+// mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
+//
+
+// Package network_map is a generated GoMock package.
+package network_map
+
+import (
+ context "context"
+ reflect "reflect"
+
+ peer "github.com/netbirdio/netbird/management/server/peer"
+ posture "github.com/netbirdio/netbird/management/server/posture"
+ types "github.com/netbirdio/netbird/management/server/types"
+ gomock "go.uber.org/mock/gomock"
+)
+
+// MockController is a mock of Controller interface.
+type MockController struct {
+ ctrl *gomock.Controller
+ recorder *MockControllerMockRecorder
+ isgomock struct{}
+}
+
+// MockControllerMockRecorder is the mock recorder for MockController.
+type MockControllerMockRecorder struct {
+ mock *MockController
+}
+
+// NewMockController creates a new mock instance.
+func NewMockController(ctrl *gomock.Controller) *MockController {
+ mock := &MockController{ctrl: ctrl}
+ mock.recorder = &MockControllerMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockController) EXPECT() *MockControllerMockRecorder {
+ return m.recorder
+}
+
+// BufferUpdateAccountPeers mocks base method.
+func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers.
+func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID any) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID)
+}
+
+// CountStreams mocks base method.
+func (m *MockController) CountStreams() int {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "CountStreams")
+ ret0, _ := ret[0].(int)
+ return ret0
+}
+
+// CountStreams indicates an expected call of CountStreams.
+func (mr *MockControllerMockRecorder) CountStreams() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountStreams", reflect.TypeOf((*MockController)(nil).CountStreams))
+}
+
+// DisconnectPeers mocks base method.
+func (m *MockController) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "DisconnectPeers", ctx, accountId, peerIDs)
+}
+
+// DisconnectPeers indicates an expected call of DisconnectPeers.
+func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, accountId, peerIDs any) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, accountId, peerIDs)
+}
+
+// GetDNSDomain mocks base method.
+func (m *MockController) GetDNSDomain(settings *types.Settings) string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetDNSDomain", settings)
+ ret0, _ := ret[0].(string)
+ return ret0
+}
+
+// GetDNSDomain indicates an expected call of GetDNSDomain.
+func (mr *MockControllerMockRecorder) GetDNSDomain(settings any) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSDomain", reflect.TypeOf((*MockController)(nil).GetDNSDomain), settings)
+}
+
+// GetNetworkMap mocks base method.
+func (m *MockController) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetNetworkMap", ctx, peerID)
+ ret0, _ := ret[0].(*types.NetworkMap)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetNetworkMap indicates an expected call of GetNetworkMap.
+func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNetworkMap", reflect.TypeOf((*MockController)(nil).GetNetworkMap), ctx, peerID)
+}
+
+// GetValidatedPeerWithMap mocks base method.
+func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p)
+ ret0, _ := ret[0].(*peer.Peer)
+ ret1, _ := ret[1].(*types.NetworkMap)
+ ret2, _ := ret[2].([]*posture.Checks)
+ ret3, _ := ret[3].(int64)
+ ret4, _ := ret[4].(error)
+ return ret0, ret1, ret2, ret3, ret4
+}
+
+// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
+func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
+}
+
+// OnPeerConnected mocks base method.
+func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "OnPeerConnected", ctx, accountID, peerID)
+ ret0, _ := ret[0].(chan *UpdateMessage)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// OnPeerConnected indicates an expected call of OnPeerConnected.
+func (mr *MockControllerMockRecorder) OnPeerConnected(ctx, accountID, peerID any) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerConnected", reflect.TypeOf((*MockController)(nil).OnPeerConnected), ctx, accountID, peerID)
+}
+
+// OnPeerDisconnected mocks base method.
+func (m *MockController) OnPeerDisconnected(ctx context.Context, accountID, peerID string) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "OnPeerDisconnected", ctx, accountID, peerID)
+}
+
+// OnPeerDisconnected indicates an expected call of OnPeerDisconnected.
+func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID any) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDisconnected", reflect.TypeOf((*MockController)(nil).OnPeerDisconnected), ctx, accountID, peerID)
+}
+
+// OnPeersAdded mocks base method.
+func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// OnPeersAdded indicates an expected call of OnPeersAdded.
+func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs)
+}
+
+// OnPeersDeleted mocks base method.
+func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// OnPeersDeleted indicates an expected call of OnPeersDeleted.
+func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs)
+}
+
+// OnPeersUpdated mocks base method.
+func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// OnPeersUpdated indicates an expected call of OnPeersUpdated.
+func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs)
+}
+
+// StartWarmup mocks base method.
+func (m *MockController) StartWarmup(arg0 context.Context) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "StartWarmup", arg0)
+}
+
+// StartWarmup indicates an expected call of StartWarmup.
+func (mr *MockControllerMockRecorder) StartWarmup(arg0 any) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartWarmup", reflect.TypeOf((*MockController)(nil).StartWarmup), arg0)
+}
+
+// UpdateAccountPeer mocks base method.
+func (m *MockController) UpdateAccountPeer(ctx context.Context, accountId, peerId string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "UpdateAccountPeer", ctx, accountId, peerId)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// UpdateAccountPeer indicates an expected call of UpdateAccountPeer.
+func (mr *MockControllerMockRecorder) UpdateAccountPeer(ctx, accountId, peerId any) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeer", reflect.TypeOf((*MockController)(nil).UpdateAccountPeer), ctx, accountId, peerId)
+}
+
+// UpdateAccountPeers mocks base method.
+func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// UpdateAccountPeers indicates an expected call of UpdateAccountPeers.
+func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID any) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID)
+}
diff --git a/management/internals/controllers/network_map/network_map.go b/management/internals/controllers/network_map/network_map.go
new file mode 100644
index 000000000..e915c2193
--- /dev/null
+++ b/management/internals/controllers/network_map/network_map.go
@@ -0,0 +1 @@
+package network_map
diff --git a/management/internals/controllers/network_map/update_channel.go b/management/internals/controllers/network_map/update_channel.go
new file mode 100644
index 000000000..0b085b85f
--- /dev/null
+++ b/management/internals/controllers/network_map/update_channel.go
@@ -0,0 +1,13 @@
+package network_map
+
+import "context"
+
+type PeersUpdateManager interface {
+ SendUpdate(ctx context.Context, peerID string, update *UpdateMessage)
+ CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage
+ CloseChannel(ctx context.Context, peerID string)
+ CountStreams() int
+ HasChannel(peerID string) bool
+ CloseChannels(ctx context.Context, peerIDs []string)
+ GetAllConnectedPeers() map[string]struct{}
+}
diff --git a/management/server/updatechannel.go b/management/internals/controllers/network_map/update_channel/updatechannel.go
similarity index 87%
rename from management/server/updatechannel.go
rename to management/internals/controllers/network_map/update_channel/updatechannel.go
index da12f1b70..5f7db5300 100644
--- a/management/server/updatechannel.go
+++ b/management/internals/controllers/network_map/update_channel/updatechannel.go
@@ -1,4 +1,4 @@
-package server
+package update_channel
import (
"context"
@@ -7,38 +7,34 @@ import (
log "github.com/sirupsen/logrus"
- "github.com/netbirdio/netbird/shared/management/proto"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/telemetry"
- "github.com/netbirdio/netbird/management/server/types"
)
const channelBufferSize = 100
-type UpdateMessage struct {
- Update *proto.SyncResponse
- NetworkMap *types.NetworkMap
-}
-
type PeersUpdateManager struct {
// peerChannels is an update channel indexed by Peer.ID
- peerChannels map[string]chan *UpdateMessage
+ peerChannels map[string]chan *network_map.UpdateMessage
// channelsMux keeps the mutex to access peerChannels
channelsMux *sync.RWMutex
// metrics provides method to collect application metrics
metrics telemetry.AppMetrics
}
+var _ network_map.PeersUpdateManager = (*PeersUpdateManager)(nil)
+
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
return &PeersUpdateManager{
- peerChannels: make(map[string]chan *UpdateMessage),
+ peerChannels: make(map[string]chan *network_map.UpdateMessage),
channelsMux: &sync.RWMutex{},
metrics: metrics,
}
}
// SendUpdate sends update message to the peer's channel
-func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) {
+func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *network_map.UpdateMessage) {
start := time.Now()
var found, dropped bool
@@ -66,7 +62,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
}
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
-func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage {
+func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *network_map.UpdateMessage {
start := time.Now()
closed := false
@@ -85,7 +81,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c
close(channel)
}
// mbragin: todo shouldn't it be more? or configurable?
- channel := make(chan *UpdateMessage, channelBufferSize)
+ channel := make(chan *network_map.UpdateMessage, channelBufferSize)
p.peerChannels[peerID] = channel
log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID)
@@ -176,3 +172,9 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool {
return ok
}
+
+func (p *PeersUpdateManager) CountStreams() int {
+ p.channelsMux.RLock()
+ defer p.channelsMux.RUnlock()
+ return len(p.peerChannels)
+}
diff --git a/management/server/updatechannel_test.go b/management/internals/controllers/network_map/update_channel/updatechannel_test.go
similarity index 89%
rename from management/server/updatechannel_test.go
rename to management/internals/controllers/network_map/update_channel/updatechannel_test.go
index 0dc86563d..afc1e2c32 100644
--- a/management/server/updatechannel_test.go
+++ b/management/internals/controllers/network_map/update_channel/updatechannel_test.go
@@ -1,10 +1,11 @@
-package server
+package update_channel
import (
"context"
"testing"
"time"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -24,7 +25,7 @@ func TestCreateChannel(t *testing.T) {
func TestSendUpdate(t *testing.T) {
peer := "test-sendupdate"
peersUpdater := NewPeersUpdateManager(nil)
- update1 := &UpdateMessage{Update: &proto.SyncResponse{
+ update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 0,
},
@@ -44,7 +45,7 @@ func TestSendUpdate(t *testing.T) {
peersUpdater.SendUpdate(context.Background(), peer, update1)
}
- update2 := &UpdateMessage{Update: &proto.SyncResponse{
+ update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 10,
},
diff --git a/management/internals/controllers/network_map/update_message.go b/management/internals/controllers/network_map/update_message.go
new file mode 100644
index 000000000..33643bcbd
--- /dev/null
+++ b/management/internals/controllers/network_map/update_message.go
@@ -0,0 +1,9 @@
+package network_map
+
+import (
+ "github.com/netbirdio/netbird/shared/management/proto"
+)
+
+type UpdateMessage struct {
+ Update *proto.SyncResponse
+}
diff --git a/management/server/peers/ephemeral/interface.go b/management/internals/modules/peers/ephemeral/interface.go
similarity index 83%
rename from management/server/peers/ephemeral/interface.go
rename to management/internals/modules/peers/ephemeral/interface.go
index a1605b3b9..8fe25435c 100644
--- a/management/server/peers/ephemeral/interface.go
+++ b/management/internals/modules/peers/ephemeral/interface.go
@@ -2,10 +2,15 @@ package ephemeral
import (
"context"
+ "time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
+const (
+ EphemeralLifeTime = 10 * time.Minute
+)
+
type Manager interface {
LoadInitialPeers(ctx context.Context)
Stop()
diff --git a/management/server/peers/ephemeral/manager/ephemeral.go b/management/internals/modules/peers/ephemeral/manager/ephemeral.go
similarity index 85%
rename from management/server/peers/ephemeral/manager/ephemeral.go
rename to management/internals/modules/peers/ephemeral/manager/ephemeral.go
index 062ba69d2..15119045b 100644
--- a/management/server/peers/ephemeral/manager/ephemeral.go
+++ b/management/internals/modules/peers/ephemeral/manager/ephemeral.go
@@ -7,14 +7,15 @@ import (
log "github.com/sirupsen/logrus"
- nbAccount "github.com/netbirdio/netbird/management/server/account"
+ "github.com/netbirdio/netbird/management/internals/modules/peers"
+ "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
+
"github.com/netbirdio/netbird/management/server/store"
)
const (
- ephemeralLifeTime = 10 * time.Minute
// cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure.
cleanupWindow = 1 * time.Minute
)
@@ -33,11 +34,11 @@ type ephemeralPeer struct {
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
// in worst case we will get invalid error message in this manager.
-// EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted
+// EphemeralManager keep a list of ephemeral peers. After EphemeralLifeTime inactivity the peer will be deleted
// automatically. Inactivity means the peer disconnected from the Management server.
type EphemeralManager struct {
- store store.Store
- accountManager nbAccount.Manager
+ store store.Store
+ peersManager peers.Manager
headPeer *ephemeralPeer
tailPeer *ephemeralPeer
@@ -49,12 +50,12 @@ type EphemeralManager struct {
}
// NewEphemeralManager instantiate new EphemeralManager
-func NewEphemeralManager(store store.Store, accountManager nbAccount.Manager) *EphemeralManager {
+func NewEphemeralManager(store store.Store, peersManager peers.Manager) *EphemeralManager {
return &EphemeralManager{
- store: store,
- accountManager: accountManager,
+ store: store,
+ peersManager: peersManager,
- lifeTime: ephemeralLifeTime,
+ lifeTime: ephemeral.EphemeralLifeTime,
cleanupWindow: cleanupWindow,
}
}
@@ -106,7 +107,7 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee
}
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
-// is inactive it will be deleted after the ephemeralLifeTime period.
+// is inactive it will be deleted after the EphemeralLifeTime period.
func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) {
if !peer.Ephemeral {
return
@@ -180,20 +181,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
e.peersLock.Unlock()
- bufferAccountCall := make(map[string]struct{})
-
+ peerIDsPerAccount := make(map[string][]string)
for id, p := range deletePeers {
- log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
- err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator)
+ peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
+ }
+
+ for accountID, peerIDs := range peerIDsPerAccount {
+ log.WithContext(ctx).Debugf("delete ephemeral peers for account: %s", accountID)
+ err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
- } else {
- bufferAccountCall[p.accountID] = struct{}{}
}
}
- for accountID := range bufferAccountCall {
- e.accountManager.BufferUpdateAccountPeers(ctx, accountID)
- }
}
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
diff --git a/management/server/peers/ephemeral/manager/ephemeral_test.go b/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go
similarity index 69%
rename from management/server/peers/ephemeral/manager/ephemeral_test.go
rename to management/internals/modules/peers/ephemeral/manager/ephemeral_test.go
index fc7525c29..9d3ed246a 100644
--- a/management/server/peers/ephemeral/manager/ephemeral_test.go
+++ b/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go
@@ -7,10 +7,13 @@ import (
"testing"
"time"
+ "github.com/golang/mock/gomock"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
+ "github.com/netbirdio/netbird/management/internals/modules/peers"
+ "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
nbAccount "github.com/netbirdio/netbird/management/server/account"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
@@ -91,17 +94,27 @@ func TestNewManager(t *testing.T) {
}
store := &MockStore{}
- am := MockAccountManager{
- store: store,
- }
+ ctrl := gomock.NewController(t)
+ peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
- mgr := NewEphemeralManager(store, &am)
+ // Expect DeletePeers to be called for ephemeral peers
+ peersManager.EXPECT().
+ DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
+ DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
+ for _, peerID := range peerIDs {
+ delete(store.account.Peers, peerID)
+ }
+ return nil
+ }).
+ AnyTimes()
+
+ mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
- startTime = startTime.Add(ephemeralLifeTime + 1)
+ startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
if len(store.account.Peers) != numberOfPeers {
@@ -119,19 +132,29 @@ func TestNewManagerPeerConnected(t *testing.T) {
}
store := &MockStore{}
- am := MockAccountManager{
- store: store,
- }
+ ctrl := gomock.NewController(t)
+ peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
- mgr := NewEphemeralManager(store, &am)
+ // Expect DeletePeers to be called for ephemeral peers (except the connected one)
+ peersManager.EXPECT().
+ DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
+ DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
+ for _, peerID := range peerIDs {
+ delete(store.account.Peers, peerID)
+ }
+ return nil
+ }).
+ AnyTimes()
+
+ mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
- startTime = startTime.Add(ephemeralLifeTime + 1)
+ startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
expected := numberOfPeers + 1
@@ -150,15 +173,25 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
}
store := &MockStore{}
- am := MockAccountManager{
- store: store,
- }
+ ctrl := gomock.NewController(t)
+ peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
- mgr := NewEphemeralManager(store, &am)
+ // Expect DeletePeers to be called for the one disconnected peer
+ peersManager.EXPECT().
+ DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
+ DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
+ for _, peerID := range peerIDs {
+ delete(store.account.Peers, peerID)
+ }
+ return nil
+ }).
+ AnyTimes()
+
+ mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
for _, v := range store.account.Peers {
mgr.OnPeerConnected(context.Background(), v)
@@ -166,7 +199,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
}
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
- startTime = startTime.Add(ephemeralLifeTime + 1)
+ startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
expected := numberOfPeers + numberOfEphemeralPeers - 1
@@ -181,25 +214,63 @@ func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) {
testLifeTime = 1 * time.Second
testCleanupWindow = 100 * time.Millisecond
)
+
+ t.Cleanup(func() {
+ timeNow = time.Now
+ })
+ startTime := time.Now()
+ timeNow = func() time.Time {
+ return startTime
+ }
+
mockStore := &MockStore{}
+ account := newAccountWithId(context.Background(), "account", "", "", false)
+ mockStore.account = account
+
+ wg := &sync.WaitGroup{}
+ wg.Add(ephemeralPeers)
mockAM := &MockAccountManager{
store: mockStore,
+ wg: wg,
}
- mockAM.wg = &sync.WaitGroup{}
- mockAM.wg.Add(ephemeralPeers)
- mgr := NewEphemeralManager(mockStore, mockAM)
+
+ ctrl := gomock.NewController(t)
+ peersManager := peers.NewMockManager(ctrl)
+
+ // Set up expectation that DeletePeers will be called once with all peer IDs
+ peersManager.EXPECT().
+ DeletePeers(gomock.Any(), account.Id, gomock.Any(), gomock.Any(), true).
+ DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
+ // Simulate the actual deletion behavior
+ for _, peerID := range peerIDs {
+ err := mockAM.DeletePeer(ctx, accountID, peerID, userID)
+ if err != nil {
+ return err
+ }
+ }
+ mockAM.BufferUpdateAccountPeers(ctx, accountID)
+ return nil
+ }).
+ Times(1)
+
+ mgr := NewEphemeralManager(mockStore, peersManager)
mgr.lifeTime = testLifeTime
mgr.cleanupWindow = testCleanupWindow
- account := newAccountWithId(context.Background(), "account", "", "", false)
- mockStore.account = account
+ // Add peers and disconnect them at slightly different times (within cleanup window)
for i := range ephemeralPeers {
p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true}
mockStore.account.Peers[p.ID] = p
- time.Sleep(testCleanupWindow / ephemeralPeers)
mgr.OnPeerDisconnected(context.Background(), p)
+ startTime = startTime.Add(testCleanupWindow / (ephemeralPeers * 2))
}
- mockAM.wg.Wait()
+
+ // Advance time past the lifetime to trigger cleanup
+ startTime = startTime.Add(testLifeTime + testCleanupWindow)
+
+ // Wait for all deletions to complete
+ wg.Wait()
+
assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime")
assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once")
assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers")
diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go
new file mode 100644
index 000000000..b200b9663
--- /dev/null
+++ b/management/internals/modules/peers/manager.go
@@ -0,0 +1,164 @@
+package peers
+
+//go:generate go run github.com/golang/mock/mockgen -package peers -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
+ "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
+ "github.com/netbirdio/netbird/management/server/account"
+ "github.com/netbirdio/netbird/management/server/activity"
+ "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
+ "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/permissions"
+ "github.com/netbirdio/netbird/management/server/permissions/modules"
+ "github.com/netbirdio/netbird/management/server/permissions/operations"
+ "github.com/netbirdio/netbird/management/server/store"
+ "github.com/netbirdio/netbird/shared/management/status"
+)
+
+type Manager interface {
+ GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error)
+ GetPeerAccountID(ctx context.Context, peerID string) (string, error)
+ GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error)
+ GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error)
+ DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error
+ SetNetworkMapController(networkMapController network_map.Controller)
+ SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
+ SetAccountManager(accountManager account.Manager)
+}
+
+type managerImpl struct {
+ store store.Store
+ permissionsManager permissions.Manager
+ integratedPeerValidator integrated_validator.IntegratedValidator
+ accountManager account.Manager
+
+ networkMapController network_map.Controller
+}
+
+func NewManager(store store.Store, permissionsManager permissions.Manager) Manager {
+ return &managerImpl{
+ store: store,
+ permissionsManager: permissionsManager,
+ }
+}
+
+func (m *managerImpl) SetNetworkMapController(networkMapController network_map.Controller) {
+ m.networkMapController = networkMapController
+}
+
+func (m *managerImpl) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) {
+ m.integratedPeerValidator = integratedPeerValidator
+}
+
+func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
+ m.accountManager = accountManager
+}
+
+func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
+ allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
+ if err != nil {
+ return nil, fmt.Errorf("failed to validate user permissions: %w", err)
+ }
+
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
+ }
+
+ return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
+}
+
+func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
+ allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
+ if err != nil {
+ return nil, fmt.Errorf("failed to validate user permissions: %w", err)
+ }
+
+ if !allowed {
+ return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
+ }
+
+ return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
+}
+
+func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
+ return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
+}
+
+func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
+ return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
+}
+
+func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
+ settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
+ if err != nil {
+ return err
+ }
+ dnsDomain := m.networkMapController.GetDNSDomain(settings)
+
+ for _, peerID := range peerIDs {
+ var eventsToStore []func()
+ err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
+ peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
+ if err != nil {
+ return err
+ }
+
+ if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) {
+ return nil
+ }
+
+ if err := transaction.RemovePeerFromAllGroups(ctx, peerID); err != nil {
+ return fmt.Errorf("failed to remove peer %s from groups", peerID)
+ }
+
+ if err := m.integratedPeerValidator.PeerDeleted(ctx, accountID, peerID, settings.Extra); err != nil {
+ return err
+ }
+
+ peerPolicyRules, err := transaction.GetPolicyRulesByResourceID(ctx, store.LockingStrengthNone, accountID, peerID)
+ if err != nil {
+ return err
+ }
+ for _, rule := range peerPolicyRules {
+ policy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, rule.PolicyID)
+ if err != nil {
+ return err
+ }
+
+ err = transaction.DeletePolicy(ctx, accountID, rule.PolicyID)
+ if err != nil {
+ return err
+ }
+
+ eventsToStore = append(eventsToStore, func() {
+ m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
+ })
+ }
+
+ if err = transaction.DeletePeer(ctx, accountID, peerID); err != nil {
+ return err
+ }
+
+ eventsToStore = append(eventsToStore, func() {
+ m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
+ })
+
+ return nil
+ })
+ if err != nil {
+ return err
+ }
+ for _, event := range eventsToStore {
+ event()
+ }
+ }
+
+ m.accountManager.UpdateAccountPeers(ctx, accountID)
+
+ return nil
+}
diff --git a/management/server/peers/manager_mock.go b/management/internals/modules/peers/manager_mock.go
similarity index 55%
rename from management/server/peers/manager_mock.go
rename to management/internals/modules/peers/manager_mock.go
index 994f8346b..2e3651e88 100644
--- a/management/server/peers/manager_mock.go
+++ b/management/internals/modules/peers/manager_mock.go
@@ -9,6 +9,9 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
+ network_map "github.com/netbirdio/netbird/management/internals/controllers/network_map"
+ account "github.com/netbirdio/netbird/management/server/account"
+ integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
peer "github.com/netbirdio/netbird/management/server/peer"
)
@@ -35,6 +38,20 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
+// DeletePeers mocks base method.
+func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "DeletePeers", ctx, accountID, peerIDs, userID, checkConnected)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// DeletePeers indicates an expected call of DeletePeers.
+func (mr *MockManagerMockRecorder) DeletePeers(ctx, accountID, peerIDs, userID, checkConnected interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeers", reflect.TypeOf((*MockManager)(nil).DeletePeers), ctx, accountID, peerIDs, userID, checkConnected)
+}
+
// GetAllPeers mocks base method.
func (m *MockManager) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
m.ctrl.T.Helper()
@@ -94,3 +111,39 @@ func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs)
}
+
+// SetAccountManager mocks base method.
+func (m *MockManager) SetAccountManager(accountManager account.Manager) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetAccountManager", accountManager)
+}
+
+// SetAccountManager indicates an expected call of SetAccountManager.
+func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager)
+}
+
+// SetIntegratedPeerValidator mocks base method.
+func (m *MockManager) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetIntegratedPeerValidator", integratedPeerValidator)
+}
+
+// SetIntegratedPeerValidator indicates an expected call of SetIntegratedPeerValidator.
+func (mr *MockManagerMockRecorder) SetIntegratedPeerValidator(integratedPeerValidator interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetIntegratedPeerValidator", reflect.TypeOf((*MockManager)(nil).SetIntegratedPeerValidator), integratedPeerValidator)
+}
+
+// SetNetworkMapController mocks base method.
+func (m *MockManager) SetNetworkMapController(networkMapController network_map.Controller) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetNetworkMapController", networkMapController)
+}
+
+// SetNetworkMapController indicates an expected call of SetNetworkMapController.
+func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
+}
diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go
index 16e93a549..57b3fac78 100644
--- a/management/internals/server/boot.go
+++ b/management/internals/server/boot.go
@@ -10,9 +10,9 @@ import (
"slices"
"time"
- "github.com/google/uuid"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
+ "github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
@@ -22,7 +22,7 @@ import (
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
- "github.com/netbirdio/netbird/management/server"
+ nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/activity"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbhttp "github.com/netbirdio/netbird/management/server/http"
@@ -57,7 +57,7 @@ func (s *BaseServer) Metrics() telemetry.AppMetrics {
func (s *BaseServer) Store() store.Store {
return Create(s, func() store.Store {
- store, err := store.NewStore(context.Background(), s.config.StoreConfig.Engine, s.config.Datadir, s.Metrics(), false)
+ store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false)
if err != nil {
log.Fatalf("failed to create store: %v", err)
}
@@ -73,17 +73,17 @@ func (s *BaseServer) EventStore() activity.Store {
log.Fatalf("failed to initialize integration metrics: %v", err)
}
- eventStore, key, err := integrations.InitEventStore(context.Background(), s.config.Datadir, s.config.DataStoreEncryptionKey, integrationMetrics)
+ eventStore, key, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
if err != nil {
log.Fatalf("failed to initialize event store: %v", err)
}
- if s.config.DataStoreEncryptionKey != key {
- log.WithContext(context.Background()).Infof("update config with activity store key")
- s.config.DataStoreEncryptionKey = key
- err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.config)
+ if s.Config.DataStoreEncryptionKey != key {
+ log.WithContext(context.Background()).Infof("update Config with activity store key")
+ s.Config.DataStoreEncryptionKey = key
+ err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.Config)
if err != nil {
- log.Fatalf("failed to update config with activity store: %v", err)
+ log.Fatalf("failed to update Config with activity store: %v", err)
}
}
@@ -93,7 +93,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
- httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager())
+ httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController())
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
@@ -103,14 +103,14 @@ func (s *BaseServer) APIHandler() http.Handler {
func (s *BaseServer) GRPCServer() *grpc.Server {
return Create(s, func() *grpc.Server {
- trustedPeers := s.config.ReverseProxy.TrustedPeers
+ trustedPeers := s.Config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
log.WithContext(context.Background()).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
trustedPeers = defaultTrustedPeers
}
- trustedHTTPProxies := s.config.ReverseProxy.TrustedHTTPProxies
- trustedProxiesCount := s.config.ReverseProxy.TrustedHTTPProxiesCount
+ trustedHTTPProxies := s.Config.ReverseProxy.TrustedHTTPProxies
+ trustedProxiesCount := s.Config.ReverseProxy.TrustedHTTPProxiesCount
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
log.WithContext(context.Background()).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
@@ -128,15 +128,15 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
}
- if s.config.HttpConfig.LetsEncryptDomain != "" {
- certManager, err := encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
+ if s.Config.HttpConfig.LetsEncryptDomain != "" {
+ certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
if err != nil {
log.Fatalf("failed to create certificate manager: %v", err)
}
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
- } else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
- tlsConfig, err := loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
+ } else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" {
+ tlsConfig, err := loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey)
if err != nil {
log.Fatalf("cannot load TLS credentials: %v", err)
}
@@ -145,7 +145,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
}
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
- srv, err := server.NewServer(context.Background(), s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator())
+ srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController())
if err != nil {
log.Fatalf("failed to create management server: %v", err)
}
@@ -180,7 +180,7 @@ func unaryInterceptor(
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
- reqID := uuid.New().String()
+ reqID := xid.New().String()
//nolint
ctx = context.WithValue(ctx, hook.ExecutionContextKey, hook.GRPCSource)
//nolint
@@ -194,7 +194,7 @@ func streamInterceptor(
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
- reqID := uuid.New().String()
+ reqID := xid.New().String()
wrapped := grpcMiddleware.WrapServerStream(ss)
//nolint
ctx := context.WithValue(ss.Context(), hook.ExecutionContextKey, hook.GRPCSource)
diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go
index ddd81daa2..3442c7646 100644
--- a/management/internals/server/controllers.go
+++ b/management/internals/server/controllers.go
@@ -6,17 +6,21 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
+ nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
+ "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
+ "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
+ "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
- "github.com/netbirdio/netbird/management/server/peers/ephemeral"
- "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
)
-func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager {
- return Create(s, func() *server.PeersUpdateManager {
- return server.NewPeersUpdateManager(s.Metrics())
+func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
+ return Create(s, func() network_map.PeersUpdateManager {
+ return update_channel.NewPeersUpdateManager(s.Metrics())
})
}
@@ -40,26 +44,46 @@ func (s *BaseServer) ProxyController() port_forwarding.Controller {
})
}
-func (s *BaseServer) SecretsManager() *server.TimeBasedAuthSecretsManager {
- return Create(s, func() *server.TimeBasedAuthSecretsManager {
- return server.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager())
+func (s *BaseServer) SecretsManager() grpc.SecretsManager {
+ return Create(s, func() grpc.SecretsManager {
+ secretsManager, err := grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.Config.TURNConfig, s.Config.Relay, s.SettingsManager(), s.GroupsManager())
+ if err != nil {
+ log.Fatalf("failed to create secrets manager: %v", err)
+ }
+ return secretsManager
})
}
func (s *BaseServer) AuthManager() auth.Manager {
return Create(s, func() auth.Manager {
return auth.NewManager(s.Store(),
- s.config.HttpConfig.AuthIssuer,
- s.config.HttpConfig.AuthAudience,
- s.config.HttpConfig.AuthKeysLocation,
- s.config.HttpConfig.AuthUserIDClaim,
- s.config.GetAuthAudiences(),
- s.config.HttpConfig.IdpSignKeyRefreshEnabled)
+ s.Config.HttpConfig.AuthIssuer,
+ s.Config.HttpConfig.AuthAudience,
+ s.Config.HttpConfig.AuthKeysLocation,
+ s.Config.HttpConfig.AuthUserIDClaim,
+ s.Config.GetAuthAudiences(),
+ s.Config.HttpConfig.IdpSignKeyRefreshEnabled)
})
}
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
return Create(s, func() ephemeral.Manager {
- return manager.NewEphemeralManager(s.Store(), s.AccountManager())
+ return manager.NewEphemeralManager(s.Store(), s.PeersManager())
})
}
+
+func (s *BaseServer) NetworkMapController() network_map.Controller {
+ return Create(s, func() network_map.Controller {
+ return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.DNSDomain(), s.ProxyController(), s.EphemeralManager(), s.Config)
+ })
+}
+
+func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
+ return Create(s, func() *server.AccountRequestBuffer {
+ return server.NewAccountRequestBuffer(context.Background(), s.Store())
+ })
+}
+
+func (s *BaseServer) DNSDomain() string {
+ return s.dnsDomain
+}
diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go
index daec4ef6f..af9ca5f2d 100644
--- a/management/internals/server/modules.go
+++ b/management/internals/server/modules.go
@@ -2,10 +2,12 @@ package server
import (
"context"
+ "os"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
+ "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/geolocation"
@@ -14,20 +16,29 @@ import (
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
- "github.com/netbirdio/netbird/management/server/peers"
+
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/users"
)
+const (
+ geolocationDisabledKey = "NB_DISABLE_GEOLOCATION"
+)
+
func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
+ if os.Getenv(geolocationDisabledKey) == "true" {
+ log.Info("geolocation service is disabled, skipping initialization")
+ return nil
+ }
+
return Create(s, func() geolocation.Geolocation {
- geo, err := geolocation.NewGeolocation(context.Background(), s.config.Datadir, !s.disableGeoliteUpdate)
+ geo, err := geolocation.NewGeolocation(context.Background(), s.Config.Datadir, !s.disableGeoliteUpdate)
if err != nil {
log.Fatalf("could not initialize geolocation service: %v", err)
}
- log.Infof("geolocation service has been initialized from %s", s.config.Datadir)
+ log.Infof("geolocation service has been initialized from %s", s.Config.Datadir)
return geo
})
@@ -35,7 +46,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
func (s *BaseServer) PermissionsManager() permissions.Manager {
return Create(s, func() permissions.Manager {
- return integrations.InitPermissionsManager(s.Store())
+ manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter())
+
+ s.AfterInit(func(s *BaseServer) {
+ manager.SetAccountManager(s.AccountManager())
+ })
+
+ return manager
})
}
@@ -54,21 +71,22 @@ func (s *BaseServer) SettingsManager() settings.Manager {
func (s *BaseServer) PeersManager() peers.Manager {
return Create(s, func() peers.Manager {
- return peers.NewManager(s.Store(), s.PermissionsManager())
+ manager := peers.NewManager(s.Store(), s.PermissionsManager())
+ s.AfterInit(func(s *BaseServer) {
+ manager.SetNetworkMapController(s.NetworkMapController())
+ manager.SetIntegratedPeerValidator(s.IntegratedValidator())
+ manager.SetAccountManager(s.AccountManager())
+ })
+ return manager
})
}
func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager {
- accountManager, err := server.BuildManager(context.Background(), s.Store(), s.PeersUpdateManager(), s.IdpManager(), s.mgmtSingleAccModeDomain,
- s.dnsDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
+ accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
if err != nil {
log.Fatalf("failed to create account manager: %v", err)
}
-
- s.AfterInit(func(s *BaseServer) {
- accountManager.SetEphemeralManager(s.EphemeralManager())
- })
return accountManager
})
}
@@ -77,8 +95,8 @@ func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager {
var idpManager idp.Manager
var err error
- if s.config.IdpManagerConfig != nil {
- idpManager, err = idp.NewManager(context.Background(), *s.config.IdpManagerConfig, s.Metrics())
+ if s.Config.IdpManagerConfig != nil {
+ idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
if err != nil {
log.Fatalf("failed to create IDP manager: %v", err)
}
diff --git a/management/internals/server/server.go b/management/internals/server/server.go
index ab1c2ebe7..d9c715225 100644
--- a/management/internals/server/server.go
+++ b/management/internals/server/server.go
@@ -41,10 +41,10 @@ type Server interface {
}
// Server holds the HTTP BaseServer instance.
-// Add any additional fields you need, such as database connections, config, etc.
+// Add any additional fields you need, such as database connections, Config, etc.
type BaseServer struct {
- // config holds the server configuration
- config *nbconfig.Config
+ // Config holds the server configuration
+ Config *nbconfig.Config
// container of dependencies, each dependency is identified by a unique string.
container map[string]any
// AfterInit is a function that will be called after the server is initialized
@@ -70,7 +70,7 @@ type BaseServer struct {
// NewServer initializes and configures a new Server instance
func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer {
return &BaseServer{
- config: config,
+ Config: config,
container: make(map[string]any),
dnsDomain: dnsDomain,
mgmtSingleAccModeDomain: mgmtSingleAccModeDomain,
@@ -103,14 +103,14 @@ func (s *BaseServer) Start(ctx context.Context) error {
var tlsConfig *tls.Config
tlsEnabled := false
- if s.config.HttpConfig.LetsEncryptDomain != "" {
- s.certManager, err = encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
+ if s.Config.HttpConfig.LetsEncryptDomain != "" {
+ s.certManager, err = encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
if err != nil {
return fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
}
tlsEnabled = true
- } else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
- tlsConfig, err = loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
+ } else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" {
+ tlsConfig, err = loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey)
if err != nil {
log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err)
return err
@@ -126,8 +126,8 @@ func (s *BaseServer) Start(ctx context.Context) error {
if !s.disableMetrics {
idpManager := "disabled"
- if s.config.IdpManagerConfig != nil && s.config.IdpManagerConfig.ManagerType != "" {
- idpManager = s.config.IdpManagerConfig.ManagerType
+ if s.Config.IdpManagerConfig != nil && s.Config.IdpManagerConfig.ManagerType != "" {
+ idpManager = s.Config.IdpManagerConfig.ManagerType
}
metricsWorker := metrics.NewWorker(srvCtx, installationID, s.Store(), s.PeersUpdateManager(), idpManager)
go metricsWorker.Run(srvCtx)
@@ -183,7 +183,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
- s.update = version.NewUpdate("nb/management")
+ s.update = version.NewUpdateAndStart("nb/management")
s.update.SetDaemonVersion(version.NetbirdVersion())
s.update.SetOnUpdateListener(func() {
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())
diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go
new file mode 100644
index 000000000..f984c73df
--- /dev/null
+++ b/management/internals/shared/grpc/conversion.go
@@ -0,0 +1,450 @@
+package grpc
+
+import (
+ "context"
+ "fmt"
+ "net/url"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
+
+ integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
+ "github.com/netbirdio/netbird/client/ssh/auth"
+
+ nbdns "github.com/netbirdio/netbird/dns"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
+ nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/posture"
+ "github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/route"
+ "github.com/netbirdio/netbird/shared/management/proto"
+ "github.com/netbirdio/netbird/shared/sshauth"
+)
+
+func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
+ if config == nil {
+ return nil
+ }
+
+ var stuns []*proto.HostConfig
+ for _, stun := range config.Stuns {
+ stuns = append(stuns, &proto.HostConfig{
+ Uri: stun.URI,
+ Protocol: ToResponseProto(stun.Proto),
+ })
+ }
+
+ var turns []*proto.ProtectedHostConfig
+ if config.TURNConfig != nil {
+ for _, turn := range config.TURNConfig.Turns {
+ var username string
+ var password string
+ if turnCredentials != nil {
+ username = turnCredentials.Payload
+ password = turnCredentials.Signature
+ } else {
+ username = turn.Username
+ password = turn.Password
+ }
+ turns = append(turns, &proto.ProtectedHostConfig{
+ HostConfig: &proto.HostConfig{
+ Uri: turn.URI,
+ Protocol: ToResponseProto(turn.Proto),
+ },
+ User: username,
+ Password: password,
+ })
+ }
+ }
+
+ var relayCfg *proto.RelayConfig
+ if config.Relay != nil && len(config.Relay.Addresses) > 0 {
+ relayCfg = &proto.RelayConfig{
+ Urls: config.Relay.Addresses,
+ }
+
+ if relayToken != nil {
+ relayCfg.TokenPayload = relayToken.Payload
+ relayCfg.TokenSignature = relayToken.Signature
+ }
+ }
+
+ var signalCfg *proto.HostConfig
+ if config.Signal != nil {
+ signalCfg = &proto.HostConfig{
+ Uri: config.Signal.URI,
+ Protocol: ToResponseProto(config.Signal.Proto),
+ }
+ }
+
+ nbConfig := &proto.NetbirdConfig{
+ Stuns: stuns,
+ Turns: turns,
+ Signal: signalCfg,
+ Relay: relayCfg,
+ }
+
+ return nbConfig
+}
+
+func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, enableSSH bool) *proto.PeerConfig {
+ netmask, _ := network.Net.Mask.Size()
+ fqdn := peer.FQDN(dnsName)
+
+ sshConfig := &proto.SSHConfig{
+ SshEnabled: peer.SSHEnabled || enableSSH,
+ }
+
+ if sshConfig.SshEnabled {
+ sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig)
+ }
+
+ return &proto.PeerConfig{
+ Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask),
+ SshConfig: sshConfig,
+ Fqdn: fqdn,
+ RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
+ LazyConnectionEnabled: settings.LazyConnectionEnabled,
+ AutoUpdate: &proto.AutoUpdateSettings{
+ Version: settings.AutoUpdateVersion,
+ },
+ }
+}
+
+func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
+ response := &proto.SyncResponse{
+ PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
+ NetworkMap: &proto.NetworkMap{
+ Serial: networkMap.Network.CurrentSerial(),
+ Routes: toProtocolRoutes(networkMap.Routes),
+ DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
+ PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
+ },
+ Checks: toProtocolChecks(ctx, checks),
+ }
+
+ nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
+ extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
+ response.NetbirdConfig = extendedConfig
+
+ response.NetworkMap.PeerConfig = response.PeerConfig
+
+ remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
+ remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName)
+ response.RemotePeers = remotePeers
+ response.NetworkMap.RemotePeers = remotePeers
+ response.RemotePeersIsEmpty = len(remotePeers) == 0
+ response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
+
+ response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
+
+ firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
+ response.NetworkMap.FirewallRules = firewallRules
+ response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
+
+ routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
+ response.NetworkMap.RoutesFirewallRules = routesFirewallRules
+ response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
+
+ if networkMap.ForwardingRules != nil {
+ forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules))
+ for _, rule := range networkMap.ForwardingRules {
+ forwardingRules = append(forwardingRules, rule.ToProto())
+ }
+ response.NetworkMap.ForwardingRules = forwardingRules
+ }
+
+ if networkMap.AuthorizedUsers != nil {
+ hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
+ userIDClaim := auth.DefaultUserIDClaim
+ if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
+ userIDClaim = httpConfig.AuthUserIDClaim
+ }
+ response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
+ }
+
+ return response
+}
+
+func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
+ userIDToIndex := make(map[string]uint32)
+ var hashedUsers [][]byte
+ machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
+
+ for machineUser, users := range authorizedUsers {
+ indexes := make([]uint32, 0, len(users))
+ for userID := range users {
+ idx, exists := userIDToIndex[userID]
+ if !exists {
+ hash, err := sshauth.HashUserID(userID)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
+ continue
+ }
+ idx = uint32(len(hashedUsers))
+ userIDToIndex[userID] = idx
+ hashedUsers = append(hashedUsers, hash[:])
+ }
+ indexes = append(indexes, idx)
+ }
+ machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
+ }
+
+ return hashedUsers, machineUsers
+}
+
+func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
+ for _, rPeer := range peers {
+ dst = append(dst, &proto.RemotePeerConfig{
+ WgPubKey: rPeer.Key,
+ AllowedIps: []string{rPeer.IP.String() + "/32"},
+ SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
+ Fqdn: rPeer.FQDN(dnsName),
+ AgentVersion: rPeer.Meta.WtVersion,
+ })
+ }
+ return dst
+}
+
+// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
+func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig {
+ protoUpdate := &proto.DNSConfig{
+ ServiceEnable: update.ServiceEnable,
+ CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
+ NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
+ ForwarderPort: forwardPort,
+ }
+
+ for _, zone := range update.CustomZones {
+ protoZone := convertToProtoCustomZone(zone)
+ protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
+ }
+
+ for _, nsGroup := range update.NameServerGroups {
+ cacheKey := nsGroup.ID
+ if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
+ protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
+ } else {
+ protoGroup := convertToProtoNameServerGroup(nsGroup)
+ cache.SetNameServerGroup(cacheKey, protoGroup)
+ protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
+ }
+ }
+
+ return protoUpdate
+}
+
+func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
+ switch configProto {
+ case nbconfig.UDP:
+ return proto.HostConfig_UDP
+ case nbconfig.DTLS:
+ return proto.HostConfig_DTLS
+ case nbconfig.HTTP:
+ return proto.HostConfig_HTTP
+ case nbconfig.HTTPS:
+ return proto.HostConfig_HTTPS
+ case nbconfig.TCP:
+ return proto.HostConfig_TCP
+ default:
+ panic(fmt.Errorf("unexpected config protocol type %v", configProto))
+ }
+}
+
+func toProtocolRoutes(routes []*route.Route) []*proto.Route {
+ protoRoutes := make([]*proto.Route, 0, len(routes))
+ for _, r := range routes {
+ protoRoutes = append(protoRoutes, toProtocolRoute(r))
+ }
+ return protoRoutes
+}
+
+func toProtocolRoute(route *route.Route) *proto.Route {
+ return &proto.Route{
+ ID: string(route.ID),
+ NetID: string(route.NetID),
+ Network: route.Network.String(),
+ Domains: route.Domains.ToPunycodeList(),
+ NetworkType: int64(route.NetworkType),
+ Peer: route.Peer,
+ Metric: int64(route.Metric),
+ Masquerade: route.Masquerade,
+ KeepRoute: route.KeepRoute,
+ SkipAutoApply: route.SkipAutoApply,
+ }
+}
+
+// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
+func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule {
+ result := make([]*proto.FirewallRule, len(rules))
+ for i := range rules {
+ rule := rules[i]
+
+ fwRule := &proto.FirewallRule{
+ PolicyID: []byte(rule.PolicyID),
+ PeerIP: rule.PeerIP,
+ Direction: getProtoDirection(rule.Direction),
+ Action: getProtoAction(rule.Action),
+ Protocol: getProtoProtocol(rule.Protocol),
+ Port: rule.Port,
+ }
+
+ if shouldUsePortRange(fwRule) {
+ fwRule.PortInfo = rule.PortRange.ToProto()
+ }
+
+ result[i] = fwRule
+ }
+ return result
+}
+
+// getProtoDirection converts the direction to proto.RuleDirection.
+func getProtoDirection(direction int) proto.RuleDirection {
+ if direction == types.FirewallRuleDirectionOUT {
+ return proto.RuleDirection_OUT
+ }
+ return proto.RuleDirection_IN
+}
+
+func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
+ result := make([]*proto.RouteFirewallRule, len(rules))
+ for i := range rules {
+ rule := rules[i]
+ result[i] = &proto.RouteFirewallRule{
+ SourceRanges: rule.SourceRanges,
+ Action: getProtoAction(rule.Action),
+ Destination: rule.Destination,
+ Protocol: getProtoProtocol(rule.Protocol),
+ PortInfo: getProtoPortInfo(rule),
+ IsDynamic: rule.IsDynamic,
+ Domains: rule.Domains.ToPunycodeList(),
+ PolicyID: []byte(rule.PolicyID),
+ RouteID: string(rule.RouteID),
+ }
+ }
+
+ return result
+}
+
+// getProtoAction converts the action to proto.RuleAction.
+func getProtoAction(action string) proto.RuleAction {
+ if action == string(types.PolicyTrafficActionDrop) {
+ return proto.RuleAction_DROP
+ }
+ return proto.RuleAction_ACCEPT
+}
+
+// getProtoProtocol converts the protocol to proto.RuleProtocol.
+func getProtoProtocol(protocol string) proto.RuleProtocol {
+ switch types.PolicyRuleProtocolType(protocol) {
+ case types.PolicyRuleProtocolALL:
+ return proto.RuleProtocol_ALL
+ case types.PolicyRuleProtocolTCP:
+ return proto.RuleProtocol_TCP
+ case types.PolicyRuleProtocolUDP:
+ return proto.RuleProtocol_UDP
+ case types.PolicyRuleProtocolICMP:
+ return proto.RuleProtocol_ICMP
+ default:
+ return proto.RuleProtocol_UNKNOWN
+ }
+}
+
+// getProtoPortInfo converts the port info to proto.PortInfo.
+func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
+ var portInfo proto.PortInfo
+ if rule.Port != 0 {
+ portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
+ } else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
+ portInfo.PortSelection = &proto.PortInfo_Range_{
+ Range: &proto.PortInfo_Range{
+ Start: uint32(portRange.Start),
+ End: uint32(portRange.End),
+ },
+ }
+ }
+ return &portInfo
+}
+
+func shouldUsePortRange(rule *proto.FirewallRule) bool {
+ return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
+}
+
+// Helper function to convert nbdns.CustomZone to proto.CustomZone
+func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
+ protoZone := &proto.CustomZone{
+ Domain: zone.Domain,
+ Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
+ }
+ for _, record := range zone.Records {
+ protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
+ Name: record.Name,
+ Type: int64(record.Type),
+ Class: record.Class,
+ TTL: int64(record.TTL),
+ RData: record.RData,
+ })
+ }
+ return protoZone
+}
+
+// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
+func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
+ protoGroup := &proto.NameServerGroup{
+ Primary: nsGroup.Primary,
+ Domains: nsGroup.Domains,
+ SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
+ NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
+ }
+ for _, ns := range nsGroup.NameServers {
+ protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
+ IP: ns.IP.String(),
+ Port: int64(ns.Port),
+ NSType: int64(ns.NSType),
+ })
+ }
+ return protoGroup
+}
+
+// buildJWTConfig constructs JWT configuration for SSH servers from management server config
+func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig {
+ if config == nil || config.AuthAudience == "" {
+ return nil
+ }
+
+ issuer := strings.TrimSpace(config.AuthIssuer)
+ if issuer == "" && deviceFlowConfig != nil {
+ if d := deriveIssuerFromTokenEndpoint(deviceFlowConfig.ProviderConfig.TokenEndpoint); d != "" {
+ issuer = d
+ }
+ }
+ if issuer == "" {
+ return nil
+ }
+
+ keysLocation := strings.TrimSpace(config.AuthKeysLocation)
+ if keysLocation == "" {
+ keysLocation = strings.TrimSuffix(issuer, "/") + "/.well-known/jwks.json"
+ }
+
+ return &proto.JWTConfig{
+ Issuer: issuer,
+ Audience: config.AuthAudience,
+ KeysLocation: keysLocation,
+ }
+}
+
+// deriveIssuerFromTokenEndpoint extracts the issuer URL from a token endpoint
+func deriveIssuerFromTokenEndpoint(tokenEndpoint string) string {
+ if tokenEndpoint == "" {
+ return ""
+ }
+
+ u, err := url.Parse(tokenEndpoint)
+ if err != nil {
+ return ""
+ }
+
+ return fmt.Sprintf("%s://%s/", u.Scheme, u.Host)
+}
diff --git a/management/internals/shared/grpc/conversion_test.go b/management/internals/shared/grpc/conversion_test.go
new file mode 100644
index 000000000..701271345
--- /dev/null
+++ b/management/internals/shared/grpc/conversion_test.go
@@ -0,0 +1,150 @@
+package grpc
+
+import (
+ "fmt"
+ "net/netip"
+ "reflect"
+ "testing"
+
+ nbdns "github.com/netbirdio/netbird/dns"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
+)
+
+func TestToProtocolDNSConfigWithCache(t *testing.T) {
+ var cache cache.DNSConfigCache
+
+ // Create two different configs
+ config1 := nbdns.Config{
+ ServiceEnable: true,
+ CustomZones: []nbdns.CustomZone{
+ {
+ Domain: "example.com",
+ Records: []nbdns.SimpleRecord{
+ {Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
+ },
+ },
+ },
+ NameServerGroups: []*nbdns.NameServerGroup{
+ {
+ ID: "group1",
+ Name: "Group 1",
+ NameServers: []nbdns.NameServer{
+ {IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
+ },
+ },
+ },
+ }
+
+ config2 := nbdns.Config{
+ ServiceEnable: true,
+ CustomZones: []nbdns.CustomZone{
+ {
+ Domain: "example.org",
+ Records: []nbdns.SimpleRecord{
+ {Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
+ },
+ },
+ },
+ NameServerGroups: []*nbdns.NameServerGroup{
+ {
+ ID: "group2",
+ Name: "Group 2",
+ NameServers: []nbdns.NameServer{
+ {IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
+ },
+ },
+ },
+ }
+
+ // First run with config1
+ result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
+
+ // Second run with config2
+ result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
+
+ // Third run with config1 again
+ result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
+
+ // Verify that result1 and result3 are identical
+ if !reflect.DeepEqual(result1, result3) {
+ t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
+ }
+
+ // Verify that result2 is different from result1 and result3
+ if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
+ t.Errorf("Results should be different for different inputs")
+ }
+
+ if _, exists := cache.GetNameServerGroup("group1"); !exists {
+ t.Errorf("Cache should contain name server group 'group1'")
+ }
+
+ if _, exists := cache.GetNameServerGroup("group2"); !exists {
+ t.Errorf("Cache should contain name server group 'group2'")
+ }
+}
+
+func BenchmarkToProtocolDNSConfig(b *testing.B) {
+ sizes := []int{10, 100, 1000}
+
+ for _, size := range sizes {
+ testData := generateTestData(size)
+
+ b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
+ cache := &cache.DNSConfigCache{}
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
+ }
+ })
+
+ b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ cache := &cache.DNSConfigCache{}
+ toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
+ }
+ })
+ }
+}
+
+func generateTestData(size int) nbdns.Config {
+ config := nbdns.Config{
+ ServiceEnable: true,
+ CustomZones: make([]nbdns.CustomZone, size),
+ NameServerGroups: make([]*nbdns.NameServerGroup, size),
+ }
+
+ for i := 0; i < size; i++ {
+ config.CustomZones[i] = nbdns.CustomZone{
+ Domain: fmt.Sprintf("domain%d.com", i),
+ Records: []nbdns.SimpleRecord{
+ {
+ Name: fmt.Sprintf("record%d", i),
+ Type: 1,
+ Class: "IN",
+ TTL: 3600,
+ RData: "192.168.1.1",
+ },
+ },
+ }
+
+ config.NameServerGroups[i] = &nbdns.NameServerGroup{
+ ID: fmt.Sprintf("group%d", i),
+ Primary: i == 0,
+ Domains: []string{fmt.Sprintf("domain%d.com", i)},
+ SearchDomainsEnabled: true,
+ NameServers: []nbdns.NameServer{
+ {
+ IP: netip.MustParseAddr("8.8.8.8"),
+ Port: 53,
+ NSType: 1,
+ },
+ },
+ }
+ }
+
+ return config
+}
diff --git a/management/server/loginfilter.go b/management/internals/shared/grpc/loginfilter.go
similarity index 99%
rename from management/server/loginfilter.go
rename to management/internals/shared/grpc/loginfilter.go
index 8604af6e2..59f69dd90 100644
--- a/management/server/loginfilter.go
+++ b/management/internals/shared/grpc/loginfilter.go
@@ -1,4 +1,4 @@
-package server
+package grpc
import (
"hash/fnv"
diff --git a/management/server/loginfilter_test.go b/management/internals/shared/grpc/loginfilter_test.go
similarity index 99%
rename from management/server/loginfilter_test.go
rename to management/internals/shared/grpc/loginfilter_test.go
index 65782dd9d..8b26e14ab 100644
--- a/management/server/loginfilter_test.go
+++ b/management/internals/shared/grpc/loginfilter_test.go
@@ -1,4 +1,4 @@
-package server
+package grpc
import (
"hash/fnv"
diff --git a/management/server/grpcserver.go b/management/internals/shared/grpc/server.go
similarity index 69%
rename from management/server/grpcserver.go
rename to management/internals/shared/grpc/server.go
index 12b59b691..ad6b34c5f 100644
--- a/management/server/grpcserver.go
+++ b/management/internals/shared/grpc/server.go
@@ -1,4 +1,4 @@
-package server
+package grpc
import (
"context"
@@ -7,8 +7,10 @@ import (
"net"
"net/netip"
"os"
+ "strconv"
"strings"
"sync"
+ "sync/atomic"
"time"
pb "github.com/golang/protobuf/proto" // nolint
@@ -20,9 +22,8 @@ import (
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
- integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
- "github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/store"
@@ -44,49 +45,49 @@ import (
const (
envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS"
envBlockPeers = "NB_BLOCK_SAME_PEERS"
+ envConcurrentSyncs = "NB_MAX_CONCURRENT_SYNCS"
+
+ defaultSyncLim = 1000
)
-// GRPCServer an instance of a Management gRPC API server
-type GRPCServer struct {
+// Server an instance of a Management gRPC API server
+type Server struct {
accountManager account.Manager
settingsManager settings.Manager
- wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer
- peersUpdateManager *PeersUpdateManager
- config *nbconfig.Config
- secretsManager SecretsManager
- appMetrics telemetry.AppMetrics
- ephemeralManager ephemeral.Manager
- peerLocks sync.Map
- authManager auth.Manager
+ config *nbconfig.Config
+ secretsManager SecretsManager
+ appMetrics telemetry.AppMetrics
+ peerLocks sync.Map
+ authManager auth.Manager
logBlockedPeers bool
blockPeersWithSameConfig bool
integratedPeerValidator integrated_validator.IntegratedValidator
+
+ loginFilter *loginFilter
+
+ networkMapController network_map.Controller
+
+ syncSem atomic.Int32
+ syncLim int32
}
// NewServer creates a new Management server
func NewServer(
- ctx context.Context,
config *nbconfig.Config,
accountManager account.Manager,
settingsManager settings.Manager,
- peersUpdateManager *PeersUpdateManager,
secretsManager SecretsManager,
appMetrics telemetry.AppMetrics,
- ephemeralManager ephemeral.Manager,
authManager auth.Manager,
integratedPeerValidator integrated_validator.IntegratedValidator,
-) (*GRPCServer, error) {
- key, err := wgtypes.GeneratePrivateKey()
- if err != nil {
- return nil, err
- }
-
+ networkMapController network_map.Controller,
+) (*Server, error) {
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
- err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
- return int64(len(peersUpdateManager.peerChannels))
+ err := appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
+ return int64(networkMapController.CountStreams())
})
if err != nil {
return nil, err
@@ -96,24 +97,36 @@ func NewServer(
logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true"
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
- return &GRPCServer{
- wgKey: key,
- // peerKey -> event channel
- peersUpdateManager: peersUpdateManager,
+ syncLim := int32(defaultSyncLim)
+ if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
+ syncLimParsed, err := strconv.Atoi(syncLimStr)
+ if err != nil {
+ log.Errorf("invalid value for %s: %v using %d", envConcurrentSyncs, err, defaultSyncLim)
+ } else {
+ //nolint:gosec
+ syncLim = int32(syncLimParsed)
+ }
+ }
+
+ return &Server{
accountManager: accountManager,
settingsManager: settingsManager,
config: config,
secretsManager: secretsManager,
authManager: authManager,
appMetrics: appMetrics,
- ephemeralManager: ephemeralManager,
logBlockedPeers: logBlockedPeers,
blockPeersWithSameConfig: blockPeersWithSameConfig,
integratedPeerValidator: integratedPeerValidator,
+ networkMapController: networkMapController,
+
+ loginFilter: newLoginFilter(),
+
+ syncLim: syncLim,
}, nil
}
-func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
+func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
ip := ""
p, ok := peer.FromContext(ctx)
if ok {
@@ -121,10 +134,6 @@ func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto
}
log.WithContext(ctx).Tracef("GetServerKey request from %s", ip)
- start := time.Now()
- defer func() {
- log.WithContext(ctx).Tracef("GetServerKey from %s took %v", ip, time.Since(start))
- }()
// todo introduce something more meaningful with the key expiration/rotation
if s.appMetrics != nil {
@@ -135,8 +144,14 @@ func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto
nanos := int32(now.Nanosecond())
expiresAt := ×tamp.Timestamp{Seconds: secs, Nanos: nanos}
+ key, err := s.secretsManager.GetWGKey()
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to get wireguard key: %v", err)
+ return nil, errors.New("failed to get wireguard key")
+ }
+
return &proto.ServerKeyResponse{
- Key: s.wgKey.PublicKey().String(),
+ Key: key.PublicKey().String(),
ExpiresAt: expiresAt,
}, nil
}
@@ -150,7 +165,12 @@ func getRealIP(ctx context.Context) net.IP {
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account)
-func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
+func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
+ if s.syncSem.Load() >= s.syncLim {
+ return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
+ }
+ s.syncSem.Add(1)
+
reqStart := time.Now()
ctx := srv.Context()
@@ -158,20 +178,22 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
syncReq := &proto.SyncRequest{}
peerKey, err := s.parseRequest(ctx, req, syncReq)
if err != nil {
+ s.syncSem.Add(-1)
return err
}
realIP := getRealIP(ctx)
sRealIP := realIP.String()
peerMeta := extractPeerMeta(ctx, syncReq.GetMeta())
metahashed := metaHash(peerMeta, sRealIP)
- if !s.accountManager.AllowSync(peerKey.String(), metahashed) {
+ if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
}
if s.logBlockedPeers {
- log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
+ log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
}
if s.blockPeersWithSameConfig {
+ s.syncSem.Add(-1)
return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn)
}
}
@@ -183,48 +205,61 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
- unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
- defer func() {
- if unlock != nil {
- unlock()
- }
- }()
-
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
+ s.syncSem.Add(-1)
return status.Errorf(codes.PermissionDenied, "peer is not registered")
}
+ s.syncSem.Add(-1)
return err
}
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
+ start := time.Now()
+ unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
+ defer func() {
+ if unlock != nil {
+ unlock()
+ }
+ }()
+ log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
+
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
if syncReq.GetMeta() == nil {
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
- peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
+ metahash := metaHash(peerMeta, realIP.String())
+ s.loginFilter.addLogin(peerKey.String(), metahash)
+
+ peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
+ s.syncSem.Add(-1)
return mapError(ctx, err)
}
- err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
+ err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv, dnsFwdPort)
if err != nil {
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
+ s.syncSem.Add(-1)
return err
}
- updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
-
- s.ephemeralManager.OnPeerConnected(ctx, peer)
+ updates, err := s.networkMapController.OnPeerConnected(ctx, accountID, peer.ID)
+ if err != nil {
+ log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
+ s.syncSem.Add(-1)
+ s.cancelPeerRoutines(ctx, accountID, peer)
+ return err
+ }
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
@@ -235,13 +270,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
unlock()
unlock = nil
- log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart))
+ s.syncSem.Add(-1)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
-func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
+func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
for {
select {
@@ -275,14 +310,20 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
-func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
- encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
+func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
+ key, err := s.secretsManager.GetWGKey()
+ if err != nil {
+ s.cancelPeerRoutines(ctx, accountID, peer)
+ return status.Errorf(codes.Internal, "failed processing update message")
+ }
+
+ encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.SendMsg(&proto.EncryptedMessage{
- WgPubKey: s.wgKey.PublicKey().String(),
+ WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
})
if err != nil {
@@ -293,7 +334,7 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
return nil
}
-func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
+func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock()
@@ -301,14 +342,13 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p
if err != nil {
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
}
- s.peersUpdateManager.CloseChannel(ctx, peer.ID)
+ s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID)
s.secretsManager.CancelRefresh(peer.ID)
- s.ephemeralManager.OnPeerDisconnected(ctx, peer)
- log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
+ log.WithContext(ctx).Debugf("peer %s has been disconnected", peer.Key)
}
-func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
+func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
if s.authManager == nil {
return "", status.Errorf(codes.Internal, "missing auth manager")
}
@@ -342,7 +382,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
return userAuth.UserId, nil
}
-func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
+func (s *Server) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
start := time.Now()
@@ -450,14 +490,19 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
}
}
-func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
+func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
}
- err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed)
+ key, err := s.secretsManager.GetWGKey()
+ if err != nil {
+ return wgtypes.Key{}, status.Errorf(codes.Internal, "failed processing request")
+ }
+
+ err = encryption.DecryptMessage(peerKey, key, req.Body, parsed)
if err != nil {
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message")
}
@@ -469,11 +514,10 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa
// In case it is, the login is successful
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
// In case of the successful registration login is also successful
-func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
+func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
reqStart := time.Now()
realIP := getRealIP(ctx)
sRealIP := realIP.String()
- log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
loginReq := &proto.LoginRequest{}
peerKey, err := s.parseRequest(ctx, req, loginReq)
@@ -483,9 +527,9 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
metahashed := metaHash(peerMeta, sRealIP)
- if !s.accountManager.AllowSync(peerKey.String(), metahashed) {
+ if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.logBlockedPeers {
- log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
+ log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
}
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestBlocked()
@@ -509,6 +553,8 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
+ log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
+
defer func() {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
@@ -546,30 +592,31 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
return nil, mapError(ctx, err)
}
- // if the login request contains setup key then it is a registration request
- if loginReq.GetSetupKey() != "" {
- s.ephemeralManager.OnPeerDisconnected(ctx, peer)
- }
-
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
if err != nil {
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
- encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
+ key, err := s.secretsManager.GetWGKey()
+ if err != nil {
+ log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err)
+ return nil, status.Errorf(codes.Internal, "failed logging in peer")
+ }
+
+ encryptedResp, err := encryption.EncryptMessage(peerKey, key, loginResp)
if err != nil {
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
return &proto.EncryptedMessage{
- WgPubKey: s.wgKey.PublicKey().String(),
+ WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
}, nil
}
-func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
+func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
var relayToken *Token
var err error
if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
@@ -588,7 +635,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
- PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), settings),
+ PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH),
Checks: toProtocolChecks(ctx, postureChecks),
}
@@ -600,7 +647,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer
//
// The user ID can be empty if the token is not provided, which is acceptable if the peer is already
// registered or if it uses a setup key to register.
-func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
+func (s *Server) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
userID := ""
if loginReq.GetJwtToken() != "" {
var err error
@@ -620,166 +667,13 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR
return userID, nil
}
-func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
- switch configProto {
- case nbconfig.UDP:
- return proto.HostConfig_UDP
- case nbconfig.DTLS:
- return proto.HostConfig_DTLS
- case nbconfig.HTTP:
- return proto.HostConfig_HTTP
- case nbconfig.HTTPS:
- return proto.HostConfig_HTTPS
- case nbconfig.TCP:
- return proto.HostConfig_TCP
- default:
- panic(fmt.Errorf("unexpected config protocol type %v", configProto))
- }
-}
-
-func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
- if config == nil {
- return nil
- }
-
- var stuns []*proto.HostConfig
- for _, stun := range config.Stuns {
- stuns = append(stuns, &proto.HostConfig{
- Uri: stun.URI,
- Protocol: ToResponseProto(stun.Proto),
- })
- }
-
- var turns []*proto.ProtectedHostConfig
- if config.TURNConfig != nil {
- for _, turn := range config.TURNConfig.Turns {
- var username string
- var password string
- if turnCredentials != nil {
- username = turnCredentials.Payload
- password = turnCredentials.Signature
- } else {
- username = turn.Username
- password = turn.Password
- }
- turns = append(turns, &proto.ProtectedHostConfig{
- HostConfig: &proto.HostConfig{
- Uri: turn.URI,
- Protocol: ToResponseProto(turn.Proto),
- },
- User: username,
- Password: password,
- })
- }
- }
-
- var relayCfg *proto.RelayConfig
- if config.Relay != nil && len(config.Relay.Addresses) > 0 {
- relayCfg = &proto.RelayConfig{
- Urls: config.Relay.Addresses,
- }
-
- if relayToken != nil {
- relayCfg.TokenPayload = relayToken.Payload
- relayCfg.TokenSignature = relayToken.Signature
- }
- }
-
- var signalCfg *proto.HostConfig
- if config.Signal != nil {
- signalCfg = &proto.HostConfig{
- Uri: config.Signal.URI,
- Protocol: ToResponseProto(config.Signal.Proto),
- }
- }
-
- nbConfig := &proto.NetbirdConfig{
- Stuns: stuns,
- Turns: turns,
- Signal: signalCfg,
- Relay: relayCfg,
- }
-
- return nbConfig
-}
-
-func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig {
- netmask, _ := network.Net.Mask.Size()
- fqdn := peer.FQDN(dnsName)
- return &proto.PeerConfig{
- Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network
- SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled},
- Fqdn: fqdn,
- RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
- LazyConnectionEnabled: settings.LazyConnectionEnabled,
- }
-}
-
-func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
- response := &proto.SyncResponse{
- PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
- NetworkMap: &proto.NetworkMap{
- Serial: networkMap.Network.CurrentSerial(),
- Routes: toProtocolRoutes(networkMap.Routes),
- DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
- },
- Checks: toProtocolChecks(ctx, checks),
- }
-
- nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
- extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
- response.NetbirdConfig = extendedConfig
-
- response.NetworkMap.PeerConfig = response.PeerConfig
-
- remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
- remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName)
- response.RemotePeers = remotePeers
- response.NetworkMap.RemotePeers = remotePeers
- response.RemotePeersIsEmpty = len(remotePeers) == 0
- response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
-
- response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
-
- firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
- response.NetworkMap.FirewallRules = firewallRules
- response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
-
- routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
- response.NetworkMap.RoutesFirewallRules = routesFirewallRules
- response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
-
- if networkMap.ForwardingRules != nil {
- forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules))
- for _, rule := range networkMap.ForwardingRules {
- forwardingRules = append(forwardingRules, rule.ToProto())
- }
- response.NetworkMap.ForwardingRules = forwardingRules
- }
-
- return response
-}
-
-func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
- for _, rPeer := range peers {
- dst = append(dst, &proto.RemotePeerConfig{
- WgPubKey: rPeer.Key,
- AllowedIps: []string{rPeer.IP.String() + "/32"},
- SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
- Fqdn: rPeer.FQDN(dnsName),
- AgentVersion: rPeer.Meta.WtVersion,
- })
- }
- return dst
-}
-
// IsHealthy indicates whether the service is healthy
-func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) {
+func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) {
return &proto.Empty{}, nil
}
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
-func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
+func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer, dnsFwdPort int64) error {
var err error
var turnToken *Token
@@ -803,27 +697,25 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
return status.Errorf(codes.Internal, "error handling request")
}
- peerGroups, err := getPeerGroupIDs(ctx, s.accountManager.GetStore(), peer.AccountID, peer.ID)
+ peerGroups, err := s.accountManager.GetStore().GetPeerGroupIDs(ctx, store.LockingStrengthNone, peer.AccountID, peer.ID)
if err != nil {
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
}
- // Get all peers in the account for forwarder port computation
- allPeers, err := s.accountManager.GetStore().GetAccountPeers(ctx, store.LockingStrengthNone, peer.AccountID, "", "")
+ plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
+
+ key, err := s.secretsManager.GetWGKey()
if err != nil {
- return fmt.Errorf("get account peers: %w", err)
+ return status.Errorf(codes.Internal, "failed getting server key")
}
- dnsFwdPort := computeForwarderPort(allPeers, dnsForwarderPortMinVersion)
- plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
-
- encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
+ encryptedResp, err := encryption.EncryptMessage(peerKey, key, plainResp)
if err != nil {
return status.Errorf(codes.Internal, "error handling request")
}
err = srv.Send(&proto.EncryptedMessage{
- WgPubKey: s.wgKey.PublicKey().String(),
+ WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
})
@@ -838,12 +730,8 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
// GetDeviceAuthorizationFlow returns a device authorization flow information
// This is used for initiating an Oauth 2 device authorization grant flow
// which will be used by our clients to Login
-func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
+func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
- start := time.Now()
- defer func() {
- log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow for pubKey: %s took %v", req.WgPubKey, time.Since(start))
- }()
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
@@ -852,7 +740,12 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
return nil, status.Error(codes.InvalidArgument, errMSG)
}
- err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{})
+ key, err := s.secretsManager.GetWGKey()
+ if err != nil {
+ return nil, status.Errorf(codes.Internal, "failed to get server key")
+ }
+
+ err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.DeviceAuthorizationFlowRequest{})
if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.WithContext(ctx).Warn(errMSG)
@@ -882,13 +775,13 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
},
}
- encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
+ encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp)
if err != nil {
return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information")
}
return &proto.EncryptedMessage{
- WgPubKey: s.wgKey.PublicKey().String(),
+ WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
}, nil
}
@@ -896,12 +789,8 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
// GetPKCEAuthorizationFlow returns a pkce authorization flow information
// This is used for initiating an Oauth 2 pkce authorization grant flow
// which will be used by our clients to Login
-func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
+func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
- start := time.Now()
- defer func() {
- log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow for pubKey %s took %v", req.WgPubKey, time.Since(start))
- }()
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
@@ -910,7 +799,12 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
return nil, status.Error(codes.InvalidArgument, errMSG)
}
- err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{})
+ key, err := s.secretsManager.GetWGKey()
+ if err != nil {
+ return nil, status.Errorf(codes.Internal, "failed to get server key")
+ }
+
+ err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.PKCEAuthorizationFlowRequest{})
if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.WithContext(ctx).Warn(errMSG)
@@ -938,20 +832,20 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow)
- encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
+ encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp)
if err != nil {
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")
}
return &proto.EncryptedMessage{
- WgPubKey: s.wgKey.PublicKey().String(),
+ WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
}, nil
}
// SyncMeta endpoint is used to synchronize peer's system metadata and notifies the connected,
// peer's under the same account of any updates.
-func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
+func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
realIP := getRealIP(ctx)
log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
@@ -976,7 +870,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
return &proto.Empty{}, nil
}
-func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
+func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
start := time.Now()
diff --git a/management/internals/shared/grpc/server_test.go b/management/internals/shared/grpc/server_test.go
new file mode 100644
index 000000000..d3a12e986
--- /dev/null
+++ b/management/internals/shared/grpc/server_test.go
@@ -0,0 +1,108 @@
+package grpc
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
+ "github.com/netbirdio/netbird/encryption"
+ "github.com/netbirdio/netbird/management/internals/server/config"
+ mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
+)
+
+func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
+ testingServerKey, err := wgtypes.GeneratePrivateKey()
+ if err != nil {
+ t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
+ }
+
+ testingClientKey, err := wgtypes.GeneratePrivateKey()
+ if err != nil {
+ t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
+ }
+
+ testCases := []struct {
+ name string
+ inputFlow *config.DeviceAuthorizationFlow
+ expectedFlow *mgmtProto.DeviceAuthorizationFlow
+ expectedErrFunc require.ErrorAssertionFunc
+ expectedErrMSG string
+ expectedComparisonFunc require.ComparisonAssertionFunc
+ expectedComparisonMSG string
+ }{
+ {
+ name: "Testing No Device Flow Config",
+ inputFlow: nil,
+ expectedErrFunc: require.Error,
+ expectedErrMSG: "should return error",
+ },
+ {
+ name: "Testing Invalid Device Flow Provider Config",
+ inputFlow: &config.DeviceAuthorizationFlow{
+ Provider: "NoNe",
+ ProviderConfig: config.ProviderConfig{
+ ClientID: "test",
+ },
+ },
+ expectedErrFunc: require.Error,
+ expectedErrMSG: "should return error",
+ },
+ {
+ name: "Testing Full Device Flow Config",
+ inputFlow: &config.DeviceAuthorizationFlow{
+ Provider: "hosted",
+ ProviderConfig: config.ProviderConfig{
+ ClientID: "test",
+ },
+ },
+ expectedFlow: &mgmtProto.DeviceAuthorizationFlow{
+ Provider: 0,
+ ProviderConfig: &mgmtProto.ProviderConfig{
+ ClientID: "test",
+ },
+ },
+ expectedErrFunc: require.NoError,
+ expectedErrMSG: "should not return error",
+ expectedComparisonFunc: require.Equal,
+ expectedComparisonMSG: "should match",
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.name, func(t *testing.T) {
+ mgmtServer := &Server{
+ secretsManager: &TimeBasedAuthSecretsManager{wgKey: testingServerKey},
+ config: &config.Config{
+ DeviceAuthorizationFlow: testCase.inputFlow,
+ },
+ }
+
+ message := &mgmtProto.DeviceAuthorizationFlowRequest{}
+ key, err := mgmtServer.secretsManager.GetWGKey()
+ require.NoError(t, err, "should be able to get server key")
+
+ encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), key, message)
+ require.NoError(t, err, "should be able to encrypt message")
+
+ resp, err := mgmtServer.GetDeviceAuthorizationFlow(
+ context.TODO(),
+ &mgmtProto.EncryptedMessage{
+ WgPubKey: testingClientKey.PublicKey().String(),
+ Body: encryptedMSG,
+ },
+ )
+ testCase.expectedErrFunc(t, err, testCase.expectedErrMSG)
+ if testCase.expectedComparisonFunc != nil {
+ flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
+
+ err = encryption.DecryptMessage(key.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
+ require.NoError(t, err, "should be able to decrypt")
+
+ testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
+ testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG)
+ }
+ })
+ }
+}
diff --git a/management/server/token_mgr.go b/management/internals/shared/grpc/token_mgr.go
similarity index 87%
rename from management/server/token_mgr.go
rename to management/internals/shared/grpc/token_mgr.go
index f9293e7a8..ccb32202f 100644
--- a/management/server/token_mgr.go
+++ b/management/internals/shared/grpc/token_mgr.go
@@ -1,4 +1,4 @@
-package server
+package grpc
import (
"context"
@@ -10,8 +10,10 @@ import (
"time"
log "github.com/sirupsen/logrus"
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings"
@@ -28,6 +30,7 @@ type SecretsManager interface {
GenerateRelayToken() (*Token, error)
SetupRefresh(ctx context.Context, accountID, peerKey string)
CancelRefresh(peerKey string)
+ GetWGKey() (wgtypes.Key, error)
}
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
@@ -37,16 +40,22 @@ type TimeBasedAuthSecretsManager struct {
relayCfg *nbconfig.Relay
turnHmacToken *auth.TimedHMAC
relayHmacToken *authv2.Generator
- updateManager *PeersUpdateManager
+ updateManager network_map.PeersUpdateManager
settingsManager settings.Manager
groupsManager groups.Manager
turnCancelMap map[string]chan struct{}
relayCancelMap map[string]chan struct{}
+ wgKey wgtypes.Key
}
type Token auth.Token
-func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
+func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) (*TimeBasedAuthSecretsManager, error) {
+ key, err := wgtypes.GeneratePrivateKey()
+ if err != nil {
+ return nil, err
+ }
+
mgr := &TimeBasedAuthSecretsManager{
updateManager: updateManager,
turnCfg: turnCfg,
@@ -55,6 +64,7 @@ func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *
relayCancelMap: make(map[string]chan struct{}),
settingsManager: settingsManager,
groupsManager: groupsManager,
+ wgKey: key,
}
if turnCfg != nil {
@@ -80,7 +90,12 @@ func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *
}
}
- return mgr
+ return mgr, nil
+}
+
+// GetWGKey returns WireGuard private key used to generate peer keys
+func (m *TimeBasedAuthSecretsManager) GetWGKey() (wgtypes.Key, error) {
+ return m.wgKey, nil
}
// GenerateTurnToken generates new time-based secret credentials for TURN
@@ -152,7 +167,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, accountI
relayCancel := make(chan struct{}, 1)
m.relayCancelMap[peerID] = relayCancel
go m.refreshRelayTokens(ctx, accountID, peerID, relayCancel)
- log.WithContext(ctx).Debugf("starting relay refresh for %s", peerID)
+ log.WithContext(ctx).Tracef("starting relay refresh for %s", peerID)
}
}
@@ -163,7 +178,7 @@ func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, acc
for {
select {
case <-cancel:
- log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID)
+ log.WithContext(ctx).Tracef("stopping TURN refresh for %s", peerID)
return
case <-ticker.C:
m.pushNewTURNAndRelayTokens(ctx, accountID, peerID)
@@ -178,7 +193,7 @@ func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, ac
for {
select {
case <-cancel:
- log.WithContext(ctx).Debugf("stopping relay refresh for %s", peerID)
+ log.WithContext(ctx).Tracef("stopping relay refresh for %s", peerID)
return
case <-ticker.C:
m.pushNewRelayTokens(ctx, accountID, peerID)
@@ -227,7 +242,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
- m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
+ m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
}
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
@@ -251,7 +266,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
- m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
+ m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
}
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
diff --git a/management/server/token_mgr_test.go b/management/internals/shared/grpc/token_mgr_test.go
similarity index 90%
rename from management/server/token_mgr_test.go
rename to management/internals/shared/grpc/token_mgr_test.go
index 5c956dc31..98eb66fb5 100644
--- a/management/server/token_mgr_test.go
+++ b/management/internals/shared/grpc/token_mgr_test.go
@@ -1,4 +1,4 @@
-package server
+package grpc
import (
"context"
@@ -13,6 +13,8 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings"
@@ -31,7 +33,7 @@ var TurnTestHost = &config.Host{
func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
ttl := util.Duration{Duration: time.Hour}
secret := "some_secret"
- peersManager := NewPeersUpdateManager(nil)
+ peersManager := update_channel.NewPeersUpdateManager(nil)
rc := &config.Relay{
Addresses: []string{"localhost:0"},
@@ -44,12 +46,13 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
- tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
+ tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
+ require.NoError(t, err)
turnCredentials, err := tested.GenerateTurnToken()
require.NoError(t, err)
@@ -80,7 +83,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
ttl := util.Duration{Duration: 2 * time.Second}
secret := "some_secret"
- peersManager := NewPeersUpdateManager(nil)
+ peersManager := update_channel.NewPeersUpdateManager(nil)
peer := "some_peer"
updateChannel := peersManager.CreateChannel(context.Background(), peer)
@@ -96,12 +99,13 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
groupsManager := groups.NewManagerMock()
- tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
+ tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
+ require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -116,7 +120,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
t.Errorf("expecting peer to be present in the relay cancel map, got not present")
}
- var updates []*UpdateMessage
+ var updates []*network_map.UpdateMessage
loop:
for timeout := time.After(5 * time.Second); ; {
@@ -185,7 +189,7 @@ loop:
func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
ttl := util.Duration{Duration: time.Hour}
secret := "some_secret"
- peersManager := NewPeersUpdateManager(nil)
+ peersManager := update_channel.NewPeersUpdateManager(nil)
peer := "some_peer"
rc := &config.Relay{
@@ -199,12 +203,13 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
- tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
+ tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
+ require.NoError(t, err)
tested.SetupRefresh(context.Background(), "someAccountID", peer)
if _, ok := tested.turnCancelMap[peer]; !ok {
diff --git a/management/main.go b/management/main.go
index 561ed8f26..ff8482f97 100644
--- a/management/main.go
+++ b/management/main.go
@@ -1,11 +1,19 @@
package main
import (
- "github.com/netbirdio/netbird/management/cmd"
+ "log"
+ "net/http"
+ // nolint:gosec
+ _ "net/http/pprof"
"os"
+
+ "github.com/netbirdio/netbird/management/cmd"
)
func main() {
+ go func() {
+ log.Println(http.ListenAndServe("localhost:6060", nil))
+ }()
if err := cmd.Execute(); err != nil {
os.Exit(1)
}
diff --git a/management/server/account.go b/management/server/account.go
index dca105ddf..405a3c0f6 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -11,12 +11,12 @@ import (
"reflect"
"regexp"
"slices"
- "strconv"
"strings"
"sync"
- "sync/atomic"
"time"
+ "github.com/netbirdio/netbird/shared/auth"
+
cacheStore "github.com/eko/gocache/lib/v4/store"
"github.com/eko/gocache/store/redis/v4"
"github.com/rs/xid"
@@ -26,6 +26,8 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter/hook"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
+ nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
@@ -35,7 +37,6 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
- "github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -68,28 +69,29 @@ type DefaultAccountManager struct {
cacheMux sync.Mutex
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
cacheLoading map[string]chan struct{}
- peersUpdateManager *PeersUpdateManager
+ networkMapController network_map.Controller
idpManager idp.Manager
cacheManager *nbcache.AccountUserDataCache
externalCacheManager nbcache.UserDataCache
ctx context.Context
eventStore activity.Store
geo geolocation.Geolocation
- ephemeralManager ephemeral.Manager
requestBuffer *AccountRequestBuffer
proxyController port_forwarding.Controller
settingsManager settings.Manager
+ // config contains the management server configuration
+ config *nbconfig.Config
+
// singleAccountMode indicates whether the instance has a single account.
// If true, then every new user will end up under the same account.
// This value will be set to false if management service has more than one account.
singleAccountMode bool
// singleAccountModeDomain is a domain to use in singleAccountMode setup
singleAccountModeDomain string
- // dnsDomain is used for peer resolution. This is appended to the peer's name
- dnsDomain string
+
peerLoginExpiry Scheduler
peerInactivityExpiry Scheduler
@@ -103,14 +105,11 @@ type DefaultAccountManager struct {
permissionsManager permissions.Manager
- accountUpdateLocks sync.Map
- updateAccountPeersBufferInterval atomic.Int64
-
- loginFilter *loginFilter
-
disableDefaultPolicy bool
}
+var _ account.Manager = (*DefaultAccountManager)(nil)
+
func isUniqueConstraintError(err error) bool {
switch {
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
@@ -176,11 +175,11 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups []
// BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(
ctx context.Context,
+ config *nbconfig.Config,
store store.Store,
- peersUpdateManager *PeersUpdateManager,
+ networkMapController network_map.Controller,
idpManager idp.Manager,
singleAccountModeDomain string,
- dnsDomain string,
eventStore activity.Store,
geo geolocation.Geolocation,
userDeleteFromIDPEnabled bool,
@@ -198,13 +197,13 @@ func BuildManager(
am := &DefaultAccountManager{
Store: store,
+ config: config,
geo: geo,
- peersUpdateManager: peersUpdateManager,
+ networkMapController: networkMapController,
idpManager: idpManager,
ctx: context.Background(),
cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{},
- dnsDomain: dnsDomain,
eventStore: eventStore,
peerLoginExpiry: NewDefaultScheduler(),
peerInactivityExpiry: NewDefaultScheduler(),
@@ -215,11 +214,10 @@ func BuildManager(
proxyController: proxyController,
settingsManager: settingsManager,
permissionsManager: permissionsManager,
- loginFilter: newLoginFilter(),
disableDefaultPolicy: disableDefaultPolicy,
}
- am.startWarmup(ctx)
+ am.networkMapController.StartWarmup(ctx)
accountsCounter, err := store.GetAccountsCounter(ctx)
if err != nil {
@@ -238,7 +236,7 @@ func BuildManager(
log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", accountsCounter)
}
- cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval)
+ cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn)
if err != nil {
return nil, fmt.Errorf("getting cache store: %s", err)
}
@@ -263,36 +261,6 @@ func BuildManager(
return am, nil
}
-func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) {
- am.ephemeralManager = em
-}
-
-func (am *DefaultAccountManager) startWarmup(ctx context.Context) {
- var initialInterval int64
- intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS")
- interval, err := strconv.Atoi(intervalStr)
- if err != nil {
- initialInterval = 1
- log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err)
- } else {
- initialInterval = int64(interval) * 10
- go func() {
- startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S")
- startupPeriod, err := strconv.Atoi(startupPeriodStr)
- if err != nil {
- startupPeriod = 1
- log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err)
- }
- time.Sleep(time.Duration(startupPeriod) * time.Second)
- am.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond))
- log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval)
- }()
- }
- am.updateAccountPeersBufferInterval.Store(initialInterval)
- log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval)
-
-}
-
func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager {
return am.externalCacheManager
}
@@ -327,10 +295,23 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return err
}
- if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil {
+ if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil {
return err
}
+ if oldSettings.Extra != nil && newSettings.Extra != nil &&
+ oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled {
+ approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID)
+ if err != nil {
+ return fmt.Errorf("failed to approve pending peers: %w", err)
+ }
+
+ if approvedCount > 0 {
+ log.WithContext(ctx).Debugf("approved %d pending peers in account %s", approvedCount, accountID)
+ updateAccountPeers = true
+ }
+ }
+
if oldSettings.NetworkRange != newSettings.NetworkRange {
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
return err
@@ -340,7 +321,8 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled ||
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
- oldSettings.DNSDomain != newSettings.DNSDomain {
+ oldSettings.DNSDomain != newSettings.DNSDomain ||
+ oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion {
updateAccountPeers = true
}
@@ -351,6 +333,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
}
}
+ newSettings.Extra.IntegratedValidatorGroups = oldSettings.Extra.IntegratedValidatorGroups
+ newSettings.Extra.IntegratedValidator = oldSettings.Extra.IntegratedValidator
+
if err = transaction.SaveAccountSettings(ctx, accountID, newSettings); err != nil {
return err
}
@@ -376,6 +361,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.handleLazyConnectionSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
+ am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID)
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
return nil, err
}
@@ -401,7 +387,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return newSettings, nil
}
-func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
+func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@@ -415,17 +401,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
}
- peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
- if err != nil {
- return err
- }
-
- peersMap := make(map[string]*nbpeer.Peer, len(peers))
- for _, peer := range peers {
- peersMap[peer.ID] = peer
- }
-
- return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peersMap, userID, accountID)
+ return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID)
}
func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
@@ -477,6 +453,14 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con
}
}
+func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
+ if oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion {
+ am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountAutoUpdateVersionUpdated, map[string]any{
+ "version": newSettings.AutoUpdateVersion,
+ })
+ }
+}
+
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
if newSettings.PeerInactivityExpirationEnabled {
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
@@ -816,6 +800,13 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
accountIDString := fmt.Sprintf("%v", accountID)
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ // nolint:staticcheck
+ ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
+
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
if err != nil {
return nil, nil, err
@@ -1040,7 +1031,7 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun
}
// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes
-func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth nbcontext.UserAuth,
+func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth auth.UserAuth,
primaryDomain bool,
) error {
if userAuth.Domain == "" {
@@ -1089,7 +1080,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
ctx context.Context,
userAccountID string,
domainAccountID string,
- userAuth nbcontext.UserAuth,
+ userAuth auth.UserAuth,
) error {
primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, userAuth, primaryDomain)
@@ -1108,7 +1099,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
// otherwise it will create a new account and make it primary account for the domain.
-func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
+func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) {
if userAuth.UserId == "" {
return "", fmt.Errorf("user ID is empty")
}
@@ -1139,7 +1130,7 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
return newAccount.Id, nil
}
-func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
+func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) {
newUser := types.NewRegularUser(userAuth.UserId)
newUser.AccountID = domainAccountID
@@ -1251,7 +1242,7 @@ func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accou
onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
- log.Errorf("failed to get account onboarding for accountssssssss %s: %v", accountID, err)
+ log.Errorf("failed to get account onboarding for account %s: %v", accountID, err)
return nil, err
}
@@ -1303,7 +1294,7 @@ func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, ac
return newOnboarding, nil
}
-func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
+func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) {
if userAuth.UserId == "" {
return "", "", errors.New(emptyUserID)
}
@@ -1347,7 +1338,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled.
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
-func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error {
+func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error {
if userAuth.IsChild || userAuth.IsPAT {
return nil
}
@@ -1465,21 +1456,19 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
}
}
- if settings.GroupsPropagationEnabled {
- removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
- if err != nil {
- return err
- }
+ removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
+ if err != nil {
+ return err
+ }
- newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
- if err != nil {
- return err
- }
+ newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
+ if err != nil {
+ return err
+ }
- if removedGroupAffectsPeers || newGroupsAffectsPeers {
- log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
- am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
- }
+ if removedGroupAffectsPeers || newGroupsAffectsPeers {
+ log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
+ am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
}
return nil
@@ -1505,7 +1494,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
//
// UserAuth IsChild -> checks that account exists
-func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
+func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth auth.UserAuth) (string, error) {
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
userAuth.UserId, userAuth.AccountId, userAuth.Domain, userAuth.DomainCategory)
@@ -1584,7 +1573,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
return domainAccountID, cancel, nil
}
-func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
+func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth auth.UserAuth) (string, error) {
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
@@ -1632,23 +1621,14 @@ func handleNotFound(err error) error {
return nil
}
-func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.UserAuth) bool {
+func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAuth) bool {
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
}
-func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool {
- return am.loginFilter.allowLogin(wgPubKey, metahash)
-}
-
-func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
- start := time.Now()
- defer func() {
- log.WithContext(ctx).Debugf("SyncAndMarkPeer: took %v", time.Since(start))
- }()
-
- peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
+func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
+ peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
- return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
+ return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID)
@@ -1656,10 +1636,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
- metahash := metaHash(meta, realIP.String())
- am.loginFilter.addLogin(peerPubKey, metahash)
-
- return peer, netMap, postureChecks, nil
+ return peer, netMap, postureChecks, dnsfwdPort, nil
}
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
@@ -1676,41 +1653,19 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
return err
}
- _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
+ _, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
if err != nil {
- return mapError(ctx, err)
+ return err
}
return nil
}
-// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
-func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
- return am.peersUpdateManager.GetAllConnectedPeers(), nil
-}
-
-// HasConnectedChannel returns true if peers has channel in update manager, otherwise false
-func (am *DefaultAccountManager) HasConnectedChannel(peerID string) bool {
- return am.peersUpdateManager.HasChannel(peerID)
-}
-
var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
func isDomainValid(domain string) bool {
return invalidDomainRegexp.MatchString(domain)
}
-// GetDNSDomain returns the configured dnsDomain
-func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string {
- if settings == nil {
- return am.dnsDomain
- }
- if settings.DNSDomain == "" {
- return am.dnsDomain
- }
-
- return settings.DNSDomain
-}
-
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) {
peers := []*nbpeer.Peer{}
log.WithContext(ctx).Debugf("invalidating peers %v for account %s", peerIDs, accountID)
@@ -2129,7 +2084,14 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
}
if updateNetworkMap {
- am.BufferUpdateAccountPeers(ctx, accountID)
+ peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
+ if err != nil {
+ return err
+ }
+ err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID})
+ if err != nil {
+ return fmt.Errorf("notify network map controller of peer update: %w", err)
+ }
}
return nil
}
@@ -2177,7 +2139,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti
if err != nil {
return fmt.Errorf("get account settings: %w", err)
}
- dnsDomain := am.GetDNSDomain(settings)
+ dnsDomain := am.networkMapController.GetDNSDomain(settings)
eventMeta := peer.EventMeta(dnsDomain)
oldIP := peer.IP.String()
diff --git a/management/server/account/manager.go b/management/server/account/manager.go
index a1ed9498b..b5921ec7a 100644
--- a/management/server/account/manager.go
+++ b/management/server/account/manager.go
@@ -6,13 +6,13 @@ import (
"net/netip"
"time"
+ "github.com/netbirdio/netbird/shared/auth"
+
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
- nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
- "github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -45,10 +45,10 @@ type Manager interface {
GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error)
AccountExists(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
- GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
+ GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
DeleteAccount(ctx context.Context, accountID, userID string) error
GetUserByID(ctx context.Context, id string) (*types.User, error)
- GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
+ GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
@@ -89,7 +89,6 @@ type Manager interface {
SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
- GetDNSDomain(settings *types.Settings) string
StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error)
@@ -97,10 +96,8 @@ type Manager interface {
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
- LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
- SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
- GetAllConnectedPeers() (map[string]struct{}, error)
- HasConnectedChannel(peerID string) bool
+ LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
+ SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error)
@@ -109,8 +106,8 @@ type Manager interface {
GetIdpManager() idp.Manager
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
- GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error)
- SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
+ GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
+ SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
@@ -120,12 +117,10 @@ type Manager interface {
UpdateAccountPeers(ctx context.Context, accountID string)
BufferUpdateAccountPeers(ctx context.Context, accountID string)
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
- SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
+ SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error
GetStore() store.Store
GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
- GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
- SetEphemeralManager(em ephemeral.Manager)
- AllowSync(string, uint64) bool
+ GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
}
diff --git a/management/server/account/request_buffer.go b/management/server/account/request_buffer.go
new file mode 100644
index 000000000..eced1929f
--- /dev/null
+++ b/management/server/account/request_buffer.go
@@ -0,0 +1,11 @@
+package account
+
+import (
+ "context"
+
+ "github.com/netbirdio/netbird/management/server/types"
+)
+
+type RequestBuffer interface {
+ GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error)
+}
diff --git a/management/server/account_test.go b/management/server/account_test.go
index 07d2f2383..25818ada2 100644
--- a/management/server/account_test.go
+++ b/management/server/account_test.go
@@ -22,10 +22,15 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
+ "github.com/netbirdio/netbird/management/internals/modules/peers"
+ ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
+ "github.com/netbirdio/netbird/management/internals/server/config"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache"
- nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -42,6 +47,7 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
+ "github.com/netbirdio/netbird/shared/auth"
)
func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account *types.Account, userID string) {
@@ -391,7 +397,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
- networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
+ networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
}
@@ -406,7 +412,7 @@ func TestNewAccount(t *testing.T) {
}
func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -442,7 +448,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
}
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
- type initUserParams nbcontext.UserAuth
+ type initUserParams auth.UserAuth
var (
publicDomain = "public.com"
@@ -465,7 +471,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
testCases := []struct {
name string
- inputClaims nbcontext.UserAuth
+ inputClaims auth.UserAuth
inputInitUserParams initUserParams
inputUpdateAttrs bool
inputUpdateClaimAccount bool
@@ -480,7 +486,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
}{
{
name: "New User With Public Domain",
- inputClaims: nbcontext.UserAuth{
+ inputClaims: auth.UserAuth{
Domain: publicDomain,
UserId: "pub-domain-user",
DomainCategory: types.PublicCategory,
@@ -497,7 +503,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "New User With Unknown Domain",
- inputClaims: nbcontext.UserAuth{
+ inputClaims: auth.UserAuth{
Domain: unknownDomain,
UserId: "unknown-domain-user",
DomainCategory: types.UnknownCategory,
@@ -514,7 +520,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "New User With Private Domain",
- inputClaims: nbcontext.UserAuth{
+ inputClaims: auth.UserAuth{
Domain: privateDomain,
UserId: "pvt-domain-user",
DomainCategory: types.PrivateCategory,
@@ -531,7 +537,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "New Regular User With Existing Private Domain",
- inputClaims: nbcontext.UserAuth{
+ inputClaims: auth.UserAuth{
Domain: privateDomain,
UserId: "new-pvt-domain-user",
DomainCategory: types.PrivateCategory,
@@ -549,7 +555,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "Existing User With Existing Reclassified Private Domain",
- inputClaims: nbcontext.UserAuth{
+ inputClaims: auth.UserAuth{
Domain: defaultInitAccount.Domain,
UserId: defaultInitAccount.UserId,
DomainCategory: types.PrivateCategory,
@@ -566,7 +572,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "Existing Account Id With Existing Reclassified Private Domain",
- inputClaims: nbcontext.UserAuth{
+ inputClaims: auth.UserAuth{
Domain: defaultInitAccount.Domain,
UserId: defaultInitAccount.UserId,
DomainCategory: types.PrivateCategory,
@@ -584,7 +590,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "User With Private Category And Empty Domain",
- inputClaims: nbcontext.UserAuth{
+ inputClaims: auth.UserAuth{
Domain: "",
UserId: "pvt-domain-user",
DomainCategory: types.PrivateCategory,
@@ -603,7 +609,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain)
@@ -613,7 +619,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
require.NoError(t, err, "get init account failed")
if testCase.inputUpdateAttrs {
- err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, nbcontext.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
+ err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, auth.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed")
}
@@ -644,7 +650,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
userId := "user-id"
domain := "test.domain"
_ = newAccountWithId(context.Background(), "", userId, domain, false)
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
require.NoError(t, err, "create init user failed")
@@ -653,7 +659,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")
- claims := nbcontext.UserAuth{
+ claims := auth.UserAuth{
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
Domain: domain,
UserId: userId,
@@ -705,7 +711,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
}
func TestAccountManager_PrivateAccount(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -731,7 +737,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
}
func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -768,7 +774,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
}
func TestAccountManager_GetAccountByUserID(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -805,7 +811,7 @@ func createAccount(am *DefaultAccountManager, accountID, userID, domain string)
}
func TestAccountManager_GetAccount(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -843,7 +849,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
}
func TestAccountManager_DeleteAccount(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -912,19 +918,19 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
}
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
- claims := nbcontext.UserAuth{
+ claims := auth.UserAuth{
Domain: "example.com",
UserId: "pvt-domain-user",
DomainCategory: types.PrivateCategory,
}
- publicClaims := nbcontext.UserAuth{
+ publicClaims := auth.UserAuth{
Domain: "test.com",
UserId: "public-domain-user",
DomainCategory: types.PublicCategory,
}
- am, err := createManager(b)
+ am, _, err := createManager(b)
if err != nil {
b.Fatal(err)
return
@@ -1016,7 +1022,7 @@ func genUsers(p string, n int) map[string]*types.User {
}
func TestAccountManager_AddPeer(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -1086,7 +1092,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
}
func TestAccountManager_AddPeerWithUserID(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -1154,8 +1160,17 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
}
+func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
+ t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
+ testAccountManager_NetworkUpdates_SaveGroup(t)
+}
+
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
- manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
+ testAccountManager_NetworkUpdates_SaveGroup(t)
+}
+
+func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
+ manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
group := types.Group{
ID: "groupA",
@@ -1181,8 +1196,8 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
}, true)
require.NoError(t, err)
- updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
- defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
+ updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
+ defer updateManager.CloseChannel(context.Background(), peer1.ID)
wg := sync.WaitGroup{}
wg.Add(1)
@@ -1205,11 +1220,20 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
wg.Wait()
}
-func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
- manager, account, peer1, _, _ := setupNetworkMapTest(t)
+func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
+ t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
+ testAccountManager_NetworkUpdates_DeletePolicy(t)
+}
- updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
- defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
+func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
+ testAccountManager_NetworkUpdates_DeletePolicy(t)
+}
+
+func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
+ manager, updateManager, account, peer1, _, _ := setupNetworkMapTest(t)
+
+ updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
+ defer updateManager.CloseChannel(context.Background(), peer1.ID)
// Ensure that we do not receive an update message before the policy is deleted
time.Sleep(time.Second)
@@ -1239,8 +1263,17 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
wg.Wait()
}
+func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
+ t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
+ testAccountManager_NetworkUpdates_SavePolicy(t)
+}
+
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
- manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
+ testAccountManager_NetworkUpdates_SavePolicy(t)
+}
+
+func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
+ manager, updateManager, account, peer1, peer2, _ := setupNetworkMapTest(t)
group := types.Group{
AccountID: account.Id,
@@ -1253,8 +1286,8 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
return
}
- updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
- defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
+ updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
+ defer updateManager.CloseChannel(context.Background(), peer1.ID)
wg := sync.WaitGroup{}
wg.Add(1)
@@ -1288,8 +1321,17 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
wg.Wait()
}
+func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
+ t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
+ testAccountManager_NetworkUpdates_DeletePeer(t)
+}
+
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
- manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
+ testAccountManager_NetworkUpdates_DeletePeer(t)
+}
+
+func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
+ manager, updateManager, account, peer1, _, peer3 := setupNetworkMapTest(t)
group := types.Group{
ID: "groupA",
@@ -1318,8 +1360,11 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
return
}
- updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
- defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
+ // We need to sleep to wait for the buffer peer update
+ time.Sleep(300 * time.Millisecond)
+
+ updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
+ defer updateManager.CloseChannel(context.Background(), peer1.ID)
wg := sync.WaitGroup{}
wg.Add(1)
@@ -1341,11 +1386,20 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
wg.Wait()
}
-func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
- manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
+func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
+ t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
+ testAccountManager_NetworkUpdates_DeleteGroup(t)
+}
- updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
- defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
+func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
+ testAccountManager_NetworkUpdates_DeleteGroup(t)
+}
+
+func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
+ manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
+
+ updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
+ defer updateManager.CloseChannel(context.Background(), peer1.ID)
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
@@ -1377,6 +1431,14 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
return
}
+ for drained := false; !drained; {
+ select {
+ case <-updMsg:
+ default:
+ drained = true
+ }
+ }
+
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
@@ -1404,7 +1466,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
}
func TestAccountManager_DeletePeer(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -1485,7 +1547,7 @@ func getEvent(t *testing.T, accountID string, manager nbAccount.Manager, eventTy
}
func TestGetUsersFromAccount(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -1736,7 +1798,9 @@ func TestAccount_Copy(t *testing.T) {
Address: "172.12.6.1/24",
},
},
+ NetworkMapCache: &types.NetworkMapBuilder{},
}
+ account.InitOnce()
err := hasNilField(account)
if err != nil {
t.Fatal(err)
@@ -1782,7 +1846,7 @@ func hasNilField(x interface{}) error {
}
func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -1797,7 +1861,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
}
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -1853,7 +1917,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
}
func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -1896,7 +1960,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -1958,7 +2022,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
}
func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -1994,6 +2058,43 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
}
+func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T) {
+ manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
+
+ accountID := account.Id
+ userID := account.Users[account.CreatedBy].Id
+ ctx := context.Background()
+
+ newSettings := account.Settings.Copy()
+ newSettings.Extra = &types.ExtraSettings{
+ PeerApprovalEnabled: true,
+ }
+ _, err := manager.UpdateAccountSettings(ctx, accountID, userID, newSettings)
+ require.NoError(t, err)
+
+ peer1.Status.RequiresApproval = true
+ peer2.Status.RequiresApproval = true
+ peer3.Status.RequiresApproval = false
+
+ require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer1))
+ require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer2))
+ require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer3))
+
+ newSettings = account.Settings.Copy()
+ newSettings.Extra = &types.ExtraSettings{
+ PeerApprovalEnabled: false,
+ }
+ _, err = manager.UpdateAccountSettings(ctx, accountID, userID, newSettings)
+ require.NoError(t, err)
+
+ accountPeers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
+ require.NoError(t, err)
+
+ for _, peer := range accountPeers {
+ assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval after disabling peer approval", peer.ID)
+ }
+}
+
func TestAccount_GetExpiredPeers(t *testing.T) {
type test struct {
name string
@@ -2622,7 +2723,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
func TestAccount_SetJWTGroups(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", "postgres")
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
// create a new account
@@ -2648,7 +2749,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
t.Run("skip sync for token auth type", func(t *testing.T) {
- claims := nbcontext.UserAuth{
+ claims := auth.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{"group3"},
@@ -2663,7 +2764,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("empty jwt groups", func(t *testing.T) {
- claims := nbcontext.UserAuth{
+ claims := auth.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{},
@@ -2677,7 +2778,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("jwt match existing api group", func(t *testing.T) {
- claims := nbcontext.UserAuth{
+ claims := auth.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{"group1"},
@@ -2698,7 +2799,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
account.Users["user1"].AutoGroups = []string{"group1"}
assert.NoError(t, manager.Store.SaveUser(context.Background(), account.Users["user1"]))
- claims := nbcontext.UserAuth{
+ claims := auth.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{"group1"},
@@ -2716,7 +2817,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("add jwt group", func(t *testing.T) {
- claims := nbcontext.UserAuth{
+ claims := auth.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{"group1", "group2"},
@@ -2730,7 +2831,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("existed group not update", func(t *testing.T) {
- claims := nbcontext.UserAuth{
+ claims := auth.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{"group2"},
@@ -2744,7 +2845,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("add new group", func(t *testing.T) {
- claims := nbcontext.UserAuth{
+ claims := auth.UserAuth{
UserId: "user2",
AccountId: "accountID",
Groups: []string{"group1", "group3"},
@@ -2762,7 +2863,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
- claims := nbcontext.UserAuth{
+ claims := auth.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{},
@@ -2777,7 +2878,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) {
- claims := nbcontext.UserAuth{
+ claims := auth.UserAuth{
UserId: "user2",
AccountId: "accountID",
Groups: []string{},
@@ -2864,18 +2965,18 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
// Fatalf(format string, args ...interface{})
// }
-func createManager(t testing.TB) (*DefaultAccountManager, error) {
+func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) {
t.Helper()
store, err := createStore(t)
if err != nil {
- return nil, err
+ return nil, nil, err
}
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
- return nil, err
+ return nil, nil, err
}
ctrl := gomock.NewController(t)
@@ -2893,12 +2994,17 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store)
- manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
+ ctx := context.Background()
+
+ updateManager := update_channel.NewPeersUpdateManager(metrics)
+ requestBuffer := NewAccountRequestBuffer(ctx, store)
+ networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
+ manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
- return nil, err
+ return nil, nil, err
}
- return manager, nil
+ return manager, updateManager, nil
}
func createStore(t testing.TB) (store.Store, error) {
@@ -2927,10 +3033,10 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
}
}
-func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
+func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
t.Helper()
- manager, err := createManager(t)
+ manager, updateManager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -2971,10 +3077,10 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account,
peer2 := getPeer(manager, setupKey)
peer3 := getPeer(manager, setupKey)
- return manager, account, peer1, peer2, peer3
+ return manager, updateManager, account, peer1, peer2, peer3
}
-func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) {
+func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
select {
case msg := <-updateMessage:
@@ -2984,7 +3090,7 @@ func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessag
}
}
-func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) {
+func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
select {
@@ -3022,7 +3128,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
- manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
+ manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -3031,16 +3137,14 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
- peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
- peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
+ updateManager.CreateChannel(ctx, peerID)
}
- manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
- _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
+ _, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
assert.NoError(b, err)
}
@@ -3085,7 +3189,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
- manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
+ manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -3094,11 +3198,10 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
- peerChannels := make(map[string]chan *UpdateMessage)
+
for peerID := range account.Peers {
- peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
+ updateManager.CreateChannel(ctx, peerID)
}
- manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
@@ -3155,7 +3258,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
- manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
+ manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -3164,11 +3267,10 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
- peerChannels := make(map[string]chan *UpdateMessage)
+
for peerID := range account.Peers {
- peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
+ updateManager.CreateChannel(ctx, peerID)
}
- manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
@@ -3227,7 +3329,7 @@ func TestMain(m *testing.M) {
}
func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -3273,7 +3375,7 @@ func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
}
func Test_UpdateToPrimaryAccount(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -3303,12 +3405,12 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
}
func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
require.NoError(t, err)
t.Run("memory cache", func(t *testing.T) {
t.Run("should always return true", func(t *testing.T) {
- cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
+ cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
require.NoError(t, err)
cold, err := manager.isCacheCold(context.Background(), cacheStore)
@@ -3323,7 +3425,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
t.Cleanup(cleanup)
t.Setenv(cache.RedisStoreEnvVar, redisURL)
- cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
+ cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
require.NoError(t, err)
t.Run("should return true when no account exists", func(t *testing.T) {
@@ -3353,7 +3455,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
}
func TestPropagateUserGroupMemberships(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
require.NoError(t, err)
ctx := context.Background()
@@ -3470,7 +3572,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
}
func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
require.NoError(t, err)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
@@ -3502,7 +3604,7 @@ func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
}
func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
require.NoError(t, err)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
@@ -3541,7 +3643,7 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
}
func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -3608,7 +3710,7 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
}
func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -3630,7 +3732,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
// Test adding new user to existing account with approval required
newUserID := "new-user-id"
- userAuth := nbcontext.UserAuth{
+ userAuth := auth.UserAuth{
UserId: newUserID,
Domain: "example.com",
DomainCategory: types.PrivateCategory,
@@ -3654,13 +3756,13 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
}
func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
// Create a domain-based account without user approval
- ownerUserAuth := nbcontext.UserAuth{
+ ownerUserAuth := auth.UserAuth{
UserId: "owner-user",
Domain: "example.com",
DomainCategory: types.PrivateCategory,
@@ -3679,7 +3781,7 @@ func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
// Test adding new user to existing account without approval required
newUserID := "new-user-id"
- userAuth := nbcontext.UserAuth{
+ userAuth := auth.UserAuth{
UserId: newUserID,
Domain: "example.com",
DomainCategory: types.PrivateCategory,
diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go
index 5c5989f84..6344b2904 100644
--- a/management/server/activity/codes.go
+++ b/management/server/activity/codes.go
@@ -179,6 +179,9 @@ const (
PeerIPUpdated Activity = 88
UserApproved Activity = 89
UserRejected Activity = 90
+ UserCreated Activity = 91
+
+ AccountAutoUpdateVersionUpdated Activity = 92
AccountDeleted Activity = 99999
)
@@ -286,8 +289,12 @@ var activityMap = map[Activity]Code{
AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"},
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
- UserApproved: {"User approved", "user.approve"},
- UserRejected: {"User rejected", "user.reject"},
+
+ UserApproved: {"User approved", "user.approve"},
+ UserRejected: {"User rejected", "user.reject"},
+ UserCreated: {"User created", "user.create"},
+
+ AccountAutoUpdateVersionUpdated: {"Account AutoUpdate Version updated", "account.settings.auto.version.update"},
}
// StringCode returns a string code of the activity
diff --git a/management/server/activity/store/sql_store.go b/management/server/activity/store/sql_store.go
index 80b165938..ffecb6b8f 100644
--- a/management/server/activity/store/sql_store.go
+++ b/management/server/activity/store/sql_store.go
@@ -7,6 +7,7 @@ import (
"path/filepath"
"runtime"
"strconv"
+ "time"
log "github.com/sirupsen/logrus"
"gorm.io/driver/postgres"
@@ -273,15 +274,21 @@ func configureConnectionPool(db *gorm.DB, storeEngine types.Engine) (*gorm.DB, e
return nil, err
}
- if storeEngine == types.SqliteStoreEngine {
- sqlDB.SetMaxOpenConns(1)
- } else {
- conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv))
- if err != nil {
- conns = runtime.NumCPU()
- }
- sqlDB.SetMaxOpenConns(conns)
+ conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv))
+ if err != nil {
+ conns = runtime.NumCPU()
}
+ if storeEngine == types.SqliteStoreEngine {
+ conns = 1
+ }
+
+ sqlDB.SetMaxOpenConns(conns)
+ sqlDB.SetMaxIdleConns(conns)
+ sqlDB.SetConnMaxLifetime(time.Hour)
+ sqlDB.SetConnMaxIdleTime(3 * time.Minute)
+
+ log.Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v",
+ conns, conns, time.Hour, 3*time.Minute)
return db, nil
}
diff --git a/management/server/auth/manager.go b/management/server/auth/manager.go
index ece9dc321..0c62357dc 100644
--- a/management/server/auth/manager.go
+++ b/management/server/auth/manager.go
@@ -9,18 +9,19 @@ import (
"github.com/golang-jwt/jwt/v5"
+ "github.com/netbirdio/netbird/shared/auth"
+
"github.com/netbirdio/netbird/base62"
- nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
- nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
+ nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
var _ Manager = (*manager)(nil)
type Manager interface {
- ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error)
- EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error)
+ ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error)
+ EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error)
MarkPATUsed(ctx context.Context, tokenID string) error
GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
}
@@ -55,20 +56,20 @@ func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim s
}
}
-func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) {
+func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) {
token, err := m.validator.ValidateAndParse(ctx, value)
if err != nil {
- return nbcontext.UserAuth{}, nil, err
+ return auth.UserAuth{}, nil, err
}
userAuth, err := m.extractor.ToUserAuth(token)
if err != nil {
- return nbcontext.UserAuth{}, nil, err
+ return auth.UserAuth{}, nil, err
}
return userAuth, token, err
}
-func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
+func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) {
if userAuth.IsChild || userAuth.IsPAT {
return userAuth, nil
}
diff --git a/management/server/auth/manager_mock.go b/management/server/auth/manager_mock.go
index 30a7a7161..edf158a49 100644
--- a/management/server/auth/manager_mock.go
+++ b/management/server/auth/manager_mock.go
@@ -3,9 +3,10 @@ package auth
import (
"context"
+ "github.com/netbirdio/netbird/shared/auth"
+
"github.com/golang-jwt/jwt/v5"
- nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
)
@@ -15,18 +16,18 @@ var (
// @note really dislike this mocking approach but rather than have to do additional test refactoring.
type MockManager struct {
- ValidateAndParseTokenFunc func(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error)
- EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error)
+ ValidateAndParseTokenFunc func(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error)
+ EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error)
MarkPATUsedFunc func(ctx context.Context, tokenID string) error
GetPATInfoFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
}
// EnsureUserAccessByJWTGroups implements Manager.
-func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
+func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) {
if m.EnsureUserAccessByJWTGroupsFunc != nil {
return m.EnsureUserAccessByJWTGroupsFunc(ctx, userAuth, token)
}
- return nbcontext.UserAuth{}, nil
+ return auth.UserAuth{}, nil
}
// GetPATInfo implements Manager.
@@ -46,9 +47,9 @@ func (m *MockManager) MarkPATUsed(ctx context.Context, tokenID string) error {
}
// ValidateAndParseToken implements Manager.
-func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) {
+func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) {
if m.ValidateAndParseTokenFunc != nil {
return m.ValidateAndParseTokenFunc(ctx, value)
}
- return nbcontext.UserAuth{}, &jwt.Token{}, nil
+ return auth.UserAuth{}, &jwt.Token{}, nil
}
diff --git a/management/server/auth/manager_test.go b/management/server/auth/manager_test.go
index c8015eb37..b9f091b1e 100644
--- a/management/server/auth/manager_test.go
+++ b/management/server/auth/manager_test.go
@@ -17,10 +17,10 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/auth"
- nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
- nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
+ nbauth "github.com/netbirdio/netbird/shared/auth"
+ nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) {
@@ -131,7 +131,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
}
// this has been validated and parsed by ValidateAndParseToken
- userAuth := nbcontext.UserAuth{
+ userAuth := nbauth.UserAuth{
AccountId: account.Id,
Domain: domain,
UserId: userId,
@@ -236,7 +236,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
tests := []struct {
name string
tokenFunc func() string
- expected *nbcontext.UserAuth // nil indicates expected error
+ expected *nbauth.UserAuth // nil indicates expected error
}{
{
name: "Valid with custom claims",
@@ -258,7 +258,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
tokenString, _ := token.SignedString(key)
return tokenString
},
- expected: &nbcontext.UserAuth{
+ expected: &nbauth.UserAuth{
UserId: "user-id|123",
AccountId: "account-id|567",
Domain: "http://localhost",
@@ -282,7 +282,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
tokenString, _ := token.SignedString(key)
return tokenString
},
- expected: &nbcontext.UserAuth{
+ expected: &nbauth.UserAuth{
UserId: "user-id|123",
},
},
diff --git a/management/server/cache/idp.go b/management/server/cache/idp.go
index 1b31ff82a..19dfc0f38 100644
--- a/management/server/cache/idp.go
+++ b/management/server/cache/idp.go
@@ -18,6 +18,7 @@ const (
DefaultIDPCacheExpirationMax = 7 * 24 * time.Hour // 7 days
DefaultIDPCacheExpirationMin = 3 * 24 * time.Hour // 3 days
DefaultIDPCacheCleanupInterval = 30 * time.Minute
+ DefaultIDPCacheOpenConn = 100
)
// UserDataCache is an interface that wraps the basic Get, Set and Delete methods for idp.UserData objects.
diff --git a/management/server/cache/idp_test.go b/management/server/cache/idp_test.go
index 3fcfbb11a..0e8061e94 100644
--- a/management/server/cache/idp_test.go
+++ b/management/server/cache/idp_test.go
@@ -33,7 +33,7 @@ func TestNewIDPCacheManagers(t *testing.T) {
t.Cleanup(cleanup)
t.Setenv(cache.RedisStoreEnvVar, redisURL)
}
- cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval)
+ cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval, cache.DefaultIDPCacheOpenConn)
if err != nil {
t.Fatalf("couldn't create cache store: %s", err)
}
diff --git a/management/server/cache/store.go b/management/server/cache/store.go
index 1c141a180..54b0242de 100644
--- a/management/server/cache/store.go
+++ b/management/server/cache/store.go
@@ -3,6 +3,7 @@ package cache
import (
"context"
"fmt"
+ "math"
"os"
"time"
@@ -20,24 +21,27 @@ const RedisStoreEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS"
// NewStore creates a new cache store with the given max timeout and cleanup interval. It checks for the environment Variable RedisStoreEnvVar
// to determine if a redis store should be used. If the environment variable is set, it will attempt to connect to the redis store.
-func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration) (store.StoreInterface, error) {
+func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (store.StoreInterface, error) {
redisAddr := os.Getenv(RedisStoreEnvVar)
if redisAddr != "" {
- return getRedisStore(ctx, redisAddr)
+ return getRedisStore(ctx, redisAddr, maxConn)
}
goc := gocache.New(maxTimeout, cleanupInterval)
return gocache_store.NewGoCache(goc), nil
}
-func getRedisStore(ctx context.Context, redisEnvAddr string) (store.StoreInterface, error) {
+func getRedisStore(ctx context.Context, redisEnvAddr string, maxConn int) (store.StoreInterface, error) {
options, err := redis.ParseURL(redisEnvAddr)
if err != nil {
return nil, fmt.Errorf("parsing redis cache url: %s", err)
}
- options.MaxIdleConns = 6
- options.MinIdleConns = 3
- options.MaxActiveConns = 100
+ options.MaxIdleConns = int(math.Ceil(float64(maxConn) * 0.5)) // 50% of max conns
+ options.MinIdleConns = int(math.Ceil(float64(maxConn) * 0.1)) // 10% of max conns
+ options.MaxActiveConns = maxConn
+ options.ConnMaxIdleTime = 30 * time.Minute
+ options.ConnMaxLifetime = 0
+ options.PoolTimeout = 10 * time.Second
redisClient := redis.NewClient(options)
subCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
diff --git a/management/server/cache/store_test.go b/management/server/cache/store_test.go
index f49dd6bbd..1b64fd70d 100644
--- a/management/server/cache/store_test.go
+++ b/management/server/cache/store_test.go
@@ -15,7 +15,7 @@ import (
)
func TestMemoryStore(t *testing.T) {
- memStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
+ memStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
t.Fatalf("couldn't create memory store: %s", err)
}
@@ -42,7 +42,7 @@ func TestMemoryStore(t *testing.T) {
func TestRedisStoreConnectionFailure(t *testing.T) {
t.Setenv(cache.RedisStoreEnvVar, "redis://127.0.0.1:6379")
- _, err := cache.NewStore(context.Background(), 10*time.Millisecond, 30*time.Millisecond)
+ _, err := cache.NewStore(context.Background(), 10*time.Millisecond, 30*time.Millisecond, 100)
if err == nil {
t.Fatal("getting redis cache store should return error")
}
@@ -65,7 +65,7 @@ func TestRedisStoreConnectionSuccess(t *testing.T) {
}
t.Setenv(cache.RedisStoreEnvVar, redisURL)
- redisStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
+ redisStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
t.Fatalf("couldn't create redis store: %s", err)
}
diff --git a/management/server/context/auth.go b/management/server/context/auth.go
index 5cb28ddb7..cc59b8a63 100644
--- a/management/server/context/auth.go
+++ b/management/server/context/auth.go
@@ -4,7 +4,8 @@ import (
"context"
"fmt"
"net/http"
- "time"
+
+ "github.com/netbirdio/netbird/shared/auth"
)
type key int
@@ -13,45 +14,22 @@ const (
UserAuthContextKey key = iota
)
-type UserAuth struct {
- // The account id the user is accessing
- AccountId string
- // The account domain
- Domain string
- // The account domain category, TBC values
- DomainCategory string
- // Indicates whether this user was invited, TBC logic
- Invited bool
- // Indicates whether this is a child account
- IsChild bool
-
- // The user id
- UserId string
- // Last login time for this user
- LastLogin time.Time
- // The Groups the user belongs to on this account
- Groups []string
-
- // Indicates whether this user has authenticated with a Personal Access Token
- IsPAT bool
-}
-
-func GetUserAuthFromRequest(r *http.Request) (UserAuth, error) {
+func GetUserAuthFromRequest(r *http.Request) (auth.UserAuth, error) {
return GetUserAuthFromContext(r.Context())
}
-func SetUserAuthInRequest(r *http.Request, userAuth UserAuth) *http.Request {
+func SetUserAuthInRequest(r *http.Request, userAuth auth.UserAuth) *http.Request {
return r.WithContext(SetUserAuthInContext(r.Context(), userAuth))
}
-func GetUserAuthFromContext(ctx context.Context) (UserAuth, error) {
- if userAuth, ok := ctx.Value(UserAuthContextKey).(UserAuth); ok {
+func GetUserAuthFromContext(ctx context.Context) (auth.UserAuth, error) {
+ if userAuth, ok := ctx.Value(UserAuthContextKey).(auth.UserAuth); ok {
return userAuth, nil
}
- return UserAuth{}, fmt.Errorf("user auth not in context")
+ return auth.UserAuth{}, fmt.Errorf("user auth not in context")
}
-func SetUserAuthInContext(ctx context.Context, userAuth UserAuth) context.Context {
+func SetUserAuthInContext(ctx context.Context, userAuth auth.UserAuth) context.Context {
//nolint
ctx = context.WithValue(ctx, UserIDKey, userAuth.UserId)
//nolint
diff --git a/management/server/dns.go b/management/server/dns.go
index 534f43ec6..baf6debc3 100644
--- a/management/server/dns.go
+++ b/management/server/dns.go
@@ -3,54 +3,23 @@ package server
import (
"context"
"slices"
- "sync"
log "github.com/sirupsen/logrus"
- "golang.org/x/mod/semver"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
- nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
- "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
- dnsForwarderPort = 22054
- oldForwarderPort = 5353
+ dnsForwarderPort = nbdns.ForwarderServerPort
)
-const dnsForwarderPortMinVersion = "v0.59.0"
-
-// DNSConfigCache is a thread-safe cache for DNS configuration components
-type DNSConfigCache struct {
- NameServerGroups sync.Map
-}
-
-// GetNameServerGroup retrieves a cached name server group
-func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
- if c == nil {
- return nil, false
- }
- if value, ok := c.NameServerGroups.Load(key); ok {
- return value.(*proto.NameServerGroup), true
- }
- return nil, false
-}
-
-// SetNameServerGroup stores a name server group in the cache
-func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
- if c == nil {
- return
- }
- c.NameServerGroups.Store(key, value)
-}
-
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
@@ -191,99 +160,3 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
return validateGroups(settings.DisabledManagementGroups, groups)
}
-
-// computeForwarderPort checks if all peers in the account have updated to a specific version or newer.
-// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
-func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
- if len(peers) == 0 {
- return oldForwarderPort
- }
-
- reqVer := semver.Canonical(requiredVersion)
-
- // Check if all peers have the required version or newer
- for _, peer := range peers {
-
- // Development version is always supported
- if peer.Meta.WtVersion == "development" {
- continue
- }
- peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
- if peerVersion == "" {
- // If any peer doesn't have version info, return 0
- return oldForwarderPort
- }
-
- // Compare versions
- if semver.Compare(peerVersion, reqVer) < 0 {
- return oldForwarderPort
- }
- }
-
- // All peers have the required version or newer
- return dnsForwarderPort
-}
-
-// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
-func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache, forwardPort int64) *proto.DNSConfig {
- protoUpdate := &proto.DNSConfig{
- ServiceEnable: update.ServiceEnable,
- CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
- NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
- ForwarderPort: forwardPort,
- }
-
- for _, zone := range update.CustomZones {
- protoZone := convertToProtoCustomZone(zone)
- protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
- }
-
- for _, nsGroup := range update.NameServerGroups {
- cacheKey := nsGroup.ID
- if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
- protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
- } else {
- protoGroup := convertToProtoNameServerGroup(nsGroup)
- cache.SetNameServerGroup(cacheKey, protoGroup)
- protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
- }
- }
-
- return protoUpdate
-}
-
-// Helper function to convert nbdns.CustomZone to proto.CustomZone
-func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
- protoZone := &proto.CustomZone{
- Domain: zone.Domain,
- Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
- }
- for _, record := range zone.Records {
- protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
- Name: record.Name,
- Type: int64(record.Type),
- Class: record.Class,
- TTL: int64(record.TTL),
- RData: record.RData,
- })
- }
- return protoZone
-}
-
-// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
-func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
- protoGroup := &proto.NameServerGroup{
- Primary: nsGroup.Primary,
- Domains: nsGroup.Domains,
- SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
- NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
- }
- for _, ns := range nsGroup.NameServers {
- protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
- IP: ns.IP.String(),
- Port: int64(ns.Port),
- NSType: int64(ns.NSType),
- })
- }
- return protoGroup
-}
diff --git a/management/server/dns_test.go b/management/server/dns_test.go
index 83caf74ef..b5e3f2b99 100644
--- a/management/server/dns_test.go
+++ b/management/server/dns_test.go
@@ -2,9 +2,7 @@ package server
import (
"context"
- "fmt"
"net/netip"
- "reflect"
"testing"
"time"
@@ -12,6 +10,11 @@ import (
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
+ "github.com/netbirdio/netbird/management/internals/modules/peers"
+ ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
+ "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -218,7 +221,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
// return empty extra settings for expected calls to UpdateAccountPeers
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
permissionsManager := permissions.NewManager(store)
- return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
+
+ ctx := context.Background()
+ updateManager := update_channel.NewPeersUpdateManager(metrics)
+ requestBuffer := NewAccountRequestBuffer(ctx, store)
+ networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
+
+ return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}
func createDNSStore(t *testing.T) (store.Store, error) {
@@ -344,247 +353,8 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
return am.Store.GetAccount(context.Background(), account.Id)
}
-func generateTestData(size int) nbdns.Config {
- config := nbdns.Config{
- ServiceEnable: true,
- CustomZones: make([]nbdns.CustomZone, size),
- NameServerGroups: make([]*nbdns.NameServerGroup, size),
- }
-
- for i := 0; i < size; i++ {
- config.CustomZones[i] = nbdns.CustomZone{
- Domain: fmt.Sprintf("domain%d.com", i),
- Records: []nbdns.SimpleRecord{
- {
- Name: fmt.Sprintf("record%d", i),
- Type: 1,
- Class: "IN",
- TTL: 3600,
- RData: "192.168.1.1",
- },
- },
- }
-
- config.NameServerGroups[i] = &nbdns.NameServerGroup{
- ID: fmt.Sprintf("group%d", i),
- Primary: i == 0,
- Domains: []string{fmt.Sprintf("domain%d.com", i)},
- SearchDomainsEnabled: true,
- NameServers: []nbdns.NameServer{
- {
- IP: netip.MustParseAddr("8.8.8.8"),
- Port: 53,
- NSType: 1,
- },
- },
- }
- }
-
- return config
-}
-
-func BenchmarkToProtocolDNSConfig(b *testing.B) {
- sizes := []int{10, 100, 1000}
-
- for _, size := range sizes {
- testData := generateTestData(size)
-
- b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
- cache := &DNSConfigCache{}
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- toProtocolDNSConfig(testData, cache, dnsForwarderPort)
- }
- })
-
- b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- cache := &DNSConfigCache{}
- toProtocolDNSConfig(testData, cache, dnsForwarderPort)
- }
- })
- }
-}
-
-func TestToProtocolDNSConfigWithCache(t *testing.T) {
- var cache DNSConfigCache
-
- // Create two different configs
- config1 := nbdns.Config{
- ServiceEnable: true,
- CustomZones: []nbdns.CustomZone{
- {
- Domain: "example.com",
- Records: []nbdns.SimpleRecord{
- {Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
- },
- },
- },
- NameServerGroups: []*nbdns.NameServerGroup{
- {
- ID: "group1",
- Name: "Group 1",
- NameServers: []nbdns.NameServer{
- {IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
- },
- },
- },
- }
-
- config2 := nbdns.Config{
- ServiceEnable: true,
- CustomZones: []nbdns.CustomZone{
- {
- Domain: "example.org",
- Records: []nbdns.SimpleRecord{
- {Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
- },
- },
- },
- NameServerGroups: []*nbdns.NameServerGroup{
- {
- ID: "group2",
- Name: "Group 2",
- NameServers: []nbdns.NameServer{
- {IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
- },
- },
- },
- }
-
- // First run with config1
- result1 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort)
-
- // Second run with config2
- result2 := toProtocolDNSConfig(config2, &cache, dnsForwarderPort)
-
- // Third run with config1 again
- result3 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort)
-
- // Verify that result1 and result3 are identical
- if !reflect.DeepEqual(result1, result3) {
- t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
- }
-
- // Verify that result2 is different from result1 and result3
- if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
- t.Errorf("Results should be different for different inputs")
- }
-
- if _, exists := cache.GetNameServerGroup("group1"); !exists {
- t.Errorf("Cache should contain name server group 'group1'")
- }
-
- if _, exists := cache.GetNameServerGroup("group2"); !exists {
- t.Errorf("Cache should contain name server group 'group2'")
- }
-}
-
-func TestComputeForwarderPort(t *testing.T) {
- // Test with empty peers list
- peers := []*nbpeer.Peer{}
- result := computeForwarderPort(peers, "v0.59.0")
- if result != oldForwarderPort {
- t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result)
- }
-
- // Test with peers that have old versions
- peers = []*nbpeer.Peer{
- {
- Meta: nbpeer.PeerSystemMeta{
- WtVersion: "0.57.0",
- },
- },
- {
- Meta: nbpeer.PeerSystemMeta{
- WtVersion: "0.26.0",
- },
- },
- }
- result = computeForwarderPort(peers, "v0.59.0")
- if result != oldForwarderPort {
- t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result)
- }
-
- // Test with peers that have new versions
- peers = []*nbpeer.Peer{
- {
- Meta: nbpeer.PeerSystemMeta{
- WtVersion: "0.59.0",
- },
- },
- {
- Meta: nbpeer.PeerSystemMeta{
- WtVersion: "0.59.0",
- },
- },
- }
- result = computeForwarderPort(peers, "v0.59.0")
- if result != dnsForwarderPort {
- t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result)
- }
-
- // Test with peers that have mixed versions
- peers = []*nbpeer.Peer{
- {
- Meta: nbpeer.PeerSystemMeta{
- WtVersion: "0.59.0",
- },
- },
- {
- Meta: nbpeer.PeerSystemMeta{
- WtVersion: "0.57.0",
- },
- },
- }
- result = computeForwarderPort(peers, "v0.59.0")
- if result != oldForwarderPort {
- t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result)
- }
-
- // Test with peers that have empty version
- peers = []*nbpeer.Peer{
- {
- Meta: nbpeer.PeerSystemMeta{
- WtVersion: "",
- },
- },
- }
- result = computeForwarderPort(peers, "v0.59.0")
- if result != oldForwarderPort {
- t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result)
- }
-
- peers = []*nbpeer.Peer{
- {
- Meta: nbpeer.PeerSystemMeta{
- WtVersion: "development",
- },
- },
- }
- result = computeForwarderPort(peers, "v0.59.0")
- if result == oldForwarderPort {
- t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result)
- }
-
- // Test with peers that have unknown version string
- peers = []*nbpeer.Peer{
- {
- Meta: nbpeer.PeerSystemMeta{
- WtVersion: "unknown",
- },
- },
- }
- result = computeForwarderPort(peers, "v0.59.0")
- if result != oldForwarderPort {
- t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result)
- }
-}
-
func TestDNSAccountPeersUpdate(t *testing.T) {
- manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
+ manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{
{
@@ -600,9 +370,9 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
})
assert.NoError(t, err)
- updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
+ updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
- manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
+ updateManager.CloseChannel(context.Background(), peer1.ID)
})
// Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates
diff --git a/management/server/event_test.go b/management/server/event_test.go
index 8c56fd3f6..420e69866 100644
--- a/management/server/event_test.go
+++ b/management/server/event_test.go
@@ -28,7 +28,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac
}
func TestDefaultAccountManager_GetEvents(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
return
}
diff --git a/management/server/group.go b/management/server/group.go
index 487cb6d97..84e641f26 100644
--- a/management/server/group.go
+++ b/management/server/group.go
@@ -138,6 +138,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err
}
+ newGroup.AccountID = accountID
+
+ events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
+ eventsToStore = append(eventsToStore, events...)
+
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
if err != nil {
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
@@ -157,11 +162,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
}
- newGroup.AccountID = accountID
-
- events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
- eventsToStore = append(eventsToStore, events...)
-
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
@@ -335,6 +335,16 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
if err == nil && oldGroup != nil {
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
+
+ if oldGroup.Name != newGroup.Name {
+ eventsToStore = append(eventsToStore, func() {
+ meta := map[string]any{
+ "old_name": oldGroup.Name,
+ "new_name": newGroup.Name,
+ }
+ am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupUpdated, meta)
+ })
+ }
} else {
addedPeers = append(addedPeers, newGroup.Peers...)
eventsToStore = append(eventsToStore, func() {
@@ -354,7 +364,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
return nil
}
- dnsDomain := am.GetDNSDomain(settings)
+ dnsDomain := am.networkMapController.GetDNSDomain(settings)
for _, peerID := range addedPeers {
peer, ok := peers[peerID]
diff --git a/management/server/group_test.go b/management/server/group_test.go
index 31ff29cbc..4935dac5d 100644
--- a/management/server/group_test.go
+++ b/management/server/group_test.go
@@ -37,7 +37,7 @@ const (
)
func TestDefaultAccountManager_CreateGroup(t *testing.T) {
- am, err := createManager(t)
+ am, _, err := createManager(t)
if err != nil {
t.Error("failed to create account manager")
}
@@ -74,7 +74,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
}
func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
- am, err := createManager(t)
+ am, _, err := createManager(t)
if err != nil {
t.Fatalf("failed to create account manager: %s", err)
}
@@ -156,7 +156,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
}
func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
- am, err := createManager(t)
+ am, _, err := createManager(t)
assert.NoError(t, err, "Failed to create account manager")
manager, account, err := initTestGroupAccount(am)
@@ -408,7 +408,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
}
func TestGroupAccountPeersUpdate(t *testing.T) {
- manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
+ manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
g := []*types.Group{
{
@@ -442,9 +442,9 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
assert.NoError(t, err)
}
- updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
+ updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
- manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
+ updateManager.CloseChannel(context.Background(), peer1.ID)
})
// Saving a group that is not linked to any resource should not update account peers
@@ -748,7 +748,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}
func Test_AddPeerToGroup(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -805,7 +805,7 @@ func Test_AddPeerToGroup(t *testing.T) {
}
func Test_AddPeerToAll(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -862,7 +862,7 @@ func Test_AddPeerToAll(t *testing.T) {
}
func Test_AddPeerAndAddToAll(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -942,7 +942,7 @@ func uint32ToIP(n uint32) net.IP {
}
func Test_IncrementNetworkSerial(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
diff --git a/management/server/http/handler.go b/management/server/http/handler.go
index 3d4de31d0..b7c6c113c 100644
--- a/management/server/http/handler.go
+++ b/management/server/http/handler.go
@@ -4,11 +4,16 @@ import (
"context"
"fmt"
"net/http"
+ "os"
+ "strconv"
+ "time"
"github.com/gorilla/mux"
"github.com/rs/cors"
+ log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/settings"
@@ -16,6 +21,7 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
+ nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/geolocation"
nbgroups "github.com/netbirdio/netbird/management/server/groups"
@@ -34,11 +40,15 @@ import (
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
- nbpeers "github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/telemetry"
)
-const apiPrefix = "/api"
+const (
+ apiPrefix = "/api"
+ rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
+ rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
+ rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
+)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(
@@ -56,13 +66,46 @@ func NewAPIHandler(
permissionsManager permissions.Manager,
peersManager nbpeers.Manager,
settingsManager settings.Manager,
+ networkMapController network_map.Controller,
) (http.Handler, error) {
+ var rateLimitingConfig *middleware.RateLimiterConfig
+ if os.Getenv(rateLimitingEnabledKey) == "true" {
+ rpm := 6
+ if v := os.Getenv(rateLimitingRPMKey); v != "" {
+ value, err := strconv.Atoi(v)
+ if err != nil {
+ log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
+ } else {
+ rpm = value
+ }
+ }
+
+ burst := 500
+ if v := os.Getenv(rateLimitingBurstKey); v != "" {
+ value, err := strconv.Atoi(v)
+ if err != nil {
+ log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
+ } else {
+ burst = value
+ }
+ }
+
+ rateLimitingConfig = &middleware.RateLimiterConfig{
+ RequestsPerMinute: float64(rpm),
+ Burst: burst,
+ CleanupInterval: 6 * time.Hour,
+ LimiterTTL: 24 * time.Hour,
+ }
+ }
+
authMiddleware := middleware.NewAuthMiddleware(
authManager,
accountManager.GetAccountIDFromUserAuth,
accountManager.SyncUserJWTGroups,
accountManager.GetUserFromUserAuth,
+ rateLimitingConfig,
+ appMetrics.GetMeter(),
)
corsMiddleware := cors.AllowAll()
@@ -80,7 +123,7 @@ func NewAPIHandler(
}
accounts.AddEndpoints(accountManager, settingsManager, router)
- peers.AddEndpoints(accountManager, router)
+ peers.AddEndpoints(accountManager, router, networkMapController)
users.AddEndpoints(accountManager, router)
setup_keys.AddEndpoints(accountManager, router)
policies.AddEndpoints(accountManager, LocationManager, router)
diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go
index f1552d0ea..3797b0512 100644
--- a/management/server/http/handlers/accounts/accounts_handler.go
+++ b/management/server/http/handlers/accounts/accounts_handler.go
@@ -3,12 +3,15 @@ package accounts
import (
"context"
"encoding/json"
+ "fmt"
"net/http"
"net/netip"
"time"
"github.com/gorilla/mux"
+ goversion "github.com/hashicorp/go-version"
+
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/settings"
@@ -26,7 +29,9 @@ const (
// MinNetworkBits is the minimum prefix length for IPv4 network ranges (e.g., /29 gives 8 addresses, /28 gives 16)
MinNetworkBitsIPv4 = 28
// MinNetworkBitsIPv6 is the minimum prefix length for IPv6 network ranges
- MinNetworkBitsIPv6 = 120
+ MinNetworkBitsIPv6 = 120
+ disableAutoUpdate = "disabled"
+ autoUpdateLatestVersion = "latest"
)
// handler is a handler that handles the server.Account HTTP endpoints
@@ -162,6 +167,61 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
}
+func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) {
+ returnSettings := &types.Settings{
+ PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
+ PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
+ RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
+
+ PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
+ PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
+ }
+
+ if req.Settings.Extra != nil {
+ returnSettings.Extra = &types.ExtraSettings{
+ PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
+ UserApprovalRequired: req.Settings.Extra.UserApprovalRequired,
+ FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
+ FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
+ FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
+ }
+ }
+
+ if req.Settings.JwtGroupsEnabled != nil {
+ returnSettings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled
+ }
+ if req.Settings.GroupsPropagationEnabled != nil {
+ returnSettings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled
+ }
+ if req.Settings.JwtGroupsClaimName != nil {
+ returnSettings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
+ }
+ if req.Settings.JwtAllowGroups != nil {
+ returnSettings.JWTAllowGroups = *req.Settings.JwtAllowGroups
+ }
+ if req.Settings.RoutingPeerDnsResolutionEnabled != nil {
+ returnSettings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled
+ }
+ if req.Settings.DnsDomain != nil {
+ returnSettings.DNSDomain = *req.Settings.DnsDomain
+ }
+ if req.Settings.LazyConnectionEnabled != nil {
+ returnSettings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
+ }
+ if req.Settings.AutoUpdateVersion != nil {
+ _, err := goversion.NewSemver(*req.Settings.AutoUpdateVersion)
+ if *req.Settings.AutoUpdateVersion == autoUpdateLatestVersion ||
+ *req.Settings.AutoUpdateVersion == disableAutoUpdate ||
+ err == nil {
+ returnSettings.AutoUpdateVersion = *req.Settings.AutoUpdateVersion
+ } else if *req.Settings.AutoUpdateVersion != "" {
+ return nil, fmt.Errorf("invalid AutoUpdateVersion")
+ }
+ }
+
+ return returnSettings, nil
+}
+
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
@@ -186,45 +246,10 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
return
}
- settings := &types.Settings{
- PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
- PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
- RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
-
- PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
- PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
- }
-
- if req.Settings.Extra != nil {
- settings.Extra = &types.ExtraSettings{
- PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
- UserApprovalRequired: req.Settings.Extra.UserApprovalRequired,
- FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
- FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
- FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
- }
- }
-
- if req.Settings.JwtGroupsEnabled != nil {
- settings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled
- }
- if req.Settings.GroupsPropagationEnabled != nil {
- settings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled
- }
- if req.Settings.JwtGroupsClaimName != nil {
- settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
- }
- if req.Settings.JwtAllowGroups != nil {
- settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
- }
- if req.Settings.RoutingPeerDnsResolutionEnabled != nil {
- settings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled
- }
- if req.Settings.DnsDomain != nil {
- settings.DNSDomain = *req.Settings.DnsDomain
- }
- if req.Settings.LazyConnectionEnabled != nil {
- settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
+ settings, err := h.updateAccountRequestSettings(req)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
}
if req.Settings.NetworkRange != nil && *req.Settings.NetworkRange != "" {
prefix, err := netip.ParsePrefix(*req.Settings.NetworkRange)
@@ -313,6 +338,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled,
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
DnsDomain: &settings.DNSDomain,
+ AutoUpdateVersion: &settings.AutoUpdateVersion,
}
if settings.NetworkRange.IsValid() {
diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go
index 4b9b79fdc..2e48ac83e 100644
--- a/management/server/http/handlers/accounts/accounts_handler_test.go
+++ b/management/server/http/handlers/accounts/accounts_handler_test.go
@@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -120,6 +121,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
+ AutoUpdateVersion: sr(""),
},
expectedArray: true,
expectedID: accountID,
@@ -142,6 +144,30 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
+ AutoUpdateVersion: sr(""),
+ },
+ expectedArray: false,
+ expectedID: accountID,
+ },
+ {
+ name: "PutAccount OK with autoUpdateVersion",
+ expectedBody: true,
+ requestType: http.MethodPut,
+ requestPath: "/api/accounts/" + accountID,
+ requestBody: bytes.NewBufferString("{\"settings\": {\"auto_update_version\": \"latest\", \"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
+ expectedStatus: http.StatusOK,
+ expectedSettings: api.AccountSettings{
+ PeerLoginExpiration: 15552000,
+ PeerLoginExpirationEnabled: true,
+ GroupsPropagationEnabled: br(false),
+ JwtGroupsClaimName: sr(""),
+ JwtGroupsEnabled: br(false),
+ JwtAllowGroups: &[]string{},
+ RegularUsersViewBlocked: false,
+ RoutingPeerDnsResolutionEnabled: br(false),
+ LazyConnectionEnabled: br(false),
+ DnsDomain: sr(""),
+ AutoUpdateVersion: sr("latest"),
},
expectedArray: false,
expectedID: accountID,
@@ -164,6 +190,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
+ AutoUpdateVersion: sr(""),
},
expectedArray: false,
expectedID: accountID,
@@ -186,6 +213,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
+ AutoUpdateVersion: sr(""),
},
expectedArray: false,
expectedID: accountID,
@@ -208,6 +236,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
+ AutoUpdateVersion: sr(""),
},
expectedArray: false,
expectedID: accountID,
@@ -236,7 +265,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: adminUser.Id,
AccountId: accountID,
Domain: "hotmail.com",
diff --git a/management/server/http/handlers/dns/dns_settings_handler.go b/management/server/http/handlers/dns/dns_settings_handler.go
index 08a0b2afd..67638aea5 100644
--- a/management/server/http/handlers/dns/dns_settings_handler.go
+++ b/management/server/http/handlers/dns/dns_settings_handler.go
@@ -9,9 +9,9 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
+ "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
- "github.com/netbirdio/netbird/management/server/types"
)
// dnsSettingsHandler is a handler that returns the DNS settings of the account
diff --git a/management/server/http/handlers/dns/dns_settings_handler_test.go b/management/server/http/handlers/dns/dns_settings_handler_test.go
index 42b519c29..a027c067e 100644
--- a/management/server/http/handlers/dns/dns_settings_handler_test.go
+++ b/management/server/http/handlers/dns/dns_settings_handler_test.go
@@ -11,13 +11,14 @@ import (
"github.com/stretchr/testify/assert"
+ "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
- "github.com/netbirdio/netbird/management/server/types"
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
+ "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@@ -107,7 +108,7 @@ func TestDNSSettingsHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id,
AccountId: testingDNSSettingsAccount.Id,
Domain: testingDNSSettingsAccount.Domain,
diff --git a/management/server/http/handlers/dns/nameservers_handler_test.go b/management/server/http/handlers/dns/nameservers_handler_test.go
index d49b6c7e0..4716782f3 100644
--- a/management/server/http/handlers/dns/nameservers_handler_test.go
+++ b/management/server/http/handlers/dns/nameservers_handler_test.go
@@ -19,6 +19,7 @@ import (
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
+ "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@@ -193,7 +194,7 @@ func TestNameserversHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
AccountId: testNSGroupAccountID,
Domain: "hotmail.com",
diff --git a/management/server/http/handlers/events/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go
index a0695fa3f..923a24e31 100644
--- a/management/server/http/handlers/events/events_handler_test.go
+++ b/management/server/http/handlers/events/events_handler_test.go
@@ -14,11 +14,12 @@ import (
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
+ "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/management/server/activity"
- "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/shared/management/http/api"
)
func initEventsTestData(account string, events ...*activity.Event) *handler {
@@ -188,7 +189,7 @@ func TestEvents_GetEvents(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_account",
diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go
index e861e873c..56ccc9d0b 100644
--- a/management/server/http/handlers/groups/groups_handler.go
+++ b/management/server/http/handlers/groups/groups_handler.go
@@ -11,10 +11,10 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
- "github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that returns groups of the account
@@ -48,6 +48,29 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
}
accountID, userID := userAuth.AccountId, userAuth.UserId
+ // Check if filtering by name
+ groupName := r.URL.Query().Get("name")
+ if groupName != "" {
+ // Get single group by name
+ group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ // Return as array with single element to maintain API consistency
+ groupsResponse := []*api.Group{toGroupResponse(accountPeers, group)}
+ util.WriteJSONObject(r.Context(), w, groupsResponse)
+ return
+ }
+
+ // Get all groups
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go
index 34694ec8c..458a15c11 100644
--- a/management/server/http/handlers/groups/groups_handler_test.go
+++ b/management/server/http/handlers/groups/groups_handler_test.go
@@ -19,12 +19,13 @@ import (
"github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context"
- "github.com/netbirdio/netbird/shared/management/http/api"
- "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
- "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/shared/auth"
+ "github.com/netbirdio/netbird/shared/management/http/api"
+ "github.com/netbirdio/netbird/shared/management/http/util"
+ "github.com/netbirdio/netbird/shared/management/status"
)
var TestPeers = map[string]*nbpeer.Peer{
@@ -59,12 +60,23 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
return group, nil
},
+ GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*types.Group, error) {
+ groups := []*types.Group{
+ {ID: "id-jwt-group", Name: "From JWT", Issued: types.GroupIssuedJWT},
+ {ID: "id-existed", Name: "Existed", Peers: []string{"A", "B"}, Issued: types.GroupIssuedAPI},
+ {ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI},
+ }
+
+ groups = append(groups, initGroups...)
+
+ return groups, nil
+ },
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
if groupName == "All" {
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
}
- return nil, fmt.Errorf("unknown group name")
+ return nil, status.Errorf(status.NotFound, "unknown group name")
},
GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
return maps.Values(TestPeers), nil
@@ -122,7 +134,7 @@ func TestGetGroup(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -248,7 +260,7 @@ func TestWriteGroup(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -286,6 +298,84 @@ func TestWriteGroup(t *testing.T) {
}
}
+func TestGetAllGroups(t *testing.T) {
+ tt := []struct {
+ name string
+ expectedStatus int
+ expectedBody bool
+ requestType string
+ requestPath string
+ expectedCount int
+ }{
+ {
+ name: "Get All Groups",
+ expectedBody: true,
+ requestType: http.MethodGet,
+ requestPath: "/api/groups",
+ expectedStatus: http.StatusOK,
+ expectedCount: 3, // id-jwt-group, id-existed, id-all
+ },
+ {
+ name: "Get Group By Name - Existing",
+ expectedBody: true,
+ requestType: http.MethodGet,
+ requestPath: "/api/groups?name=All",
+ expectedStatus: http.StatusOK,
+ expectedCount: 1,
+ },
+ {
+ name: "Get Group By Name - Not Found",
+ expectedBody: false,
+ requestType: http.MethodGet,
+ requestPath: "/api/groups?name=NonExistent",
+ expectedStatus: http.StatusNotFound,
+ },
+ }
+
+ p := initGroupTestData()
+
+ for _, tc := range tt {
+ t.Run(tc.name, func(t *testing.T) {
+ recorder := httptest.NewRecorder()
+ req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
+ UserId: "test_user",
+ Domain: "hotmail.com",
+ AccountId: "test_id",
+ })
+
+ router := mux.NewRouter()
+ router.HandleFunc("/api/groups", p.getAllGroups).Methods("GET")
+ router.ServeHTTP(recorder, req)
+
+ res := recorder.Result()
+ defer res.Body.Close()
+
+ if status := recorder.Code; status != tc.expectedStatus {
+ t.Errorf("handler returned wrong status code: got %v want %v",
+ status, tc.expectedStatus)
+ return
+ }
+
+ if !tc.expectedBody {
+ return
+ }
+
+ content, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("Failed to read response body: %v", err)
+ }
+
+ var groups []api.Group
+ if err = json.Unmarshal(content, &groups); err != nil {
+ t.Fatalf("Response is not in correct json format; %v", err)
+ }
+
+ assert.Equal(t, tc.expectedCount, len(groups))
+ })
+ }
+}
+
func TestDeleteGroup(t *testing.T) {
tt := []struct {
name string
@@ -330,7 +420,7 @@ func TestDeleteGroup(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go
index d7b598a5d..f99eca794 100644
--- a/management/server/http/handlers/networks/handler.go
+++ b/management/server/http/handlers/networks/handler.go
@@ -12,15 +12,15 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
- "github.com/netbirdio/netbird/shared/management/http/api"
- "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/networks/types"
- "github.com/netbirdio/netbird/shared/management/status"
nbtypes "github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/shared/management/http/api"
+ "github.com/netbirdio/netbird/shared/management/http/util"
+ "github.com/netbirdio/netbird/shared/management/status"
)
// handler is a handler that returns networks of the account
diff --git a/management/server/http/handlers/networks/resources_handler.go b/management/server/http/handlers/networks/resources_handler.go
index 59396dceb..c31729a39 100644
--- a/management/server/http/handlers/networks/resources_handler.go
+++ b/management/server/http/handlers/networks/resources_handler.go
@@ -8,10 +8,10 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
- "github.com/netbirdio/netbird/shared/management/http/api"
- "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
+ "github.com/netbirdio/netbird/shared/management/http/api"
+ "github.com/netbirdio/netbird/shared/management/http/util"
)
type resourceHandler struct {
diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go
index 2e64c637f..c311a29fe 100644
--- a/management/server/http/handlers/networks/routers_handler.go
+++ b/management/server/http/handlers/networks/routers_handler.go
@@ -7,10 +7,10 @@ import (
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
- "github.com/netbirdio/netbird/shared/management/http/api"
- "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/networks/routers/types"
+ "github.com/netbirdio/netbird/shared/management/http/api"
+ "github.com/netbirdio/netbird/shared/management/http/util"
)
type routersHandler struct {
diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go
index 4b33495de..a5c9ab0ac 100644
--- a/management/server/http/handlers/peers/peers_handler.go
+++ b/management/server/http/handlers/peers/peers_handler.go
@@ -10,6 +10,7 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
@@ -23,11 +24,12 @@ import (
// Handler is a handler that returns peers of the account
type Handler struct {
- accountManager account.Manager
+ accountManager account.Manager
+ networkMapController network_map.Controller
}
-func AddEndpoints(accountManager account.Manager, router *mux.Router) {
- peersHandler := NewHandler(accountManager)
+func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) {
+ peersHandler := NewHandler(accountManager, networkMapController)
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
@@ -36,25 +38,13 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) {
}
// NewHandler creates a new peers Handler
-func NewHandler(accountManager account.Manager) *Handler {
+func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler {
return &Handler{
- accountManager: accountManager,
+ accountManager: accountManager,
+ networkMapController: networkMapController,
}
}
-func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) {
- peerToReturn := peer.Copy()
- if peer.Status.Connected {
- // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected
- // This may happen after server restart when not all peers are yet connected
- if !h.accountManager.HasConnectedChannel(peer.ID) {
- peerToReturn.Status.Connected = false
- }
- }
-
- return peerToReturn, nil
-}
-
func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
if err != nil {
@@ -62,23 +52,18 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
return
}
- peerToReturn, err := h.checkPeerStatus(peer)
- if err != nil {
- util.WriteError(ctx, err, w)
- return
- }
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
if err != nil {
util.WriteError(ctx, err, w)
return
}
- dnsDomain := h.accountManager.GetDNSDomain(settings)
+ dnsDomain := h.networkMapController.GetDNSDomain(settings)
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
- validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
+ validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
@@ -86,7 +71,9 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
}
_, valid := validPeers[peer.ID]
- util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid))
+ reason := invalidPeers[peer.ID]
+
+ util.WriteJSONObject(ctx, w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
}
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
@@ -137,7 +124,7 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
util.WriteError(ctx, err, w)
return
}
- dnsDomain := h.accountManager.GetDNSDomain(settings)
+ dnsDomain := h.networkMapController.GetDNSDomain(settings)
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
if err != nil {
@@ -147,16 +134,17 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0)
- validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
+ validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil {
- log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
+ log.WithContext(ctx).Errorf("failed to get validated peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
return
}
_, valid := validPeers[peer.ID]
+ reason := invalidPeers[peer.ID]
- util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid))
+ util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
}
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
@@ -224,38 +212,35 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w)
return
}
- dnsDomain := h.accountManager.GetDNSDomain(settings)
+ dnsDomain := h.networkMapController.GetDNSDomain(settings)
grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers))
respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
- peerToReturn, err := h.checkPeerStatus(peer)
- if err != nil {
- util.WriteError(r.Context(), err, w)
- return
- }
-
- respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0))
+ respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0))
}
- validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
+ validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
if err != nil {
- log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
+ log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return
}
- h.setApprovalRequiredFlag(respBody, validPeersMap)
+ h.setApprovalRequiredFlag(respBody, validPeersMap, invalidPeersMap)
util.WriteJSONObject(r.Context(), w, respBody)
}
-func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
+func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersMap map[string]struct{}, invalidPeersMap map[string]string) {
for _, peer := range respBody {
- _, ok := approvedPeersMap[peer.Id]
+ _, ok := validPeersMap[peer.Id]
if !ok {
peer.ApprovalRequired = true
+
+ reason := invalidPeersMap[peer.Id]
+ peer.DisapprovalReason = &reason
}
}
}
@@ -304,17 +289,17 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
}
}
- validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
+ validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return
}
- dnsDomain := h.accountManager.GetDNSDomain(account.Settings)
+ dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
- netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
+ netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}
@@ -384,6 +369,9 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
PortRanges: []types.RulePortRange{portRange},
}},
}
+ if protocol == types.PolicyRuleProtocolNetbirdSSH {
+ policy.Rules[0].AuthorizedUser = userAuth.UserId
+ }
_, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true)
if err != nil {
@@ -430,13 +418,13 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
}
}
-func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer {
+func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool, reason string) *api.Peer {
osVersion := peer.Meta.OSVersion
if osVersion == "" {
osVersion = peer.Meta.Core
}
- return &api.Peer{
+ apiPeer := &api.Peer{
CreatedAt: peer.CreatedAt,
Id: peer.ID,
Name: peer.Name,
@@ -464,7 +452,25 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
Ephemeral: peer.Ephemeral,
+ LocalFlags: &api.PeerLocalFlags{
+ BlockInbound: &peer.Meta.Flags.BlockInbound,
+ BlockLanAccess: &peer.Meta.Flags.BlockLANAccess,
+ DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes,
+ DisableDns: &peer.Meta.Flags.DisableDNS,
+ DisableFirewall: &peer.Meta.Flags.DisableFirewall,
+ DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes,
+ LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled,
+ RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled,
+ RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive,
+ ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed,
+ },
}
+
+ if !approved {
+ apiPeer.DisapprovalReason = &reason
+ }
+
+ return apiPeer
}
func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch {
@@ -472,7 +478,6 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
if osVersion == "" {
osVersion = peer.Meta.Core
}
-
return &api.PeerBatch{
CreatedAt: peer.CreatedAt,
Id: peer.ID,
@@ -501,6 +506,18 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
Ephemeral: peer.Ephemeral,
+ LocalFlags: &api.PeerLocalFlags{
+ BlockInbound: &peer.Meta.Flags.BlockInbound,
+ BlockLanAccess: &peer.Meta.Flags.BlockLANAccess,
+ DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes,
+ DisableDns: &peer.Meta.Flags.DisableDNS,
+ DisableFirewall: &peer.Meta.Flags.DisableFirewall,
+ DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes,
+ LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled,
+ RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled,
+ RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive,
+ ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed,
+ },
}
}
diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go
index 94564113f..55e779ff0 100644
--- a/management/server/http/handlers/peers/peers_handler_test.go
+++ b/management/server/http/handlers/peers/peers_handler_test.go
@@ -14,12 +14,15 @@ import (
"time"
"github.com/gorilla/mux"
+ "go.uber.org/mock/gomock"
"golang.org/x/exp/maps"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbcontext "github.com/netbirdio/netbird/management/server/context"
- "github.com/netbirdio/netbird/shared/management/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/shared/auth"
+ "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -36,7 +39,7 @@ const (
serviceUser = "service_user"
)
-func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
+func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
@@ -99,6 +102,14 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
},
}
+ ctrl := gomock.NewController(t)
+
+ networkMapController := network_map.NewMockController(ctrl)
+ networkMapController.EXPECT().
+ GetDNSDomain(gomock.Any()).
+ Return("domain").
+ AnyTimes()
+
return &Handler{
accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
@@ -187,6 +198,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
return account.Settings, nil
},
},
+ networkMapController: networkMapController,
}
}
@@ -249,14 +261,6 @@ func TestGetPeers(t *testing.T) {
expectedArray: false,
expectedPeer: peer,
},
- {
- name: "GetPeer with no update channel",
- requestType: http.MethodGet,
- requestPath: "/api/peers/" + peer1.ID,
- expectedStatus: http.StatusOK,
- expectedArray: false,
- expectedPeer: expectedPeer1,
- },
{
name: "PutPeer",
requestType: http.MethodPut,
@@ -270,14 +274,14 @@ func TestGetPeers(t *testing.T) {
rr := httptest.NewRecorder()
- p := initTestMetaData(peer, peer1)
+ p := initTestMetaData(t, peer, peer1)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "admin_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -316,8 +320,6 @@ func TestGetPeers(t *testing.T) {
for _, peer := range respBody {
if peer.Id == testPeerID {
got = peer
- } else {
- assert.Equal(t, peer.Connected, false)
}
}
@@ -331,14 +333,14 @@ func TestGetPeers(t *testing.T) {
t.Log(got)
- assert.Equal(t, got.Name, tc.expectedPeer.Name)
- assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion)
- assert.Equal(t, got.Ip, tc.expectedPeer.IP.String())
- assert.Equal(t, got.Os, "OS core")
- assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled)
- assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled)
- assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected)
- assert.Equal(t, got.SerialNumber, tc.expectedPeer.Meta.SystemSerialNumber)
+ assert.Equal(t, tc.expectedPeer.Name, got.Name)
+ assert.Equal(t, tc.expectedPeer.Meta.WtVersion, got.Version)
+ assert.Equal(t, tc.expectedPeer.IP.String(), got.Ip)
+ assert.Equal(t, "OS core", got.Os)
+ assert.Equal(t, tc.expectedPeer.LoginExpirationEnabled, got.LoginExpirationEnabled)
+ assert.Equal(t, tc.expectedPeer.SSHEnabled, got.SshEnabled)
+ assert.Equal(t, tc.expectedPeer.Status.Connected, got.Connected)
+ assert.Equal(t, tc.expectedPeer.Meta.SystemSerialNumber, got.SerialNumber)
})
}
}
@@ -374,7 +376,7 @@ func TestGetAccessiblePeers(t *testing.T) {
UserID: regularUser,
}
- p := initTestMetaData(peer1, peer2, peer3)
+ p := initTestMetaData(t, peer1, peer2, peer3)
tt := []struct {
name string
@@ -425,7 +427,7 @@ func TestGetAccessiblePeers(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: tc.callerUserID,
Domain: "hotmail.com",
AccountId: "test_id",
@@ -477,7 +479,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
},
}
- p := initTestMetaData(testPeer)
+ p := initTestMetaData(t, testPeer)
tt := []struct {
name string
@@ -508,7 +510,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/peers/%s", tc.peerID), bytes.NewBuffer([]byte(tc.requestBody)))
req.Header.Set("Content-Type", "application/json")
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: tc.callerUserID,
Domain: "hotmail.com",
AccountId: "test_id",
diff --git a/management/server/http/handlers/policies/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go
index cedd5ac88..094a36e38 100644
--- a/management/server/http/handlers/policies/geolocation_handler_test.go
+++ b/management/server/http/handlers/policies/geolocation_handler_test.go
@@ -16,12 +16,13 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
- "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/shared/auth"
+ "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/util"
)
@@ -113,7 +114,7 @@ func TestGetCitiesByCountry(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -206,7 +207,7 @@ func TestGetAllCountries(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
diff --git a/management/server/http/handlers/policies/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go
index cb6995793..a2d656a47 100644
--- a/management/server/http/handlers/policies/geolocations_handler.go
+++ b/management/server/http/handlers/policies/geolocations_handler.go
@@ -9,11 +9,11 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
- "github.com/netbirdio/netbird/shared/management/http/api"
- "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
+ "github.com/netbirdio/netbird/shared/management/http/api"
+ "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go
index 4d6bad5e3..e4d1d73df 100644
--- a/management/server/http/handlers/policies/policies_handler.go
+++ b/management/server/http/handlers/policies/policies_handler.go
@@ -10,10 +10,10 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
+ "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
- "github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that returns policy of the account
@@ -221,6 +221,8 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
pr.Protocol = types.PolicyRuleProtocolUDP
case api.PolicyRuleUpdateProtocolIcmp:
pr.Protocol = types.PolicyRuleProtocolICMP
+ case api.PolicyRuleUpdateProtocolNetbirdSsh:
+ pr.Protocol = types.PolicyRuleProtocolNetbirdSSH
default:
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
return
@@ -254,6 +256,17 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
}
}
+ if pr.Protocol == types.PolicyRuleProtocolNetbirdSSH && rule.AuthorizedGroups != nil && len(*rule.AuthorizedGroups) != 0 {
+ for _, sourceGroupID := range pr.Sources {
+ _, ok := (*rule.AuthorizedGroups)[sourceGroupID]
+ if !ok {
+ util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "authorized group for netbird-ssh protocol should be specified for each source group"), w)
+ return
+ }
+ }
+ pr.AuthorizedGroups = *rule.AuthorizedGroups
+ }
+
// validate policy object
if pr.Protocol == types.PolicyRuleProtocolALL || pr.Protocol == types.PolicyRuleProtocolICMP {
if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 {
@@ -380,6 +393,11 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
DestinationResource: r.DestinationResource.ToAPIResponse(),
}
+ if len(r.AuthorizedGroups) != 0 {
+ authorizedGroupsCopy := r.AuthorizedGroups
+ rule.AuthorizedGroups = &authorizedGroupsCopy
+ }
+
if len(r.Ports) != 0 {
portsCopy := r.Ports
rule.Ports = &portsCopy
diff --git a/management/server/http/handlers/policies/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go
index fd39ae2a3..ca5a0a6ab 100644
--- a/management/server/http/handlers/policies/policies_handler_test.go
+++ b/management/server/http/handlers/policies/policies_handler_test.go
@@ -14,10 +14,11 @@ import (
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
- "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
- "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/shared/auth"
+ "github.com/netbirdio/netbird/shared/management/http/api"
+ "github.com/netbirdio/netbird/shared/management/status"
)
func initPoliciesTestData(policies ...*types.Policy) *handler {
@@ -103,7 +104,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -267,7 +268,7 @@ func TestPoliciesWritePolicy(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
diff --git a/management/server/http/handlers/policies/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go
index 3ebc4d1e1..744cde10b 100644
--- a/management/server/http/handlers/policies/posture_checks_handler.go
+++ b/management/server/http/handlers/policies/posture_checks_handler.go
@@ -9,9 +9,9 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
+ "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
- "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/management/status"
)
diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go
index c644b533a..35198da32 100644
--- a/management/server/http/handlers/policies/posture_checks_handler_test.go
+++ b/management/server/http/handlers/policies/posture_checks_handler_test.go
@@ -16,9 +16,10 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
- "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/posture"
+ "github.com/netbirdio/netbird/shared/auth"
+ "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -45,7 +46,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH
testPostureChecks[postureChecks.ID] = postureChecks
if err := postureChecks.Validate(); err != nil {
- return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint
+ return nil, status.Errorf(status.InvalidArgument, "%s", err.Error()) //nolint
}
return postureChecks, nil
@@ -175,7 +176,7 @@ func TestGetPostureCheck(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -828,7 +829,7 @@ func TestPostureCheckUpdate(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
diff --git a/management/server/http/handlers/routes/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go
index 466a7987f..a44d81e3e 100644
--- a/management/server/http/handlers/routes/routes_handler_test.go
+++ b/management/server/http/handlers/routes/routes_handler_test.go
@@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
+ "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
@@ -493,7 +494,7 @@ func TestRoutesHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: testAccountID,
diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go
index 2287dadfe..d267b6eea 100644
--- a/management/server/http/handlers/setup_keys/setupkeys_handler.go
+++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go
@@ -10,10 +10,10 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
+ "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
- "github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that returns a list of setup keys of the account
diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go
index 7b46b486b..b137b6dd1 100644
--- a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go
+++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go
@@ -15,10 +15,11 @@ import (
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
- "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
- "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/shared/auth"
+ "github.com/netbirdio/netbird/shared/management/http/api"
+ "github.com/netbirdio/netbird/shared/management/status"
)
const (
@@ -163,7 +164,7 @@ func TestSetupKeysHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: adminUser.Id,
Domain: "hotmail.com",
AccountId: "testAccountId",
diff --git a/management/server/http/handlers/users/pat_handler.go b/management/server/http/handlers/users/pat_handler.go
index bae07af4a..867db3ca9 100644
--- a/management/server/http/handlers/users/pat_handler.go
+++ b/management/server/http/handlers/users/pat_handler.go
@@ -8,10 +8,10 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
+ "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
- "github.com/netbirdio/netbird/management/server/types"
)
// patHandler is the nameserver group handler of the account
diff --git a/management/server/http/handlers/users/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go
index 92544c56d..7cda14468 100644
--- a/management/server/http/handlers/users/pat_handler_test.go
+++ b/management/server/http/handlers/users/pat_handler_test.go
@@ -17,10 +17,11 @@ import (
"github.com/netbirdio/netbird/management/server/util"
nbcontext "github.com/netbirdio/netbird/management/server/context"
- "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
- "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/shared/auth"
+ "github.com/netbirdio/netbird/shared/management/http/api"
+ "github.com/netbirdio/netbird/shared/management/status"
)
const (
@@ -173,7 +174,7 @@ func TestTokenHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go
index e08004218..37f0a6c1d 100644
--- a/management/server/http/handlers/users/users_handler_test.go
+++ b/management/server/http/handlers/users/users_handler_test.go
@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
+ "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -128,7 +129,7 @@ func initUsersTestData() *handler {
return nil
},
- GetCurrentUserInfoFunc: func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
+ GetCurrentUserInfoFunc: func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) {
switch userAuth.UserId {
case "not-found":
return nil, status.NewUserNotFoundError("not-found")
@@ -225,7 +226,7 @@ func TestGetUsers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
@@ -335,7 +336,7 @@ func TestUpdateUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
@@ -432,7 +433,7 @@ func TestCreateUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
rr := httptest.NewRecorder()
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
@@ -481,7 +482,7 @@ func TestInviteUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = mux.SetURLVars(req, tc.requestVars)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
@@ -540,7 +541,7 @@ func TestDeleteUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = mux.SetURLVars(req, tc.requestVars)
- req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
@@ -565,7 +566,7 @@ func TestCurrentUser(t *testing.T) {
tt := []struct {
name string
expectedStatus int
- requestAuth nbcontext.UserAuth
+ requestAuth auth.UserAuth
expectedResult *api.User
}{
{
@@ -574,27 +575,27 @@ func TestCurrentUser(t *testing.T) {
},
{
name: "user not found",
- requestAuth: nbcontext.UserAuth{UserId: "not-found"},
+ requestAuth: auth.UserAuth{UserId: "not-found"},
expectedStatus: http.StatusNotFound,
},
{
name: "not of account",
- requestAuth: nbcontext.UserAuth{UserId: "not-of-account"},
+ requestAuth: auth.UserAuth{UserId: "not-of-account"},
expectedStatus: http.StatusForbidden,
},
{
name: "blocked user",
- requestAuth: nbcontext.UserAuth{UserId: "blocked-user"},
+ requestAuth: auth.UserAuth{UserId: "blocked-user"},
expectedStatus: http.StatusForbidden,
},
{
name: "service user",
- requestAuth: nbcontext.UserAuth{UserId: "service-user"},
+ requestAuth: auth.UserAuth{UserId: "service-user"},
expectedStatus: http.StatusForbidden,
},
{
name: "owner",
- requestAuth: nbcontext.UserAuth{UserId: "owner"},
+ requestAuth: auth.UserAuth{UserId: "owner"},
expectedStatus: http.StatusOK,
expectedResult: &api.User{
Id: "owner",
@@ -613,7 +614,7 @@ func TestCurrentUser(t *testing.T) {
},
{
name: "regular user",
- requestAuth: nbcontext.UserAuth{UserId: "regular-user"},
+ requestAuth: auth.UserAuth{UserId: "regular-user"},
expectedStatus: http.StatusOK,
expectedResult: &api.User{
Id: "regular-user",
@@ -632,7 +633,7 @@ func TestCurrentUser(t *testing.T) {
},
{
name: "admin user",
- requestAuth: nbcontext.UserAuth{UserId: "admin-user"},
+ requestAuth: auth.UserAuth{UserId: "admin-user"},
expectedStatus: http.StatusOK,
expectedResult: &api.User{
Id: "admin-user",
@@ -651,7 +652,7 @@ func TestCurrentUser(t *testing.T) {
},
{
name: "restricted user",
- requestAuth: nbcontext.UserAuth{UserId: "restricted-user"},
+ requestAuth: auth.UserAuth{UserId: "restricted-user"},
expectedStatus: http.StatusOK,
expectedResult: &api.User{
Id: "restricted-user",
@@ -783,7 +784,7 @@ func TestApproveUserEndpoint(t *testing.T) {
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
require.NoError(t, err)
- userAuth := nbcontext.UserAuth{
+ userAuth := auth.UserAuth{
AccountId: existingAccountID,
UserId: tc.requestingUser.Id,
}
@@ -841,7 +842,7 @@ func TestRejectUserEndpoint(t *testing.T) {
req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
require.NoError(t, err)
- userAuth := nbcontext.UserAuth{
+ userAuth := auth.UserAuth{
AccountId: existingAccountID,
UserId: tc.requestingUser.Id,
}
diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go
index 6091a4c31..38cf0c290 100644
--- a/management/server/http/middleware/auth_middleware.go
+++ b/management/server/http/middleware/auth_middleware.go
@@ -9,40 +9,62 @@ import (
"time"
log "github.com/sirupsen/logrus"
+ "go.opentelemetry.io/otel/metric"
- "github.com/netbirdio/netbird/management/server/auth"
+ serverauth "github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
-type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
-type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error
+type EnsureAccountFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
+type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) error
-type GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
+type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
- authManager auth.Manager
+ authManager serverauth.Manager
ensureAccount EnsureAccountFunc
getUserFromUserAuth GetUserFromUserAuthFunc
syncUserJWTGroups SyncUserJWTGroupsFunc
+ rateLimiter *APIRateLimiter
+ patUsageTracker *PATUsageTracker
}
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(
- authManager auth.Manager,
+ authManager serverauth.Manager,
ensureAccount EnsureAccountFunc,
syncUserJWTGroups SyncUserJWTGroupsFunc,
getUserFromUserAuth GetUserFromUserAuthFunc,
+ rateLimiterConfig *RateLimiterConfig,
+ meter metric.Meter,
) *AuthMiddleware {
+ var rateLimiter *APIRateLimiter
+ if rateLimiterConfig != nil {
+ rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
+ }
+
+ var patUsageTracker *PATUsageTracker
+ if meter != nil {
+ var err error
+ patUsageTracker, err = NewPATUsageTracker(context.Background(), meter)
+ if err != nil {
+ log.Errorf("Failed to create PAT usage tracker: %s", err)
+ }
+ }
+
return &AuthMiddleware{
authManager: authManager,
ensureAccount: ensureAccount,
syncUserJWTGroups: syncUserJWTGroups,
getUserFromUserAuth: getUserFromUserAuth,
+ rateLimiter: rateLimiter,
+ patUsageTracker: patUsageTracker,
}
}
@@ -53,18 +75,18 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
return
}
- auth := strings.Split(r.Header.Get("Authorization"), " ")
- authType := strings.ToLower(auth[0])
+ authHeader := strings.Split(r.Header.Get("Authorization"), " ")
+ authType := strings.ToLower(authHeader[0])
// fallback to token when receive pat as bearer
- if len(auth) >= 2 && authType == "bearer" && strings.HasPrefix(auth[1], "nbp_") {
+ if len(authHeader) >= 2 && authType == "bearer" && strings.HasPrefix(authHeader[1], "nbp_") {
authType = "token"
- auth[0] = authType
+ authHeader[0] = authType
}
switch authType {
case "bearer":
- request, err := m.checkJWTFromRequest(r, auth)
+ request, err := m.checkJWTFromRequest(r, authHeader)
if err != nil {
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
@@ -73,10 +95,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
h.ServeHTTP(w, request)
case "token":
- request, err := m.checkPATFromRequest(r, auth)
+ request, err := m.checkPATFromRequest(r, authHeader)
if err != nil {
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
- util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
+ // Check if it's a status error, otherwise default to Unauthorized
+ if _, ok := status.FromError(err); !ok {
+ err = status.Errorf(status.Unauthorized, "token invalid")
+ }
+ util.WriteError(r.Context(), err, w)
return
}
h.ServeHTTP(w, request)
@@ -88,8 +114,8 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
}
// CheckJWTFromRequest checks if the JWT is valid
-func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) {
- token, err := getTokenFromJWTRequest(auth)
+func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
+ token, err := getTokenFromJWTRequest(authHeaderParts)
// If an error occurs, call the error handler and return an error
if err != nil {
@@ -115,7 +141,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h
}
if userAuth.AccountId != accountId {
- log.WithContext(ctx).Debugf("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
+ log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
userAuth.AccountId = accountId
}
@@ -139,12 +165,22 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h
}
// CheckPATFromRequest checks if the PAT is valid
-func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) {
- token, err := getTokenFromPATRequest(auth)
+func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
+ token, err := getTokenFromPATRequest(authHeaderParts)
if err != nil {
return r, fmt.Errorf("error extracting token: %w", err)
}
+ if m.patUsageTracker != nil {
+ m.patUsageTracker.IncrementUsage(token)
+ }
+
+ if m.rateLimiter != nil {
+ if !m.rateLimiter.Allow(token) {
+ return r, status.Errorf(status.TooManyRequests, "too many requests")
+ }
+ }
+
ctx := r.Context()
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
if err != nil {
@@ -159,7 +195,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
return r, err
}
- userAuth := nbcontext.UserAuth{
+ userAuth := auth.UserAuth{
UserId: user.Id,
AccountId: user.AccountID,
Domain: accDomain,
diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go
index d815f5422..ba4d16796 100644
--- a/management/server/http/middleware/auth_middleware_test.go
+++ b/management/server/http/middleware/auth_middleware_test.go
@@ -12,11 +12,12 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/auth"
- nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
+ nbauth "github.com/netbirdio/netbird/shared/auth"
+ nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
const (
@@ -27,7 +28,9 @@ const (
domainCategory = "domainCategory"
userID = "userID"
tokenID = "tokenID"
+ tokenID2 = "tokenID2"
PAT = "nbp_PAT"
+ PAT2 = "nbp_PAT2"
JWT = "JWT"
wrongToken = "wrongToken"
)
@@ -49,6 +52,15 @@ var testAccount = &types.Account{
CreatedAt: time.Now().UTC(),
LastUsed: util.ToPtr(time.Now().UTC()),
},
+ tokenID2: {
+ ID: tokenID2,
+ Name: "My second token",
+ HashedToken: "someHash2",
+ ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)),
+ CreatedBy: userID,
+ CreatedAt: time.Now().UTC(),
+ LastUsed: util.ToPtr(time.Now().UTC()),
+ },
},
},
},
@@ -58,12 +70,15 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
if token == PAT {
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
}
+ if token == PAT2 {
+ return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID2], testAccount.Domain, testAccount.DomainCategory, nil
+ }
return nil, nil, "", "", fmt.Errorf("PAT invalid")
}
-func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
+func mockValidateAndParseToken(_ context.Context, token string) (nbauth.UserAuth, *jwt.Token, error) {
if token == JWT {
- return nbcontext.UserAuth{
+ return nbauth.UserAuth{
UserId: userID,
AccountId: accountID,
Domain: testAccount.Domain,
@@ -77,17 +92,17 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA
Valid: true,
}, nil
}
- return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid")
+ return nbauth.UserAuth{}, nil, fmt.Errorf("JWT invalid")
}
func mockMarkPATUsed(_ context.Context, token string) error {
- if token == tokenID {
+ if token == tokenID || token == tokenID2 {
return nil
}
return fmt.Errorf("Should never get reached")
}
-func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
+func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbauth.UserAuth, token *jwt.Token) (nbauth.UserAuth, error) {
if userAuth.IsChild || userAuth.IsPAT {
return userAuth, nil
}
@@ -183,15 +198,17 @@ func TestAuthMiddleware_Handler(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
- func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
+ func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
- func(ctx context.Context, userAuth nbcontext.UserAuth) error {
+ func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
- func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
+ func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
+ nil,
+ nil,
)
handlerToTest := authMiddleware.Handler(nextHandler)
@@ -221,18 +238,290 @@ func TestAuthMiddleware_Handler(t *testing.T) {
}
}
+func TestAuthMiddleware_RateLimiting(t *testing.T) {
+ mockAuth := &auth.MockManager{
+ ValidateAndParseTokenFunc: mockValidateAndParseToken,
+ EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
+ MarkPATUsedFunc: mockMarkPATUsed,
+ GetPATInfoFunc: mockGetAccountInfoFromPAT,
+ }
+
+ t.Run("PAT Token Rate Limiting - Burst Works", func(t *testing.T) {
+ // Configure rate limiter: 10 requests per minute with burst of 5
+ rateLimitConfig := &RateLimiterConfig{
+ RequestsPerMinute: 10,
+ Burst: 5,
+ CleanupInterval: 5 * time.Minute,
+ LimiterTTL: 10 * time.Minute,
+ }
+
+ authMiddleware := NewAuthMiddleware(
+ mockAuth,
+ func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
+ return userAuth.AccountId, userAuth.UserId, nil
+ },
+ func(ctx context.Context, userAuth nbauth.UserAuth) error {
+ return nil
+ },
+ func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
+ return &types.User{}, nil
+ },
+ rateLimitConfig,
+ nil,
+ )
+
+ handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ // Make burst requests - all should succeed
+ successCount := 0
+ for i := 0; i < 5; i++ {
+ req := httptest.NewRequest("GET", "http://testing/test", nil)
+ req.Header.Set("Authorization", "Token "+PAT)
+ rec := httptest.NewRecorder()
+
+ handler.ServeHTTP(rec, req)
+ if rec.Code == http.StatusOK {
+ successCount++
+ }
+ }
+
+ assert.Equal(t, 5, successCount, "All burst requests should succeed")
+
+ // The 6th request should fail (exceeded burst)
+ req := httptest.NewRequest("GET", "http://testing/test", nil)
+ req.Header.Set("Authorization", "Token "+PAT)
+ rec := httptest.NewRecorder()
+
+ handler.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Request beyond burst should be rate limited")
+ })
+
+ t.Run("PAT Token Rate Limiting - Rate Limit Enforced", func(t *testing.T) {
+ // Configure very low rate limit: 1 request per minute
+ rateLimitConfig := &RateLimiterConfig{
+ RequestsPerMinute: 1,
+ Burst: 1,
+ CleanupInterval: 5 * time.Minute,
+ LimiterTTL: 10 * time.Minute,
+ }
+
+ authMiddleware := NewAuthMiddleware(
+ mockAuth,
+ func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
+ return userAuth.AccountId, userAuth.UserId, nil
+ },
+ func(ctx context.Context, userAuth nbauth.UserAuth) error {
+ return nil
+ },
+ func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
+ return &types.User{}, nil
+ },
+ rateLimitConfig,
+ nil,
+ )
+
+ handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ // First request should succeed
+ req := httptest.NewRequest("GET", "http://testing/test", nil)
+ req.Header.Set("Authorization", "Token "+PAT)
+ rec := httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
+
+ // Second request should fail (rate limited)
+ req = httptest.NewRequest("GET", "http://testing/test", nil)
+ req.Header.Set("Authorization", "Token "+PAT)
+ rec = httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
+ })
+
+ t.Run("Bearer Token Not Rate Limited", func(t *testing.T) {
+ // Configure strict rate limit
+ rateLimitConfig := &RateLimiterConfig{
+ RequestsPerMinute: 1,
+ Burst: 1,
+ CleanupInterval: 5 * time.Minute,
+ LimiterTTL: 10 * time.Minute,
+ }
+
+ authMiddleware := NewAuthMiddleware(
+ mockAuth,
+ func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
+ return userAuth.AccountId, userAuth.UserId, nil
+ },
+ func(ctx context.Context, userAuth nbauth.UserAuth) error {
+ return nil
+ },
+ func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
+ return &types.User{}, nil
+ },
+ rateLimitConfig,
+ nil,
+ )
+
+ handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ // Make multiple requests with Bearer token - all should succeed
+ successCount := 0
+ for i := 0; i < 10; i++ {
+ req := httptest.NewRequest("GET", "http://testing/test", nil)
+ req.Header.Set("Authorization", "Bearer "+JWT)
+ rec := httptest.NewRecorder()
+
+ handler.ServeHTTP(rec, req)
+ if rec.Code == http.StatusOK {
+ successCount++
+ }
+ }
+
+ assert.Equal(t, 10, successCount, "All Bearer token requests should succeed (not rate limited)")
+ })
+
+ t.Run("PAT Token Rate Limiting Per Token", func(t *testing.T) {
+ // Configure rate limiter
+ rateLimitConfig := &RateLimiterConfig{
+ RequestsPerMinute: 1,
+ Burst: 1,
+ CleanupInterval: 5 * time.Minute,
+ LimiterTTL: 10 * time.Minute,
+ }
+
+ authMiddleware := NewAuthMiddleware(
+ mockAuth,
+ func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
+ return userAuth.AccountId, userAuth.UserId, nil
+ },
+ func(ctx context.Context, userAuth nbauth.UserAuth) error {
+ return nil
+ },
+ func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
+ return &types.User{}, nil
+ },
+ rateLimitConfig,
+ nil,
+ )
+
+ handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ // Use first PAT token
+ req := httptest.NewRequest("GET", "http://testing/test", nil)
+ req.Header.Set("Authorization", "Token "+PAT)
+ rec := httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT should succeed")
+
+ // Second request with same token should fail
+ req = httptest.NewRequest("GET", "http://testing/test", nil)
+ req.Header.Set("Authorization", "Token "+PAT)
+ rec = httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with same PAT should be rate limited")
+
+ // Use second PAT token - should succeed because it has independent rate limit
+ req = httptest.NewRequest("GET", "http://testing/test", nil)
+ req.Header.Set("Authorization", "Token "+PAT2)
+ rec = httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT2 should succeed (independent rate limit)")
+
+ // Second request with PAT2 should also be rate limited
+ req = httptest.NewRequest("GET", "http://testing/test", nil)
+ req.Header.Set("Authorization", "Token "+PAT2)
+ rec = httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with PAT2 should be rate limited")
+
+ // JWT should still work (not rate limited)
+ req = httptest.NewRequest("GET", "http://testing/test", nil)
+ req.Header.Set("Authorization", "Bearer "+JWT)
+ rec = httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code, "JWT request should succeed (not rate limited)")
+ })
+
+ t.Run("Rate Limiter Cleanup", func(t *testing.T) {
+ // Configure rate limiter with short cleanup interval and TTL for testing
+ rateLimitConfig := &RateLimiterConfig{
+ RequestsPerMinute: 60,
+ Burst: 1,
+ CleanupInterval: 100 * time.Millisecond,
+ LimiterTTL: 200 * time.Millisecond,
+ }
+
+ authMiddleware := NewAuthMiddleware(
+ mockAuth,
+ func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
+ return userAuth.AccountId, userAuth.UserId, nil
+ },
+ func(ctx context.Context, userAuth nbauth.UserAuth) error {
+ return nil
+ },
+ func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
+ return &types.User{}, nil
+ },
+ rateLimitConfig,
+ nil,
+ )
+
+ handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ // First request - should succeed
+ req := httptest.NewRequest("GET", "http://testing/test", nil)
+ req.Header.Set("Authorization", "Token "+PAT)
+ rec := httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
+
+ // Second request immediately - should fail (burst exhausted)
+ req = httptest.NewRequest("GET", "http://testing/test", nil)
+ req.Header.Set("Authorization", "Token "+PAT)
+ rec = httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
+
+ // Wait for limiter to be cleaned up (TTL + cleanup interval + buffer)
+ time.Sleep(400 * time.Millisecond)
+
+ // After cleanup, the limiter should be removed and recreated with full burst capacity
+ req = httptest.NewRequest("GET", "http://testing/test", nil)
+ req.Header.Set("Authorization", "Token "+PAT)
+ rec = httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code, "Request after cleanup should succeed (new limiter with full burst)")
+
+ // Verify it's a fresh limiter by checking burst is reset
+ req = httptest.NewRequest("GET", "http://testing/test", nil)
+ req.Header.Set("Authorization", "Token "+PAT)
+ rec = httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
+ })
+}
+
func TestAuthMiddleware_Handler_Child(t *testing.T) {
tt := []struct {
name string
path string
authHeader string
- expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status
+ expectedUserAuth *nbauth.UserAuth // nil expects 401 response status
}{
{
name: "Valid PAT Token",
path: "/test",
authHeader: "Token " + PAT,
- expectedUserAuth: &nbcontext.UserAuth{
+ expectedUserAuth: &nbauth.UserAuth{
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
@@ -244,7 +533,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
name: "Valid PAT Token accesses child",
path: "/test?account=xyz",
authHeader: "Token " + PAT,
- expectedUserAuth: &nbcontext.UserAuth{
+ expectedUserAuth: &nbauth.UserAuth{
AccountId: "xyz",
UserId: userID,
Domain: testAccount.Domain,
@@ -257,7 +546,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
name: "Valid JWT Token",
path: "/test",
authHeader: "Bearer " + JWT,
- expectedUserAuth: &nbcontext.UserAuth{
+ expectedUserAuth: &nbauth.UserAuth{
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
@@ -269,7 +558,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
name: "Valid JWT Token with child",
path: "/test?account=xyz",
authHeader: "Bearer " + JWT,
- expectedUserAuth: &nbcontext.UserAuth{
+ expectedUserAuth: &nbauth.UserAuth{
AccountId: "xyz",
UserId: userID,
Domain: testAccount.Domain,
@@ -288,15 +577,17 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
- func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
+ func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
- func(ctx context.Context, userAuth nbcontext.UserAuth) error {
+ func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
- func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
+ func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
+ nil,
+ nil,
)
for _, tc := range tt {
diff --git a/management/server/http/middleware/pat_usage_tracker.go b/management/server/http/middleware/pat_usage_tracker.go
new file mode 100644
index 000000000..331c288e7
--- /dev/null
+++ b/management/server/http/middleware/pat_usage_tracker.go
@@ -0,0 +1,87 @@
+package middleware
+
+import (
+ "context"
+ "maps"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "go.opentelemetry.io/otel/metric"
+)
+
+// PATUsageTracker tracks PAT usage metrics
+type PATUsageTracker struct {
+ usageCounters map[string]int64
+ mu sync.Mutex
+ stopChan chan struct{}
+ ctx context.Context
+ histogram metric.Int64Histogram
+}
+
+// NewPATUsageTracker creates a new PAT usage tracker with metrics
+func NewPATUsageTracker(ctx context.Context, meter metric.Meter) (*PATUsageTracker, error) {
+ histogram, err := meter.Int64Histogram(
+ "management.pat.usage_distribution",
+ metric.WithUnit("1"),
+ metric.WithDescription("Distribution of PAT token usage counts per minute"),
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ tracker := &PATUsageTracker{
+ usageCounters: make(map[string]int64),
+ stopChan: make(chan struct{}),
+ ctx: ctx,
+ histogram: histogram,
+ }
+
+ go tracker.reportLoop()
+
+ return tracker, nil
+}
+
+// IncrementUsage increments the usage counter for a given token
+func (t *PATUsageTracker) IncrementUsage(token string) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.usageCounters[token]++
+}
+
+// reportLoop reports the usage buckets every minute
+func (t *PATUsageTracker) reportLoop() {
+ ticker := time.NewTicker(1 * time.Minute)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ t.reportUsageBuckets()
+ case <-t.stopChan:
+ return
+ }
+ }
+}
+
+// reportUsageBuckets reports all token usage counts and resets counters
+func (t *PATUsageTracker) reportUsageBuckets() {
+ t.mu.Lock()
+ snapshot := maps.Clone(t.usageCounters)
+
+ clear(t.usageCounters)
+ t.mu.Unlock()
+
+ totalTokens := len(snapshot)
+ if totalTokens > 0 {
+ for _, count := range snapshot {
+ t.histogram.Record(t.ctx, count)
+ }
+ log.Debugf("PAT usage in last minute: %d unique tokens used", totalTokens)
+ }
+}
+
+// Stop stops the reporting goroutine
+func (t *PATUsageTracker) Stop() {
+ close(t.stopChan)
+}
diff --git a/management/server/http/middleware/rate_limiter.go b/management/server/http/middleware/rate_limiter.go
new file mode 100644
index 000000000..a6266d4f3
--- /dev/null
+++ b/management/server/http/middleware/rate_limiter.go
@@ -0,0 +1,146 @@
+package middleware
+
+import (
+ "context"
+ "sync"
+ "time"
+
+ "golang.org/x/time/rate"
+)
+
+// RateLimiterConfig holds configuration for the API rate limiter
+type RateLimiterConfig struct {
+ // RequestsPerMinute defines the rate at which tokens are replenished
+ RequestsPerMinute float64
+ // Burst defines the maximum number of requests that can be made in a burst
+ Burst int
+ // CleanupInterval defines how often to clean up old limiters (how often garbage collection runs)
+ CleanupInterval time.Duration
+ // LimiterTTL defines how long a limiter should be kept after last use (age threshold for removal)
+ LimiterTTL time.Duration
+}
+
+// DefaultRateLimiterConfig returns a default configuration
+func DefaultRateLimiterConfig() *RateLimiterConfig {
+ return &RateLimiterConfig{
+ RequestsPerMinute: 100,
+ Burst: 120,
+ CleanupInterval: 5 * time.Minute,
+ LimiterTTL: 10 * time.Minute,
+ }
+}
+
+// limiterEntry holds a rate limiter and its last access time
+type limiterEntry struct {
+ limiter *rate.Limiter
+ lastAccess time.Time
+}
+
+// APIRateLimiter manages rate limiting for API tokens
+type APIRateLimiter struct {
+ config *RateLimiterConfig
+ limiters map[string]*limiterEntry
+ mu sync.RWMutex
+ stopChan chan struct{}
+}
+
+// NewAPIRateLimiter creates a new API rate limiter with the given configuration
+func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
+ if config == nil {
+ config = DefaultRateLimiterConfig()
+ }
+
+ rl := &APIRateLimiter{
+ config: config,
+ limiters: make(map[string]*limiterEntry),
+ stopChan: make(chan struct{}),
+ }
+
+ go rl.cleanupLoop()
+
+ return rl
+}
+
+// Allow checks if a request for the given key (token) is allowed
+func (rl *APIRateLimiter) Allow(key string) bool {
+ limiter := rl.getLimiter(key)
+ return limiter.Allow()
+}
+
+// Wait blocks until the rate limiter allows another request for the given key
+// Returns an error if the context is canceled
+func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
+ limiter := rl.getLimiter(key)
+ return limiter.Wait(ctx)
+}
+
+// getLimiter retrieves or creates a rate limiter for the given key
+func (rl *APIRateLimiter) getLimiter(key string) *rate.Limiter {
+ rl.mu.RLock()
+ entry, exists := rl.limiters[key]
+ rl.mu.RUnlock()
+
+ if exists {
+ rl.mu.Lock()
+ entry.lastAccess = time.Now()
+ rl.mu.Unlock()
+ return entry.limiter
+ }
+
+ rl.mu.Lock()
+ defer rl.mu.Unlock()
+
+ if entry, exists := rl.limiters[key]; exists {
+ entry.lastAccess = time.Now()
+ return entry.limiter
+ }
+
+ requestsPerSecond := rl.config.RequestsPerMinute / 60.0
+ limiter := rate.NewLimiter(rate.Limit(requestsPerSecond), rl.config.Burst)
+ rl.limiters[key] = &limiterEntry{
+ limiter: limiter,
+ lastAccess: time.Now(),
+ }
+
+ return limiter
+}
+
+// cleanupLoop periodically removes old limiters that haven't been used recently
+func (rl *APIRateLimiter) cleanupLoop() {
+ ticker := time.NewTicker(rl.config.CleanupInterval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ rl.cleanup()
+ case <-rl.stopChan:
+ return
+ }
+ }
+}
+
+// cleanup removes limiters that haven't been used within the TTL period
+func (rl *APIRateLimiter) cleanup() {
+ rl.mu.Lock()
+ defer rl.mu.Unlock()
+
+ now := time.Now()
+ for key, entry := range rl.limiters {
+ if now.Sub(entry.lastAccess) > rl.config.LimiterTTL {
+ delete(rl.limiters, key)
+ }
+ }
+}
+
+// Stop stops the cleanup goroutine
+func (rl *APIRateLimiter) Stop() {
+ close(rl.stopChan)
+}
+
+// Reset removes the rate limiter for a specific key
+func (rl *APIRateLimiter) Reset(key string) {
+ rl.mu.Lock()
+ defer rl.mu.Unlock()
+ delete(rl.limiters, key)
+}
diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go
index 741f03f18..e8513feb5 100644
--- a/management/server/http/testing/testing_tools/channel/channel.go
+++ b/management/server/http/testing/testing_tools/channel/channel.go
@@ -7,14 +7,22 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
- "github.com/netbirdio/management-integrations/integrations"
"github.com/stretchr/testify/assert"
+ "github.com/netbirdio/management-integrations/integrations"
+ "github.com/netbirdio/netbird/management/internals/server/config"
+
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
+ "github.com/netbirdio/netbird/management/internals/modules/peers"
+ ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
+ "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
+
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
- "github.com/netbirdio/netbird/management/server/auth"
- nbcontext "github.com/netbirdio/netbird/management/server/context"
+ serverauth "github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
http2 "github.com/netbirdio/netbird/management/server/http"
@@ -22,15 +30,15 @@ import (
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
- "github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/users"
+ "github.com/netbirdio/netbird/shared/auth"
)
-func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
+func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *network_map.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir())
if err != nil {
t.Fatalf("Failed to create test store: %v", err)
@@ -42,7 +50,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
t.Fatalf("Failed to create metrics: %v", err)
}
- peersUpdateManager := server.NewPeersUpdateManager(nil)
+ peersUpdateManager := update_channel.NewPeersUpdateManager(nil)
updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId)
done := make(chan struct{})
if validateUpdate {
@@ -62,14 +70,18 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
- am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
+
+ ctx := context.Background()
+ requestBuffer := server.NewAccountRequestBuffer(ctx, store)
+ networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
+ am, err := server.BuildManager(ctx, nil, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
// @note this is required so that PAT's validate from store, but JWT's are mocked
- authManager := auth.NewManager(store, "", "", "", "", []string{}, false)
- authManagerMock := &auth.MockManager{
+ authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
+ authManagerMock := &serverauth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
MarkPATUsedFunc: authManager.MarkPATUsed,
@@ -82,7 +94,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
groupsManagerMock := groups.NewManagerMock()
peersManager := peers.NewManager(store, permissionsManager)
- apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager)
+ apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
@@ -90,7 +102,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
return apiHandler, am, done
}
-func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) {
+func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
select {
case msg := <-updateMessage:
@@ -100,7 +112,7 @@ func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server
}
}
-func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) {
+func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage, expected *network_map.UpdateMessage) {
t.Helper()
select {
@@ -114,8 +126,8 @@ func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.Up
}
}
-func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
- userAuth := nbcontext.UserAuth{}
+func mockValidateAndParseToken(_ context.Context, token string) (auth.UserAuth, *jwt.Token, error) {
+ userAuth := auth.UserAuth{}
switch token {
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
diff --git a/management/server/idp/auth0_test.go b/management/server/idp/auth0_test.go
index 66c16870b..bc352f117 100644
--- a/management/server/idp/auth0_test.go
+++ b/management/server/idp/auth0_test.go
@@ -26,9 +26,11 @@ type mockHTTPClient struct {
}
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
- body, err := io.ReadAll(req.Body)
- if err == nil {
- c.reqBody = string(body)
+ if req.Body != nil {
+ body, err := io.ReadAll(req.Body)
+ if err == nil {
+ c.reqBody = string(body)
+ }
}
return &http.Response{
StatusCode: c.code,
diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go
index 51f99b3b7..f06e57196 100644
--- a/management/server/idp/idp.go
+++ b/management/server/idp/idp.go
@@ -201,6 +201,12 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
APIToken: config.ExtraConfig["ApiToken"],
}
return NewJumpCloudManager(jumpcloudConfig, appMetrics)
+ case "pocketid":
+ pocketidConfig := PocketIdClientConfig{
+ APIToken: config.ExtraConfig["ApiToken"],
+ ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"],
+ }
+ return NewPocketIdManager(pocketidConfig, appMetrics)
default:
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
}
diff --git a/management/server/idp/pocketid.go b/management/server/idp/pocketid.go
new file mode 100644
index 000000000..38a5cc67f
--- /dev/null
+++ b/management/server/idp/pocketid.go
@@ -0,0 +1,384 @@
+package idp
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "slices"
+ "strings"
+ "time"
+
+ "github.com/netbirdio/netbird/management/server/telemetry"
+)
+
+type PocketIdManager struct {
+ managementEndpoint string
+ apiToken string
+ httpClient ManagerHTTPClient
+ credentials ManagerCredentials
+ helper ManagerHelper
+ appMetrics telemetry.AppMetrics
+}
+
+type pocketIdCustomClaimDto struct {
+ Key string `json:"key"`
+ Value string `json:"value"`
+}
+
+type pocketIdUserDto struct {
+ CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
+ Disabled bool `json:"disabled"`
+ DisplayName string `json:"displayName"`
+ Email string `json:"email"`
+ FirstName string `json:"firstName"`
+ ID string `json:"id"`
+ IsAdmin bool `json:"isAdmin"`
+ LastName string `json:"lastName"`
+ LdapID string `json:"ldapId"`
+ Locale string `json:"locale"`
+ UserGroups []pocketIdUserGroupDto `json:"userGroups"`
+ Username string `json:"username"`
+}
+
+type pocketIdUserCreateDto struct {
+ Disabled bool `json:"disabled,omitempty"`
+ DisplayName string `json:"displayName"`
+ Email string `json:"email"`
+ FirstName string `json:"firstName"`
+ IsAdmin bool `json:"isAdmin,omitempty"`
+ LastName string `json:"lastName,omitempty"`
+ Locale string `json:"locale,omitempty"`
+ Username string `json:"username"`
+}
+
+type pocketIdPaginatedUserDto struct {
+ Data []pocketIdUserDto `json:"data"`
+ Pagination pocketIdPaginationDto `json:"pagination"`
+}
+
+type pocketIdPaginationDto struct {
+ CurrentPage int `json:"currentPage"`
+ ItemsPerPage int `json:"itemsPerPage"`
+ TotalItems int `json:"totalItems"`
+ TotalPages int `json:"totalPages"`
+}
+
+func (p *pocketIdUserDto) userData() *UserData {
+ return &UserData{
+ Email: p.Email,
+ Name: p.DisplayName,
+ ID: p.ID,
+ AppMetadata: AppMetadata{},
+ }
+}
+
+type pocketIdUserGroupDto struct {
+ CreatedAt string `json:"createdAt"`
+ CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
+ FriendlyName string `json:"friendlyName"`
+ ID string `json:"id"`
+ LdapID string `json:"ldapId"`
+ Name string `json:"name"`
+}
+
+func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMetrics) (*PocketIdManager, error) {
+ httpTransport := http.DefaultTransport.(*http.Transport).Clone()
+ httpTransport.MaxIdleConns = 5
+
+ httpClient := &http.Client{
+ Timeout: 10 * time.Second,
+ Transport: httpTransport,
+ }
+ helper := JsonParser{}
+
+ if config.ManagementEndpoint == "" {
+ return nil, fmt.Errorf("pocketId IdP configuration is incomplete, ManagementEndpoint is missing")
+ }
+
+ if config.APIToken == "" {
+ return nil, fmt.Errorf("pocketId IdP configuration is incomplete, APIToken is missing")
+ }
+
+ credentials := &PocketIdCredentials{
+ clientConfig: config,
+ httpClient: httpClient,
+ helper: helper,
+ appMetrics: appMetrics,
+ }
+
+ return &PocketIdManager{
+ managementEndpoint: config.ManagementEndpoint,
+ apiToken: config.APIToken,
+ httpClient: httpClient,
+ credentials: credentials,
+ helper: helper,
+ appMetrics: appMetrics,
+ }, nil
+}
+
+func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) {
+ var MethodsWithBody = []string{http.MethodPost, http.MethodPut}
+ if !slices.Contains(MethodsWithBody, method) && body != "" {
+ return nil, fmt.Errorf("Body provided to unsupported method: %s", method)
+ }
+
+ reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource)
+ if query != nil {
+ reqURL = fmt.Sprintf("%s?%s", reqURL, query.Encode())
+ }
+ var req *http.Request
+ var err error
+ if body != "" {
+ req, err = http.NewRequestWithContext(ctx, method, reqURL, strings.NewReader(body))
+ } else {
+ req, err = http.NewRequestWithContext(ctx, method, reqURL, nil)
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Add("X-API-KEY", p.apiToken)
+
+ if body != "" {
+ req.Header.Add("content-type", "application/json")
+ req.Header.Add("content-length", fmt.Sprintf("%d", req.ContentLength))
+ }
+
+ resp, err := p.httpClient.Do(req)
+ if err != nil {
+ if p.appMetrics != nil {
+ p.appMetrics.IDPMetrics().CountRequestError()
+ }
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
+ if p.appMetrics != nil {
+ p.appMetrics.IDPMetrics().CountRequestStatusError()
+ }
+
+ return nil, fmt.Errorf("received unexpected status code from PocketID API: %d", resp.StatusCode)
+ }
+
+ return io.ReadAll(resp.Body)
+}
+
+// getAllUsersPaginated fetches all users from PocketID API using pagination
+func (p *PocketIdManager) getAllUsersPaginated(ctx context.Context, searchParams url.Values) ([]pocketIdUserDto, error) {
+ var allUsers []pocketIdUserDto
+ currentPage := 1
+
+ for {
+ params := url.Values{}
+ // Copy existing search parameters
+ for key, values := range searchParams {
+ params[key] = values
+ }
+
+ params.Set("pagination[limit]", "100")
+ params.Set("pagination[page]", fmt.Sprintf("%d", currentPage))
+
+ body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "")
+ if err != nil {
+ return nil, err
+ }
+
+ var profiles pocketIdPaginatedUserDto
+ err = p.helper.Unmarshal(body, &profiles)
+ if err != nil {
+ return nil, err
+ }
+
+ allUsers = append(allUsers, profiles.Data...)
+
+ // Check if we've reached the last page
+ if currentPage >= profiles.Pagination.TotalPages {
+ break
+ }
+
+ currentPage++
+ }
+
+ return allUsers, nil
+}
+
+func (p *PocketIdManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
+ return nil
+}
+
+func (p *PocketIdManager) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
+ body, err := p.request(ctx, http.MethodGet, "users/"+userId, nil, "")
+ if err != nil {
+ return nil, err
+ }
+
+ if p.appMetrics != nil {
+ p.appMetrics.IDPMetrics().CountGetUserDataByID()
+ }
+
+ var user pocketIdUserDto
+ err = p.helper.Unmarshal(body, &user)
+ if err != nil {
+ return nil, err
+ }
+
+ userData := user.userData()
+ userData.AppMetadata = appMetadata
+
+ return userData, nil
+}
+
+func (p *PocketIdManager) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
+ // Get all users using pagination
+ allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
+ if err != nil {
+ return nil, err
+ }
+
+ if p.appMetrics != nil {
+ p.appMetrics.IDPMetrics().CountGetAccount()
+ }
+
+ users := make([]*UserData, 0)
+ for _, profile := range allUsers {
+ userData := profile.userData()
+ userData.AppMetadata.WTAccountID = accountId
+
+ users = append(users, userData)
+ }
+ return users, nil
+}
+
+func (p *PocketIdManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
+ // Get all users using pagination
+ allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
+ if err != nil {
+ return nil, err
+ }
+
+ if p.appMetrics != nil {
+ p.appMetrics.IDPMetrics().CountGetAllAccounts()
+ }
+
+ indexedUsers := make(map[string][]*UserData)
+ for _, profile := range allUsers {
+ userData := profile.userData()
+ indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
+ }
+
+ return indexedUsers, nil
+}
+
+func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
+ firstLast := strings.Split(name, " ")
+
+ createUser := pocketIdUserCreateDto{
+ Disabled: false,
+ DisplayName: name,
+ Email: email,
+ FirstName: firstLast[0],
+ LastName: firstLast[1],
+ Username: firstLast[0] + "." + firstLast[1],
+ }
+ payload, err := p.helper.Marshal(createUser)
+ if err != nil {
+ return nil, err
+ }
+
+ body, err := p.request(ctx, http.MethodPost, "users", nil, string(payload))
+ if err != nil {
+ return nil, err
+ }
+ var newUser pocketIdUserDto
+ err = p.helper.Unmarshal(body, &newUser)
+ if err != nil {
+ return nil, err
+ }
+
+ if p.appMetrics != nil {
+ p.appMetrics.IDPMetrics().CountCreateUser()
+ }
+ var pending bool = true
+ ret := &UserData{
+ Email: email,
+ Name: name,
+ ID: newUser.ID,
+ AppMetadata: AppMetadata{
+ WTAccountID: accountID,
+ WTPendingInvite: &pending,
+ WTInvitedBy: invitedByEmail,
+ },
+ }
+ return ret, nil
+}
+
+func (p *PocketIdManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
+ params := url.Values{
+ // This value a
+ "search": []string{email},
+ }
+ body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "")
+ if err != nil {
+ return nil, err
+ }
+
+ if p.appMetrics != nil {
+ p.appMetrics.IDPMetrics().CountGetUserByEmail()
+ }
+
+ var profiles struct{ data []pocketIdUserDto }
+ err = p.helper.Unmarshal(body, &profiles)
+ if err != nil {
+ return nil, err
+ }
+
+ users := make([]*UserData, 0)
+ for _, profile := range profiles.data {
+ users = append(users, profile.userData())
+ }
+ return users, nil
+}
+
+func (p *PocketIdManager) InviteUserByID(ctx context.Context, userID string) error {
+ _, err := p.request(ctx, http.MethodPut, "users/"+userID+"/one-time-access-email", nil, "")
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (p *PocketIdManager) DeleteUser(ctx context.Context, userID string) error {
+ _, err := p.request(ctx, http.MethodDelete, "users/"+userID, nil, "")
+ if err != nil {
+ return err
+ }
+
+ if p.appMetrics != nil {
+ p.appMetrics.IDPMetrics().CountDeleteUser()
+ }
+
+ return nil
+}
+
+var _ Manager = (*PocketIdManager)(nil)
+
+type PocketIdClientConfig struct {
+ APIToken string
+ ManagementEndpoint string
+}
+
+type PocketIdCredentials struct {
+ clientConfig PocketIdClientConfig
+ helper ManagerHelper
+ httpClient ManagerHTTPClient
+ appMetrics telemetry.AppMetrics
+}
+
+var _ ManagerCredentials = (*PocketIdCredentials)(nil)
+
+func (p PocketIdCredentials) Authenticate(_ context.Context) (JWTToken, error) {
+ return JWTToken{}, nil
+}
diff --git a/management/server/idp/pocketid_test.go b/management/server/idp/pocketid_test.go
new file mode 100644
index 000000000..126a76919
--- /dev/null
+++ b/management/server/idp/pocketid_test.go
@@ -0,0 +1,137 @@
+package idp
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/netbirdio/netbird/management/server/telemetry"
+)
+
+func TestNewPocketIdManager(t *testing.T) {
+ type test struct {
+ name string
+ inputConfig PocketIdClientConfig
+ assertErrFunc require.ErrorAssertionFunc
+ assertErrFuncMessage string
+ }
+
+ defaultTestConfig := PocketIdClientConfig{
+ APIToken: "api_token",
+ ManagementEndpoint: "http://localhost",
+ }
+
+ tests := []test{
+ {
+ name: "Good Configuration",
+ inputConfig: defaultTestConfig,
+ assertErrFunc: require.NoError,
+ assertErrFuncMessage: "shouldn't return error",
+ },
+ {
+ name: "Missing ManagementEndpoint",
+ inputConfig: PocketIdClientConfig{
+ APIToken: defaultTestConfig.APIToken,
+ ManagementEndpoint: "",
+ },
+ assertErrFunc: require.Error,
+ assertErrFuncMessage: "should return error when field empty",
+ },
+ {
+ name: "Missing APIToken",
+ inputConfig: PocketIdClientConfig{
+ APIToken: "",
+ ManagementEndpoint: defaultTestConfig.ManagementEndpoint,
+ },
+ assertErrFunc: require.Error,
+ assertErrFuncMessage: "should return error when field empty",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ _, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{})
+ tc.assertErrFunc(t, err, tc.assertErrFuncMessage)
+ })
+ }
+}
+
+func TestPocketID_GetUserDataByID(t *testing.T) {
+ client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`}
+
+ mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
+ require.NoError(t, err)
+ mgr.httpClient = client
+
+ md := AppMetadata{WTAccountID: "acc1"}
+ got, err := mgr.GetUserDataByID(context.Background(), "u1", md)
+ require.NoError(t, err)
+ assert.Equal(t, "u1", got.ID)
+ assert.Equal(t, "user1@example.com", got.Email)
+ assert.Equal(t, "User One", got.Name)
+ assert.Equal(t, "acc1", got.AppMetadata.WTAccountID)
+}
+
+func TestPocketID_GetAccount_WithPagination(t *testing.T) {
+ // Single page response with two users
+ client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
+
+ mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
+ require.NoError(t, err)
+ mgr.httpClient = client
+
+ users, err := mgr.GetAccount(context.Background(), "accX")
+ require.NoError(t, err)
+ require.Len(t, users, 2)
+ assert.Equal(t, "u1", users[0].ID)
+ assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID)
+ assert.Equal(t, "u2", users[1].ID)
+}
+
+func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) {
+ client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
+
+ mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
+ require.NoError(t, err)
+ mgr.httpClient = client
+
+ accounts, err := mgr.GetAllAccounts(context.Background())
+ require.NoError(t, err)
+ require.Len(t, accounts[UnsetAccountID], 2)
+}
+
+func TestPocketID_CreateUser(t *testing.T) {
+ client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`}
+
+ mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
+ require.NoError(t, err)
+ mgr.httpClient = client
+
+ ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com")
+ require.NoError(t, err)
+ assert.Equal(t, "newid", ud.ID)
+ assert.Equal(t, "new@example.com", ud.Email)
+ assert.Equal(t, "New User", ud.Name)
+ assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID)
+ if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) {
+ assert.True(t, *ud.AppMetadata.WTPendingInvite)
+ }
+ assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy)
+}
+
+func TestPocketID_InviteAndDeleteUser(t *testing.T) {
+ // Same mock for both calls; returns OK with empty JSON
+ client := &mockHTTPClient{code: 200, resBody: `{}`}
+
+ mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
+ require.NoError(t, err)
+ mgr.httpClient = client
+
+ err = mgr.InviteUserByID(context.Background(), "u1")
+ require.NoError(t, err)
+
+ err = mgr.DeleteUser(context.Background(), "u1")
+ require.NoError(t, err)
+}
diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go
index 21f11bfce..69ea668ad 100644
--- a/management/server/integrated_validator.go
+++ b/management/server/integrated_validator.go
@@ -88,7 +88,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
return true, nil
}
-func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
+func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
var err error
var groups []*types.Group
var peers []*nbpeer.Peer
@@ -96,20 +96,30 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
- return nil, err
+ return nil, nil, err
}
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil {
- return nil, err
+ return nil, nil, err
}
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
- return nil, err
+ return nil, nil, err
}
- return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra)
+ validPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ invalidPeers, err := am.integratedPeerValidator.GetInvalidPeers(ctx, accountID, settings.Extra)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return validPeers, invalidPeers, nil
}
type MockIntegratedValidator struct {
@@ -117,7 +127,7 @@ type MockIntegratedValidator struct {
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
}
-func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
+func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error {
return nil
}
@@ -136,7 +146,11 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID
return validatedPeers, nil
}
-func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer {
+func (a MockIntegratedValidator) GetInvalidPeers(_ context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) {
+ return make(map[string]string), nil
+}
+
+func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer {
return peer
}
diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go
index ce632d567..326fbfaf0 100644
--- a/management/server/integrations/integrated_validator/interface.go
+++ b/management/server/integrations/integrated_validator/interface.go
@@ -3,18 +3,19 @@ package integrated_validator
import (
"context"
- "github.com/netbirdio/netbird/shared/management/proto"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/shared/management/proto"
)
// IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface {
- ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
+ ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
- PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer
+ PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
+ GetInvalidPeers(ctx context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error)
PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error
SetPeerInvalidationListener(fn func(accountID string, peerIDs []string))
Stop(ctx context.Context)
diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go
index a34d2086b..42f192c0a 100644
--- a/management/server/management_proto_test.go
+++ b/management/server/management_proto_test.go
@@ -22,11 +22,16 @@ import (
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
+ "github.com/netbirdio/netbird/management/internals/modules/peers"
+ "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
+ nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
- "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -321,99 +326,6 @@ func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServ
return loginResp, nil
}
-func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
- testingServerKey, err := wgtypes.GeneratePrivateKey()
- if err != nil {
- t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
- }
-
- testingClientKey, err := wgtypes.GeneratePrivateKey()
- if err != nil {
- t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
- }
-
- testCases := []struct {
- name string
- inputFlow *config.DeviceAuthorizationFlow
- expectedFlow *mgmtProto.DeviceAuthorizationFlow
- expectedErrFunc require.ErrorAssertionFunc
- expectedErrMSG string
- expectedComparisonFunc require.ComparisonAssertionFunc
- expectedComparisonMSG string
- }{
- {
- name: "Testing No Device Flow Config",
- inputFlow: nil,
- expectedErrFunc: require.Error,
- expectedErrMSG: "should return error",
- },
- {
- name: "Testing Invalid Device Flow Provider Config",
- inputFlow: &config.DeviceAuthorizationFlow{
- Provider: "NoNe",
- ProviderConfig: config.ProviderConfig{
- ClientID: "test",
- },
- },
- expectedErrFunc: require.Error,
- expectedErrMSG: "should return error",
- },
- {
- name: "Testing Full Device Flow Config",
- inputFlow: &config.DeviceAuthorizationFlow{
- Provider: "hosted",
- ProviderConfig: config.ProviderConfig{
- ClientID: "test",
- },
- },
- expectedFlow: &mgmtProto.DeviceAuthorizationFlow{
- Provider: 0,
- ProviderConfig: &mgmtProto.ProviderConfig{
- ClientID: "test",
- },
- },
- expectedErrFunc: require.NoError,
- expectedErrMSG: "should not return error",
- expectedComparisonFunc: require.Equal,
- expectedComparisonMSG: "should match",
- },
- }
-
- for _, testCase := range testCases {
- t.Run(testCase.name, func(t *testing.T) {
- mgmtServer := &GRPCServer{
- wgKey: testingServerKey,
- config: &config.Config{
- DeviceAuthorizationFlow: testCase.inputFlow,
- },
- }
-
- message := &mgmtProto.DeviceAuthorizationFlowRequest{}
-
- encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message)
- require.NoError(t, err, "should be able to encrypt message")
-
- resp, err := mgmtServer.GetDeviceAuthorizationFlow(
- context.TODO(),
- &mgmtProto.EncryptedMessage{
- WgPubKey: testingClientKey.PublicKey().String(),
- Body: encryptedMSG,
- },
- )
- testCase.expectedErrFunc(t, err, testCase.expectedErrMSG)
- if testCase.expectedComparisonFunc != nil {
- flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
-
- err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
- require.NoError(t, err, "should be able to decrypt")
-
- testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
- testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG)
- }
- })
- }
-}
-
func startManagementForTest(t *testing.T, testFile string, config *config.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
@@ -427,7 +339,6 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
t.Fatal(err)
}
- peersUpdateManager := NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
ctx := context.WithValue(context.Background(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
@@ -451,7 +362,12 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock()
- accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
+ updateManager := update_channel.NewPeersUpdateManager(metrics)
+ requestBuffer := NewAccountRequestBuffer(ctx, store)
+ ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager))
+
+ networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeralMgr, config)
+ accountManager, err := BuildManager(ctx, nil, store, networkMapController, nil, "",
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
@@ -459,10 +375,13 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
return nil, nil, "", cleanup, err
}
- secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
+ secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
+ if err != nil {
+ cleanup()
+ return nil, nil, "", cleanup, err
+ }
- ephemeralMgr := manager.NewEphemeralManager(store, accountManager)
- mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
+ mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController)
if err != nil {
return nil, nil, "", cleanup, err
}
@@ -764,9 +683,38 @@ func Test_LoginPerformance(t *testing.T) {
peerLogin := types.PeerLogin{
WireGuardPubKey: key.String(),
SSHKey: "random",
- Meta: extractPeerMeta(context.Background(), meta),
- SetupKey: setupKey.Key,
- ConnectionIP: net.IP{1, 1, 1, 1},
+ Meta: nbpeer.PeerSystemMeta{
+ Hostname: meta.GetHostname(),
+ GoOS: meta.GetGoOS(),
+ Kernel: meta.GetKernel(),
+ Platform: meta.GetPlatform(),
+ OS: meta.GetOS(),
+ OSVersion: meta.GetOSVersion(),
+ WtVersion: meta.GetNetbirdVersion(),
+ UIVersion: meta.GetUiVersion(),
+ KernelVersion: meta.GetKernelVersion(),
+ SystemSerialNumber: meta.GetSysSerialNumber(),
+ SystemProductName: meta.GetSysProductName(),
+ SystemManufacturer: meta.GetSysManufacturer(),
+ Environment: nbpeer.Environment{
+ Cloud: meta.GetEnvironment().GetCloud(),
+ Platform: meta.GetEnvironment().GetPlatform(),
+ },
+ Flags: nbpeer.Flags{
+ RosenpassEnabled: meta.GetFlags().GetRosenpassEnabled(),
+ RosenpassPermissive: meta.GetFlags().GetRosenpassPermissive(),
+ ServerSSHAllowed: meta.GetFlags().GetServerSSHAllowed(),
+ DisableClientRoutes: meta.GetFlags().GetDisableClientRoutes(),
+ DisableServerRoutes: meta.GetFlags().GetDisableServerRoutes(),
+ DisableDNS: meta.GetFlags().GetDisableDNS(),
+ DisableFirewall: meta.GetFlags().GetDisableFirewall(),
+ BlockLANAccess: meta.GetFlags().GetBlockLANAccess(),
+ BlockInbound: meta.GetFlags().GetBlockInbound(),
+ LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(),
+ },
+ },
+ SetupKey: setupKey.Key,
+ ConnectionIP: net.IP{1, 1, 1, 1},
}
login := func() error {
diff --git a/management/server/management_test.go b/management/server/management_test.go
index 1a5e47354..648201d4e 100644
--- a/management/server/management_test.go
+++ b/management/server/management_test.go
@@ -20,12 +20,16 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
+ "github.com/netbirdio/netbird/management/internals/modules/peers"
+ ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
+ nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
- "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -176,7 +180,6 @@ func startServer(
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
}
- peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
@@ -199,13 +202,19 @@ func startServer(
AnyTimes()
permissionsManager := permissions.NewManager(str)
+
+ ctx := context.Background()
+ updateManager := update_channel.NewPeersUpdateManager(metrics)
+ requestBuffer := server.NewAccountRequestBuffer(ctx, str)
+ networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str, permissionsManager)), config)
+
accountManager, err := server.BuildManager(
context.Background(),
+ nil,
str,
- peersUpdateManager,
+ networkMapController,
nil,
"",
- "netbird.selfhosted",
eventStore,
nil,
false,
@@ -220,18 +229,19 @@ func startServer(
}
groupsManager := groups.NewManager(str, permissionsManager, accountManager)
- secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
- mgmtServer, err := server.NewServer(
- context.Background(),
+ secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
+ if err != nil {
+ t.Fatalf("failed creating secrets manager: %v", err)
+ }
+ mgmtServer, err := nbgrpc.NewServer(
config,
accountManager,
settingsMockManager,
- peersUpdateManager,
secretsManager,
nil,
- &manager.EphemeralManager{},
nil,
server.MockIntegratedValidator{},
+ networkMapController,
)
if err != nil {
t.Fatalf("failed creating management server: %v", err)
diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go
index d160e7269..928098dbe 100644
--- a/management/server/mock_server/account_mock.go
+++ b/management/server/mock_server/account_mock.go
@@ -2,6 +2,7 @@ package mock_server
import (
"context"
+ "github.com/netbirdio/netbird/shared/auth"
"net"
"net/netip"
"time"
@@ -12,10 +13,8 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
- nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
- "github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -34,11 +33,11 @@ type MockAccountManager struct {
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
- GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
+ GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
- SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
+ SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
@@ -84,7 +83,7 @@ type MockAccountManager struct {
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error)
- GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
+ GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
GetDNSDomainFunc func(settings *types.Settings) string
StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
@@ -94,7 +93,7 @@ type MockAccountManager struct {
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
- SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
+ SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error
@@ -119,15 +118,16 @@ type MockAccountManager struct {
GetStoreFunc func() store.Store
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) error
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
- GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
+ GetCurrentUserInfoFunc func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error)
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
- AllowSyncFunc func(string, uint64) bool
- UpdateAccountPeersFunc func(ctx context.Context, accountID string)
- BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
+ AllowSyncFunc func(string, uint64) bool
+ UpdateAccountPeersFunc func(ctx context.Context, accountID string)
+ BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
+ RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
}
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
@@ -177,11 +177,11 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
}
-func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
+func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
}
- return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
+ return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
@@ -189,17 +189,17 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
panic("implement me")
}
-func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
+func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
account, err := am.GetAccountFunc(ctx, accountID)
if err != nil {
- return nil, err
+ return nil, nil, err
}
approvedPeers := make(map[string]struct{})
for id := range account.Peers {
approvedPeers[id] = struct{}{}
}
- return approvedPeers, nil
+ return approvedPeers, nil, nil
}
// GetGroup mock implementation of GetGroup from server.AccountManager interface
@@ -469,7 +469,7 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string,
}
// GetUser mock implementation of GetUser from server.AccountManager interface
-func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
+func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) {
if am.GetUserFromUserAuthFunc != nil {
return am.GetUserFromUserAuthFunc(ctx, userAuth)
}
@@ -674,7 +674,7 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
}
-func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
+func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) {
if am.GetAccountIDFromUserAuthFunc != nil {
return am.GetAccountIDFromUserAuthFunc(ctx, userAuth)
}
@@ -746,11 +746,11 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLog
}
// SyncPeer mocks SyncPeer of the AccountManager interface
-func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
+func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if am.SyncPeerFunc != nil {
return am.SyncPeerFunc(ctx, sync, accountID)
}
- return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
+ return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
}
// GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface
@@ -936,7 +936,7 @@ func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, acco
return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented")
}
-func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error {
+func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error {
return status.Errorf(codes.Unimplemented, "method SyncUserJWTGroups is not implemented")
}
@@ -968,21 +968,23 @@ func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string
return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented")
}
-func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
+func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) {
if am.GetCurrentUserInfoFunc != nil {
return am.GetCurrentUserInfoFunc(ctx, userAuth)
}
return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented")
}
-// SetEphemeralManager mocks SetEphemeralManager of the AccountManager interface
-func (am *MockAccountManager) SetEphemeralManager(em ephemeral.Manager) {
- // Mock implementation - does nothing
-}
-
func (am *MockAccountManager) AllowSync(key string, hash uint64) bool {
if am.AllowSyncFunc != nil {
return am.AllowSyncFunc(key, hash)
}
return true
}
+
+func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountID string) error {
+ if am.RecalculateNetworkMapCacheFunc != nil {
+ return am.RecalculateNetworkMapCacheFunc(ctx, accountID)
+ }
+ return nil
+}
diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go
index 6c985410c..e3dd8b0b8 100644
--- a/management/server/nameserver_test.go
+++ b/management/server/nameserver_test.go
@@ -11,6 +11,11 @@ import (
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
+ "github.com/netbirdio/netbird/management/internals/modules/peers"
+ ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
+ "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -785,7 +790,13 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
AnyTimes()
permissionsManager := permissions.NewManager(store)
- return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
+
+ ctx := context.Background()
+ updateManager := update_channel.NewPeersUpdateManager(metrics)
+ requestBuffer := NewAccountRequestBuffer(ctx, store)
+ networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
+
+ return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}
func createNSStore(t *testing.T) (store.Store, error) {
@@ -975,7 +986,7 @@ func TestValidateDomain(t *testing.T) {
}
func TestNameServerAccountPeersUpdate(t *testing.T) {
- manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
+ manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
var newNameServerGroupA *nbdns.NameServerGroup
var newNameServerGroupB *nbdns.NameServerGroup
@@ -994,9 +1005,9 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
})
assert.NoError(t, err)
- updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
+ updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
- manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
+ updateManager.CloseChannel(context.Background(), peer1.ID)
})
// Creating a nameserver group with a distribution group no peers should not update account peers
diff --git a/management/server/networks/resources/manager_test.go b/management/server/networks/resources/manager_test.go
index c6cec6f7e..e2dea2c6b 100644
--- a/management/server/networks/resources/manager_test.go
+++ b/management/server/networks/resources/manager_test.go
@@ -10,8 +10,8 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/management/server/permissions"
- "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store"
+ "github.com/netbirdio/netbird/shared/management/status"
)
func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go
index 7874be858..6b8cf9412 100644
--- a/management/server/networks/resources/types/resource.go
+++ b/management/server/networks/resources/types/resource.go
@@ -8,11 +8,11 @@ import (
"github.com/rs/xid"
- nbDomain "github.com/netbirdio/netbird/shared/management/domain"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/route"
+ nbDomain "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api"
)
diff --git a/management/server/networks/routers/manager_test.go b/management/server/networks/routers/manager_test.go
index 8054d05c6..6be90baa7 100644
--- a/management/server/networks/routers/manager_test.go
+++ b/management/server/networks/routers/manager_test.go
@@ -9,8 +9,8 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions"
- "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store"
+ "github.com/netbirdio/netbird/shared/management/status"
)
func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) {
diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go
index 72b15fd9a..e90c61a97 100644
--- a/management/server/networks/routers/types/router.go
+++ b/management/server/networks/routers/types/router.go
@@ -5,8 +5,8 @@ import (
"github.com/rs/xid"
- "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/networks/types"
+ "github.com/netbirdio/netbird/shared/management/http/api"
)
type NetworkRouter struct {
diff --git a/management/server/peer.go b/management/server/peer.go
index 4cf5d1e46..7c48a8052 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -8,8 +8,6 @@ import (
"net"
"slices"
"strings"
- "sync"
- "sync/atomic"
"time"
"github.com/rs/xid"
@@ -23,7 +21,6 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/domain"
- "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
@@ -31,7 +28,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
- "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -95,7 +91,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
// fetch all the peers that have access to the user's peers
for _, peer := range peers {
- aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap)
+ aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap, account.GetActiveGroupUsers())
for _, p := range aclPeers {
peersMap[p.ID] = p
}
@@ -106,11 +102,6 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
- start := time.Now()
- defer func() {
- log.WithContext(ctx).Debugf("MarkPeerConnected: took %v", time.Since(start))
- }()
-
var peer *nbpeer.Peer
var settings *types.Settings
var expired bool
@@ -145,9 +136,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
if expired {
- // we need to update other peers because when peer login expires all other peers are notified to disconnect from
- // the expired one. Here we notify them that connection is now allowed again.
- am.BufferUpdateAccountPeers(ctx, accountID)
+ err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
+ if err != nil {
+ return fmt.Errorf("notify network map controller of peer update: %w", err)
+ }
}
return nil
@@ -180,7 +172,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio
}
}
- log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
+ log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected)
err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
if err != nil {
@@ -203,7 +195,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
var peer *nbpeer.Peer
var settings *types.Settings
var peerGroupList []string
- var requiresPeerUpdates bool
var peerLabelChanged bool
var sshChanged bool
var loginExpirationChanged bool
@@ -226,9 +217,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return err
}
- dnsDomain = am.GetDNSDomain(settings)
+ dnsDomain = am.networkMapController.GetDNSDomain(settings)
- update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra)
+ update, _, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra)
if err != nil {
return err
}
@@ -321,10 +312,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
}
- if peerLabelChanged || requiresPeerUpdates {
- am.UpdateAccountPeers(ctx, accountID)
- } else if sshChanged {
- am.UpdateAccountPeer(ctx, accountID, peer.ID)
+ err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
+ if err != nil {
+ return nil, fmt.Errorf("notify network map controller of peer update: %w", err)
}
return peer, nil
@@ -350,7 +340,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
}
var peer *nbpeer.Peer
- var updateAccountPeers bool
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -363,11 +352,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
- updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID)
- if err != nil {
- return err
- }
-
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
if err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
@@ -387,8 +371,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
storeEvent()
}
- if updateAccountPeers && userID != activity.SystemInitiator {
- am.BufferUpdateAccountPeers(ctx, accountID)
+ if err := am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}); err != nil {
+ log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err)
}
return nil
@@ -396,41 +380,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
- account, err := am.Store.GetAccountByPeerID(ctx, peerID)
- if err != nil {
- return nil, err
- }
-
- peer := account.GetPeer(peerID)
- if peer == nil {
- return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
- }
-
- groups := make(map[string][]string)
- for groupID, group := range account.Groups {
- groups[groupID] = group.Peers
- }
-
- validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
- if err != nil {
- return nil, err
- }
- customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
-
- proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
- if err != nil {
- log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
- return nil, err
- }
-
- networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
-
- proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
- if ok {
- networkMap.Merge(proxyNetworkMap)
- }
-
- return networkMap, nil
+ return am.networkMapController.GetNetworkMap(ctx, peerID)
}
// GetPeerNetwork returns the Network for a given peer
@@ -584,7 +534,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
}
}
- newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
+ newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra, temporary)
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
@@ -634,11 +584,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
return fmt.Errorf("failed adding peer to All group: %w", err)
}
- if temporary {
- // we are running the on disconnect handler so that it is considered not connected as we are adding the peer manually
- am.ephemeralManager.OnPeerDisconnected(ctx, newPeer)
- }
-
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
if err != nil {
@@ -684,28 +629,24 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
}
- updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
- if err != nil {
- updateAccountPeers = true
- }
-
if newPeer == nil {
return nil, nil, nil, fmt.Errorf("new peer is nil")
}
opEvent.TargetID = newPeer.ID
- opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings))
+ opEvent.Meta = newPeer.EventMeta(am.networkMapController.GetDNSDomain(settings))
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
}
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
- if updateAccountPeers {
- am.BufferUpdateAccountPeers(ctx, accountID)
+ if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil {
+ log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
}
- return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
+ p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer)
+ return p, nmap, pc, err
}
func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) {
@@ -720,12 +661,7 @@ func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) {
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
-func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
- start := time.Now()
- defer func() {
- log.WithContext(ctx).Debugf("SyncPeer: took %v", time.Since(start))
- }()
-
+func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
var peer *nbpeer.Peer
var peerNotValid bool
var isStatusChanged bool
@@ -735,7 +671,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
- return nil, nil, nil, err
+ return nil, nil, nil, 0, err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -785,14 +721,17 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil
})
if err != nil {
- return nil, nil, nil, err
+ return nil, nil, nil, 0, err
}
if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) {
- am.BufferUpdateAccountPeers(ctx, accountID)
+ err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
+ if err != nil {
+ return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
+ }
}
- return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
+ return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
}
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
@@ -914,10 +853,14 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
}
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
- am.BufferUpdateAccountPeers(ctx, accountID)
+ err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
+ if err != nil {
+ return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err)
+ }
}
- return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
+ p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
+ return p, nmap, pc, err
}
// getPeerPostureChecks returns the posture checks for the peer.
@@ -1009,57 +952,6 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil
}
-func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
- start := time.Now()
- defer func() {
- log.WithContext(ctx).Debugf("getValidatedPeerWithMap: took %s", time.Since(start))
- }()
-
- if isRequiresApproval {
- network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
- if err != nil {
- return nil, nil, nil, err
- }
-
- emptyMap := &types.NetworkMap{
- Network: network.Copy(),
- }
- return peer, emptyMap, nil, nil
- }
-
- account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
- if err != nil {
- return nil, nil, nil, err
- }
-
- approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
- if err != nil {
- return nil, nil, nil, err
- }
-
- postureChecks, err := am.getPeerPostureChecks(account, peer.ID)
- if err != nil {
- return nil, nil, nil, err
- }
-
- customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
-
- proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
- if err != nil {
- log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
- return nil, nil, nil, err
- }
-
- networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics())
-
- proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
- if ok {
- networkMap.Merge(proxyNetworkMap)
- }
-
- return peer, networkMap, postureChecks, nil
-}
-
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction store.Store, user *types.User, peer *nbpeer.Peer) error {
err := checkAuth(ctx, user.Id, peer)
if err != nil {
@@ -1083,7 +975,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
return fmt.Errorf("failed to get account settings: %w", err)
}
- am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain(settings)))
+ am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.networkMapController.GetDNSDomain(settings)))
return nil
}
@@ -1165,7 +1057,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
}
for _, p := range userPeers {
- aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap)
+ aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap, account.GetActiveGroupUsers())
for _, aclPeer := range aclPeers {
if aclPeer.ID == peer.ID {
return peer, nil
@@ -1179,209 +1071,17 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
// UpdateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
- log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
-
- account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
- if err != nil {
- log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
- return
- }
-
- globalStart := time.Now()
-
- hasPeersConnected := false
- for _, peer := range account.Peers {
- if am.peersUpdateManager.HasChannel(peer.ID) {
- hasPeersConnected = true
- break
- }
-
- }
-
- if !hasPeersConnected {
- return
- }
-
- approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
- if err != nil {
- log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err)
- return
- }
-
- var wg sync.WaitGroup
- semaphore := make(chan struct{}, 10)
-
- dnsCache := &DNSConfigCache{}
- dnsDomain := am.GetDNSDomain(account.Settings)
- customZone := account.GetPeersCustomZone(ctx, dnsDomain)
- resourcePolicies := account.GetResourcePoliciesMap()
- routers := account.GetResourceRoutersMap()
-
- proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
- if err != nil {
- log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
- return
- }
-
- extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID)
- if err != nil {
- log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err)
- return
- }
-
- dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
-
- for _, peer := range account.Peers {
- if !am.peersUpdateManager.HasChannel(peer.ID) {
- log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
- continue
- }
-
- wg.Add(1)
- semaphore <- struct{}{}
- go func(p *nbpeer.Peer) {
- defer wg.Done()
- defer func() { <-semaphore }()
-
- start := time.Now()
-
- postureChecks, err := am.getPeerPostureChecks(account, p.ID)
- if err != nil {
- log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err)
- return
- }
-
- am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start))
- start = time.Now()
-
- remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
-
- am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start))
- start = time.Now()
-
- proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
- if ok {
- remotePeerNetworkMap.Merge(proxyNetworkMap)
- }
- am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
-
- peerGroups := account.GetPeerGroups(p.ID)
- start = time.Now()
- update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
- am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
-
- am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
- }(peer)
- }
-
- //
-
- wg.Wait()
- if am.metrics != nil {
- am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
- }
-}
-
-type bufferUpdate struct {
- mu sync.Mutex
- next *time.Timer
- update atomic.Bool
+ _ = am.networkMapController.UpdateAccountPeers(ctx, accountID)
}
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
- log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName())
-
- bufUpd, _ := am.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
- b := bufUpd.(*bufferUpdate)
-
- if !b.mu.TryLock() {
- b.update.Store(true)
- return
- }
-
- if b.next != nil {
- b.next.Stop()
- }
-
- go func() {
- defer b.mu.Unlock()
- am.UpdateAccountPeers(ctx, accountID)
- if !b.update.Load() {
- return
- }
- b.update.Store(false)
- if b.next == nil {
- b.next = time.AfterFunc(time.Duration(am.updateAccountPeersBufferInterval.Load()), func() {
- am.UpdateAccountPeers(ctx, accountID)
- })
- return
- }
- b.next.Reset(time.Duration(am.updateAccountPeersBufferInterval.Load()))
- }()
+ _ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID)
}
// UpdateAccountPeer updates a single peer that belongs to an account.
// Should be called when changes need to be synced to a specific peer only.
func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) {
- if !am.peersUpdateManager.HasChannel(peerId) {
- log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId)
- return
- }
-
- account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
- if err != nil {
- log.WithContext(ctx).Errorf("failed to send out updates to peer %s. failed to get account: %v", peerId, err)
- return
- }
-
- peer := account.GetPeer(peerId)
- if peer == nil {
- log.WithContext(ctx).Tracef("peer %s doesn't exists in account %s", peerId, accountId)
- return
- }
-
- approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
- if err != nil {
- log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err)
- return
- }
-
- dnsCache := &DNSConfigCache{}
- dnsDomain := am.GetDNSDomain(account.Settings)
- customZone := account.GetPeersCustomZone(ctx, dnsDomain)
- resourcePolicies := account.GetResourcePoliciesMap()
- routers := account.GetResourceRoutersMap()
-
- postureChecks, err := am.getPeerPostureChecks(account, peerId)
- if err != nil {
- log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err)
- return
- }
-
- proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId, peerId, account.Peers)
- if err != nil {
- log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
- return
- }
-
- remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
-
- proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
- if ok {
- remotePeerNetworkMap.Merge(proxyNetworkMap)
- }
-
- extraSettings, err := am.settingsManager.GetExtraSettings(ctx, peer.AccountID)
- if err != nil {
- log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
- return
- }
-
- peerGroups := account.GetPeerGroups(peerId)
- dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
-
- update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
- am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
+ _ = am.networkMapController.UpdateAccountPeer(ctx, accountId, peerId)
}
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
@@ -1527,16 +1227,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
}
-// IsPeerInActiveGroup checks if the given peer is part of a group that is used
-// in an active DNS, route, or ACL configuration.
-func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) {
- peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID)
- if err != nil {
- return false, err
- }
- return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction
-}
-
// deletePeers deletes all specified peers and sends updates to the remote peers.
// Returns a slice of functions to save events after successful peer deletion.
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
@@ -1546,14 +1236,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
if err != nil {
return nil, err
}
- dnsDomain := am.GetDNSDomain(settings)
-
- network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
- if err != nil {
- return nil, err
- }
-
- dnsFwdPort := computeForwarderPort(peers, dnsForwarderPortMinVersion)
+ dnsDomain := am.networkMapController.GetDNSDomain(settings)
for _, peer := range peers {
if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
@@ -1587,25 +1270,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil {
return nil, err
}
-
- am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{
- Update: &proto.SyncResponse{
- RemotePeers: []*proto.RemotePeerConfig{},
- RemotePeersIsEmpty: true,
- NetworkMap: &proto.NetworkMap{
- Serial: network.CurrentSerial(),
- RemotePeers: []*proto.RemotePeerConfig{},
- RemotePeersIsEmpty: true,
- FirewallRules: []*proto.FirewallRule{},
- FirewallRulesIsEmpty: true,
- DNSConfig: &proto.DNSConfig{
- ForwarderPort: dnsFwdPort,
- },
- },
- },
- NetworkMap: &types.NetworkMap{},
- })
- am.peersUpdateManager.CloseChannel(ctx, peer.ID)
peerDeletedEvents = append(peerDeletedEvents, func() {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
})
@@ -1614,14 +1278,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
return peerDeletedEvents, nil
}
-func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
- labelMap := make(map[string]struct{}, len(existingLabels))
- for _, label := range existingLabels {
- labelMap[label] = struct{}{}
- }
- return labelMap
-}
-
// validatePeerDelete checks if the peer can be deleted.
func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transaction store.Store, accountId, peerId string) error {
linkedInIngressPorts, err := am.proxyController.IsPeerInIngressPorts(ctx, accountId, peerId)
diff --git a/management/server/peer_test.go b/management/server/peer_test.go
index 42b3244ae..752563299 100644
--- a/management/server/peer_test.go
+++ b/management/server/peer_test.go
@@ -13,7 +13,6 @@ import (
"strconv"
"strings"
"sync"
- "sync/atomic"
"testing"
"time"
@@ -25,10 +24,16 @@ import (
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
+ "github.com/netbirdio/netbird/management/internals/modules/peers"
+ ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
+ "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
- "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/shared/management/status"
@@ -168,7 +173,16 @@ func TestPeer_SessionExpired(t *testing.T) {
}
func TestAccountManager_GetNetworkMap(t *testing.T) {
- manager, err := createManager(t)
+ testGetNetworkMapGeneral(t)
+}
+
+func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
+ t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
+ testGetNetworkMapGeneral(t)
+}
+
+func testGetNetworkMapGeneral(t *testing.T) {
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -240,7 +254,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
// TODO: disable until we start use policy again
t.Skip()
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -417,7 +431,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
}
func TestAccountManager_GetPeerNetwork(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -478,7 +492,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
}
func TestDefaultAccountManager_GetPeer(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -665,7 +679,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -733,12 +747,12 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
}
}
-func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, string, string, error) {
+func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, *update_channel.PeersUpdateManager, string, string, error) {
b.Helper()
- manager, err := createManager(b)
+ manager, updateManager, err := createManager(b)
if err != nil {
- return nil, "", "", err
+ return nil, nil, "", "", err
}
accountID := "test_account"
@@ -789,7 +803,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
ips := account.GetTakenIPs()
peerIP, err := types.AllocatePeerIP(account.Network.Net, ips)
if err != nil {
- return nil, "", "", err
+ return nil, nil, "", "", err
}
peerKey, _ := wgtypes.GeneratePrivateKey()
@@ -895,10 +909,10 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
err = manager.Store.SaveAccount(context.Background(), account)
if err != nil {
- return nil, "", "", err
+ return nil, nil, "", "", err
}
- return manager, accountID, regularUser, nil
+ return manager, updateManager, accountID, regularUser, nil
}
func BenchmarkGetPeers(b *testing.B) {
@@ -919,7 +933,7 @@ func BenchmarkGetPeers(b *testing.B) {
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
- manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups)
+ manager, _, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -959,7 +973,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
- manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
+ manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -971,14 +985,10 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
b.Fatalf("Failed to get account: %v", err)
}
- peerChannels := make(map[string]chan *UpdateMessage)
-
for peerID := range account.Peers {
- peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
+ updateManager.CreateChannel(ctx, peerID)
}
- manager.peersUpdateManager.peerChannels = peerChannels
-
b.ResetTimer()
start := time.Now()
@@ -1003,7 +1013,16 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
}
}
+func TestUpdateAccountPeers_Experimental(t *testing.T) {
+ t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
+ testUpdateAccountPeers(t)
+}
+
func TestUpdateAccountPeers(t *testing.T) {
+ testUpdateAccountPeers(t)
+}
+
+func testUpdateAccountPeers(t *testing.T) {
testCases := []struct {
name string
peers int
@@ -1019,7 +1038,7 @@ func TestUpdateAccountPeers(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- manager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups)
+ manager, updateManager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups)
if err != nil {
t.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -1031,20 +1050,19 @@ func TestUpdateAccountPeers(t *testing.T) {
t.Fatalf("Failed to get account: %v", err)
}
- peerChannels := make(map[string]chan *UpdateMessage)
+ peerChannels := make(map[string]chan *network_map.UpdateMessage)
for peerID := range account.Peers {
- peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
+ peerChannels[peerID] = updateManager.CreateChannel(ctx, peerID)
}
- manager.peersUpdateManager.peerChannels = peerChannels
manager.UpdateAccountPeers(ctx, account.Id)
for _, channel := range peerChannels {
update := <-channel
assert.Nil(t, update.Update.NetbirdConfig)
- assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
- assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
+ assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers))
+ assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules))
}
})
}
@@ -1079,7 +1097,7 @@ func TestToSyncResponse(t *testing.T) {
DNSLabel: "peer1",
SSHKey: "peer1-ssh-key",
}
- turnRelayToken := &Token{
+ turnRelayToken := &grpc.Token{
Payload: "turn-user",
Signature: "turn-pass",
}
@@ -1159,9 +1177,9 @@ func TestToSyncResponse(t *testing.T) {
},
},
}
- dnsCache := &DNSConfigCache{}
+ dnsCache := &cache.DNSConfigCache{}
accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true}
- response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, dnsForwarderPort)
+ response := grpc.ToSyncResponse(context.Background(), config, config.HttpConfig, config.DeviceAuthorizationFlow, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort))
assert.NotNil(t, response)
// assert peer config
@@ -1212,6 +1230,7 @@ func TestToSyncResponse(t *testing.T) {
assert.Equal(t, "route1", response.NetworkMap.Routes[0].NetID)
// assert network map DNSConfig
assert.Equal(t, true, response.NetworkMap.DNSConfig.ServiceEnable)
+ //nolint
assert.Equal(t, int64(dnsForwarderPort), response.NetworkMap.DNSConfig.ForwarderPort)
assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones))
assert.Equal(t, 2, len(response.NetworkMap.DNSConfig.NameServerGroups))
@@ -1271,7 +1290,12 @@ func Test_RegisterPeerByUser(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(s)
- am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
+ ctx := context.Background()
+ updateManager := update_channel.NewPeersUpdateManager(metrics)
+ requestBuffer := NewAccountRequestBuffer(ctx, s)
+ networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
+
+ am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1351,7 +1375,12 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
AnyTimes()
permissionsManager := permissions.NewManager(s)
- am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
+ ctx := context.Background()
+ updateManager := update_channel.NewPeersUpdateManager(metrics)
+ requestBuffer := NewAccountRequestBuffer(ctx, s)
+ networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
+
+ am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1499,7 +1528,12 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
permissionsManager := permissions.NewManager(s)
- am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
+ ctx := context.Background()
+ updateManager := update_channel.NewPeersUpdateManager(metrics)
+ requestBuffer := NewAccountRequestBuffer(ctx, s)
+ networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
+
+ am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1548,6 +1582,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
}
func Test_LoginPeer(t *testing.T) {
+ t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
@@ -1573,7 +1608,12 @@ func Test_LoginPeer(t *testing.T) {
AnyTimes()
permissionsManager := permissions.NewManager(s)
- am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
+ ctx := context.Background()
+ updateManager := update_channel.NewPeersUpdateManager(metrics)
+ requestBuffer := NewAccountRequestBuffer(ctx, s)
+ networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
+
+ am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1706,7 +1746,7 @@ func Test_LoginPeer(t *testing.T) {
}
func TestPeerAccountPeersUpdate(t *testing.T) {
- manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
+ manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID)
require.NoError(t, err)
@@ -1763,13 +1803,14 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
var peer5 *nbpeer.Peer
var peer6 *nbpeer.Peer
- updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
+ updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
- manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
+ updateManager.CloseChannel(context.Background(), peer1.ID)
})
// Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update
t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) {
+ t.Skip("Currently all updates will trigger a network map")
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
@@ -1790,7 +1831,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("adding peer to unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
- peerShouldNotReceiveUpdate(t, updMsg) //
+ peerShouldReceiveUpdate(t, updMsg) //
close(done)
}()
@@ -1815,7 +1856,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("deleting peer with unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
- peerShouldNotReceiveUpdate(t, updMsg)
+ peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
@@ -1871,6 +1912,8 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
})
t.Run("validator requires no update", func(t *testing.T) {
+ t.Skip("Currently all updates will trigger a network map")
+
requireNoUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, false, nil
}
@@ -2072,7 +2115,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
}
func Test_DeletePeer(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -2169,7 +2212,7 @@ func Test_IsUniqueConstraintError(t *testing.T) {
}
func Test_AddPeer(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -2257,136 +2300,8 @@ func Test_AddPeer(t *testing.T) {
assert.Equal(t, uint64(totalPeers), account.Network.Serial)
}
-func TestBufferUpdateAccountPeers(t *testing.T) {
- const (
- peersCount = 1000
- updateAccountInterval = 50 * time.Millisecond
- )
-
- var (
- deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32
- uapLastRun, dpLastRun atomic.Int64
-
- totalNewRuns, totalOldRuns int
- )
-
- uap := func(ctx context.Context, accountID string) {
- updatePeersDeleted.Store(deletedPeers.Load())
- updatePeersRuns.Add(1)
- uapLastRun.Store(time.Now().UnixMilli())
- time.Sleep(100 * time.Millisecond)
- }
-
- t.Run("new approach", func(t *testing.T) {
- updatePeersRuns.Store(0)
- updatePeersDeleted.Store(0)
- deletedPeers.Store(0)
-
- var mustore sync.Map
- bufupd := func(ctx context.Context, accountID string) {
- mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{})
- b := mu.(*bufferUpdate)
-
- if !b.mu.TryLock() {
- b.update.Store(true)
- return
- }
-
- if b.next != nil {
- b.next.Stop()
- }
-
- go func() {
- defer b.mu.Unlock()
- uap(ctx, accountID)
- if !b.update.Load() {
- return
- }
- b.update.Store(false)
- b.next = time.AfterFunc(updateAccountInterval, func() {
- uap(ctx, accountID)
- })
- }()
- }
- dp := func(ctx context.Context, accountID, peerID, userID string) error {
- deletedPeers.Add(1)
- dpLastRun.Store(time.Now().UnixMilli())
- time.Sleep(10 * time.Millisecond)
- bufupd(ctx, accountID)
- return nil
- }
-
- am := mock_server.MockAccountManager{
- UpdateAccountPeersFunc: uap,
- BufferUpdateAccountPeersFunc: bufupd,
- DeletePeerFunc: dp,
- }
- empty := ""
- for range peersCount {
- //nolint
- am.DeletePeer(context.Background(), empty, empty, empty)
- }
- time.Sleep(100 * time.Millisecond)
-
- assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
- assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
- assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
-
- totalNewRuns = int(updatePeersRuns.Load())
- })
-
- t.Run("old approach", func(t *testing.T) {
- updatePeersRuns.Store(0)
- updatePeersDeleted.Store(0)
- deletedPeers.Store(0)
-
- var mustore sync.Map
- bufupd := func(ctx context.Context, accountID string) {
- mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{})
- b := mu.(*sync.Mutex)
-
- if !b.TryLock() {
- return
- }
-
- go func() {
- time.Sleep(updateAccountInterval)
- b.Unlock()
- uap(ctx, accountID)
- }()
- }
- dp := func(ctx context.Context, accountID, peerID, userID string) error {
- deletedPeers.Add(1)
- dpLastRun.Store(time.Now().UnixMilli())
- time.Sleep(10 * time.Millisecond)
- bufupd(ctx, accountID)
- return nil
- }
-
- am := mock_server.MockAccountManager{
- UpdateAccountPeersFunc: uap,
- BufferUpdateAccountPeersFunc: bufupd,
- DeletePeerFunc: dp,
- }
- empty := ""
- for range peersCount {
- //nolint
- am.DeletePeer(context.Background(), empty, empty, empty)
- }
- time.Sleep(100 * time.Millisecond)
-
- assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
- assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
- assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
-
- totalOldRuns = int(updatePeersRuns.Load())
- })
- assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
- t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
-}
-
func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -2423,7 +2338,7 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
}
func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -2457,7 +2372,7 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
}
func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -2522,7 +2437,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
}
func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
diff --git a/management/server/peers/manager.go b/management/server/peers/manager.go
deleted file mode 100644
index cb135f4ac..000000000
--- a/management/server/peers/manager.go
+++ /dev/null
@@ -1,68 +0,0 @@
-package peers
-
-//go:generate go run github.com/golang/mock/mockgen -package peers -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
-
-import (
- "context"
- "fmt"
-
- "github.com/netbirdio/netbird/management/server/peer"
- "github.com/netbirdio/netbird/management/server/permissions"
- "github.com/netbirdio/netbird/management/server/permissions/modules"
- "github.com/netbirdio/netbird/management/server/permissions/operations"
- "github.com/netbirdio/netbird/management/server/store"
- "github.com/netbirdio/netbird/shared/management/status"
-)
-
-type Manager interface {
- GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error)
- GetPeerAccountID(ctx context.Context, peerID string) (string, error)
- GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error)
- GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error)
-}
-
-type managerImpl struct {
- store store.Store
- permissionsManager permissions.Manager
-}
-
-func NewManager(store store.Store, permissionsManager permissions.Manager) Manager {
- return &managerImpl{
- store: store,
- permissionsManager: permissionsManager,
- }
-}
-
-func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
- allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
- if err != nil {
- return nil, fmt.Errorf("failed to validate user permissions: %w", err)
- }
-
- if !allowed {
- return nil, status.NewPermissionDeniedError()
- }
-
- return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
-}
-
-func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
- allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
- if err != nil {
- return nil, fmt.Errorf("failed to validate user permissions: %w", err)
- }
-
- if !allowed {
- return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
- }
-
- return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
-}
-
-func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
- return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
-}
-
-func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
- return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
-}
diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go
index 891fa59bb..e6bdd2025 100644
--- a/management/server/permissions/manager.go
+++ b/management/server/permissions/manager.go
@@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus"
+ "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -22,6 +23,7 @@ type Manager interface {
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error)
+ SetAccountManager(accountManager account.Manager)
}
type managerImpl struct {
@@ -121,3 +123,7 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR
return permissions, nil
}
+
+func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
+ // no-op
+}
diff --git a/management/server/permissions/manager_mock.go b/management/server/permissions/manager_mock.go
index fa115d628..ec9f263f9 100644
--- a/management/server/permissions/manager_mock.go
+++ b/management/server/permissions/manager_mock.go
@@ -9,6 +9,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
+ account "github.com/netbirdio/netbird/management/server/account"
modules "github.com/netbirdio/netbird/management/server/permissions/modules"
operations "github.com/netbirdio/netbird/management/server/permissions/operations"
roles "github.com/netbirdio/netbird/management/server/permissions/roles"
@@ -53,6 +54,18 @@ func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role)
}
+// SetAccountManager mocks base method.
+func (m *MockManager) SetAccountManager(accountManager account.Manager) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetAccountManager", accountManager)
+}
+
+// SetAccountManager indicates an expected call of SetAccountManager.
+func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager)
+}
+
// ValidateAccountAccess mocks base method.
func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
m.ctrl.T.Helper()
diff --git a/management/server/policy.go b/management/server/policy.go
index 9e4b3f73a..3e84c3d10 100644
--- a/management/server/policy.go
+++ b/management/server/policy.go
@@ -10,7 +10,6 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
- "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/posture"
@@ -252,31 +251,3 @@ func getValidGroupIDs(groups map[string]*types.Group, groupIDs []string) []strin
return validIDs
}
-
-// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
-func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule {
- result := make([]*proto.FirewallRule, len(rules))
- for i := range rules {
- rule := rules[i]
-
- fwRule := &proto.FirewallRule{
- PolicyID: []byte(rule.PolicyID),
- PeerIP: rule.PeerIP,
- Direction: getProtoDirection(rule.Direction),
- Action: getProtoAction(rule.Action),
- Protocol: getProtoProtocol(rule.Protocol),
- Port: rule.Port,
- }
-
- if shouldUsePortRange(fwRule) {
- fwRule.PortInfo = rule.PortRange.ToProto()
- }
-
- result[i] = fwRule
- }
- return result
-}
-
-func shouldUsePortRange(rule *proto.FirewallRule) bool {
- return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
-}
diff --git a/management/server/policy_test.go b/management/server/policy_test.go
index 4a08f4c33..a3f987732 100644
--- a/management/server/policy_test.go
+++ b/management/server/policy_test.go
@@ -246,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
t.Run("check that all peers get map", func(t *testing.T) {
for _, p := range account.Peers {
- peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers)
+ peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), p, validatedPeers, account.GetActiveGroupUsers())
assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
}
})
t.Run("check first peer map details", func(t *testing.T) {
- peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers)
+ peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 8)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
@@ -266,7 +266,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
expectedFirewallRules := []*types.FirewallRule{
{
- PeerIP: "0.0.0.0",
+ PeerIP: "100.65.14.88",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
@@ -274,7 +274,103 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
PolicyID: "RuleDefault",
},
{
- PeerIP: "0.0.0.0",
+ PeerIP: "100.65.14.88",
+ Direction: types.FirewallRuleDirectionOUT,
+ Action: "accept",
+ Protocol: "all",
+ Port: "",
+ PolicyID: "RuleDefault",
+ },
+ {
+ PeerIP: "100.65.62.5",
+ Direction: types.FirewallRuleDirectionIN,
+ Action: "accept",
+ Protocol: "all",
+ Port: "",
+ PolicyID: "RuleDefault",
+ },
+ {
+ PeerIP: "100.65.62.5",
+ Direction: types.FirewallRuleDirectionOUT,
+ Action: "accept",
+ Protocol: "all",
+ Port: "",
+ PolicyID: "RuleDefault",
+ },
+ {
+ PeerIP: "100.65.254.139",
+ Direction: types.FirewallRuleDirectionIN,
+ Action: "accept",
+ Protocol: "all",
+ Port: "",
+ PolicyID: "RuleDefault",
+ },
+ {
+ PeerIP: "100.65.254.139",
+ Direction: types.FirewallRuleDirectionOUT,
+ Action: "accept",
+ Protocol: "all",
+ Port: "",
+ PolicyID: "RuleDefault",
+ },
+ {
+ PeerIP: "100.65.32.206",
+ Direction: types.FirewallRuleDirectionIN,
+ Action: "accept",
+ Protocol: "all",
+ Port: "",
+ PolicyID: "RuleDefault",
+ },
+ {
+ PeerIP: "100.65.32.206",
+ Direction: types.FirewallRuleDirectionOUT,
+ Action: "accept",
+ Protocol: "all",
+ Port: "",
+ PolicyID: "RuleDefault",
+ },
+ {
+ PeerIP: "100.65.250.202",
+ Direction: types.FirewallRuleDirectionIN,
+ Action: "accept",
+ Protocol: "all",
+ Port: "",
+ PolicyID: "RuleDefault",
+ },
+ {
+ PeerIP: "100.65.250.202",
+ Direction: types.FirewallRuleDirectionOUT,
+ Action: "accept",
+ Protocol: "all",
+ Port: "",
+ PolicyID: "RuleDefault",
+ },
+ {
+ PeerIP: "100.65.13.186",
+ Direction: types.FirewallRuleDirectionIN,
+ Action: "accept",
+ Protocol: "all",
+ Port: "",
+ PolicyID: "RuleDefault",
+ },
+ {
+ PeerIP: "100.65.13.186",
+ Direction: types.FirewallRuleDirectionOUT,
+ Action: "accept",
+ Protocol: "all",
+ Port: "",
+ PolicyID: "RuleDefault",
+ },
+ {
+ PeerIP: "100.65.29.55",
+ Direction: types.FirewallRuleDirectionIN,
+ Action: "accept",
+ Protocol: "all",
+ Port: "",
+ PolicyID: "RuleDefault",
+ },
+ {
+ PeerIP: "100.65.29.55",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
@@ -413,7 +509,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
})
t.Run("check port ranges support for older peers", func(t *testing.T) {
- peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers)
+ peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 1)
assert.Contains(t, peers, account.Peers["peerI"])
@@ -539,7 +635,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}
t.Run("check first peer map", func(t *testing.T) {
- peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
+ peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerC"])
expectedFirewallRules := []*types.FirewallRule{
@@ -569,7 +665,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
})
t.Run("check second peer map", func(t *testing.T) {
- peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
+ peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerB"])
expectedFirewallRules := []*types.FirewallRule{
@@ -601,7 +697,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
account.Policies[1].Rules[0].Bidirectional = false
t.Run("check first peer map directional only", func(t *testing.T) {
- peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
+ peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerC"])
expectedFirewallRules := []*types.FirewallRule{
@@ -623,7 +719,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
})
t.Run("check second peer map directional only", func(t *testing.T) {
- peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
+ peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerB"])
expectedFirewallRules := []*types.FirewallRule{
@@ -821,7 +917,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
// will establish a connection with all source peers satisfying the NB posture check.
- peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
+ peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -831,12 +927,60 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
- peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
+ peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
- assert.Len(t, firewallRules, 1)
+ assert.Len(t, firewallRules, 7)
expectedFirewallRules := []*types.FirewallRule{
{
- PeerIP: "0.0.0.0",
+ PeerIP: "100.65.80.39",
+ Direction: types.FirewallRuleDirectionOUT,
+ Action: "accept",
+ Protocol: "tcp",
+ Port: "80",
+ PolicyID: "RuleSwarm",
+ },
+ {
+ PeerIP: "100.65.14.88",
+ Direction: types.FirewallRuleDirectionOUT,
+ Action: "accept",
+ Protocol: "tcp",
+ Port: "80",
+ PolicyID: "RuleSwarm",
+ },
+ {
+ PeerIP: "100.65.62.5",
+ Direction: types.FirewallRuleDirectionOUT,
+ Action: "accept",
+ Protocol: "tcp",
+ Port: "80",
+ PolicyID: "RuleSwarm",
+ },
+ {
+ PeerIP: "100.65.32.206",
+ Direction: types.FirewallRuleDirectionOUT,
+ Action: "accept",
+ Protocol: "tcp",
+ Port: "80",
+ PolicyID: "RuleSwarm",
+ },
+ {
+ PeerIP: "100.65.13.186",
+ Direction: types.FirewallRuleDirectionOUT,
+ Action: "accept",
+ Protocol: "tcp",
+ Port: "80",
+ PolicyID: "RuleSwarm",
+ },
+ {
+ PeerIP: "100.65.29.55",
+ Direction: types.FirewallRuleDirectionOUT,
+ Action: "accept",
+ Protocol: "tcp",
+ Port: "80",
+ PolicyID: "RuleSwarm",
+ },
+ {
+ PeerIP: "100.65.21.56",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
@@ -848,7 +992,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
- peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
+ peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -858,7 +1002,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
- peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
+ peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -873,19 +1017,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
- peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
+ peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
- peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
+ peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
- peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
+ peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
@@ -900,14 +1044,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
- peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
+ peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 3)
assert.Len(t, firewallRules, 3)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"])
- peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers)
+ peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 5)
// assert peers from Group Swarm
assert.Contains(t, peers, account.Peers["peerD"])
@@ -991,7 +1135,7 @@ func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int {
}
func TestPolicyAccountPeersUpdate(t *testing.T) {
- manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
+ manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
g := []*types.Group{
{
@@ -1020,9 +1164,9 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
assert.NoError(t, err)
}
- updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
+ updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
- manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
+ updateManager.CloseChannel(context.Background(), peer1.ID)
})
var policyWithGroupRulesNoPeers *types.Policy
diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go
index d65dc5045..f0bbbc32e 100644
--- a/management/server/posture/checks.go
+++ b/management/server/posture/checks.go
@@ -7,8 +7,8 @@ import (
"regexp"
"github.com/hashicorp/go-version"
- "github.com/netbirdio/netbird/shared/management/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
diff --git a/management/server/posture/os_version.go b/management/server/posture/os_version.go
index 411f4c2c6..2ef97a066 100644
--- a/management/server/posture/os_version.go
+++ b/management/server/posture/os_version.go
@@ -82,7 +82,7 @@ func (c *OSVersionCheck) Validate() error {
func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) {
if check == nil {
- log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS)
+ log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS)
return false, nil
}
@@ -107,7 +107,7 @@ func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *M
func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) {
if check == nil {
- log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS)
+ log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS)
return false, nil
}
diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go
index 943f2a970..9a743eb8c 100644
--- a/management/server/posture_checks.go
+++ b/management/server/posture_checks.go
@@ -2,19 +2,15 @@ package server
import (
"context"
- "errors"
- "fmt"
"slices"
"github.com/rs/xid"
- "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
- "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -136,27 +132,6 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
}
-// getPeerPostureChecks returns the posture checks applied for a given peer.
-func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) {
- peerPostureChecks := make(map[string]*posture.Checks)
-
- if len(account.PostureChecks) == 0 {
- return nil, nil
- }
-
- for _, policy := range account.Policies {
- if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
- continue
- }
-
- if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
- return nil, err
- }
- }
-
- return maps.Values(peerPostureChecks), nil
-}
-
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
@@ -183,7 +158,7 @@ func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.St
// validatePostureChecks validates the posture checks.
func validatePostureChecks(ctx context.Context, transaction store.Store, accountID string, postureChecks *posture.Checks) error {
if err := postureChecks.Validate(); err != nil {
- return status.Errorf(status.InvalidArgument, err.Error()) //nolint
+ return status.Errorf(status.InvalidArgument, "%s", err.Error()) //nolint
}
// If the posture check already has an ID, verify its existence in the store.
@@ -211,50 +186,6 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account
return nil
}
-// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
-func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error {
- isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
- if err != nil {
- return err
- }
-
- if !isInGroup {
- return nil
- }
-
- for _, sourcePostureCheckID := range policy.SourcePostureChecks {
- postureCheck := account.GetPostureChecks(sourcePostureCheckID)
- if postureCheck == nil {
- return errors.New("failed to add policy posture checks: posture checks not found")
- }
- peerPostureChecks[sourcePostureCheckID] = postureCheck
- }
-
- return nil
-}
-
-// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
-func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) {
- for _, rule := range policy.Rules {
- if !rule.Enabled {
- continue
- }
-
- for _, sourceGroup := range rule.Sources {
- group := account.GetGroup(sourceGroup)
- if group == nil {
- return false, fmt.Errorf("failed to check peer in policy source group: group not found")
- }
-
- if slices.Contains(group.Peers, peerID) {
- return true, nil
- }
- }
- }
-
- return false, nil
-}
-
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go
index 67760d55a..13152ed12 100644
--- a/management/server/posture_checks_test.go
+++ b/management/server/posture_checks_test.go
@@ -21,7 +21,7 @@ const (
)
func TestDefaultAccountManager_PostureCheck(t *testing.T) {
- am, err := createManager(t)
+ am, _, err := createManager(t)
if err != nil {
t.Error("failed to create account manager")
}
@@ -123,7 +123,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
}
func TestPostureCheckAccountPeersUpdate(t *testing.T) {
- manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
+ manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
g := []*types.Group{
{
@@ -147,9 +147,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
assert.NoError(t, err)
}
- updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
+ updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
- manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
+ updateManager.CloseChannel(context.Background(), peer1.ID)
})
postureCheckA := &posture.Checks{
@@ -359,9 +359,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
// Updating linked posture check to policy where destination has peers but source does not
// should trigger account peers update and send peer update
t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) {
- updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
+ updMsg1 := updateManager.CreateChannel(context.Background(), peer2.ID)
t.Cleanup(func() {
- manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
+ updateManager.CloseChannel(context.Background(), peer2.ID)
})
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
@@ -445,7 +445,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}
func TestArePostureCheckChangesAffectPeers(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
require.NoError(t, err, "failed to create account manager")
account, err := initTestPostureChecksAccount(manager)
diff --git a/management/server/route.go b/management/server/route.go
index 4510426bb..2b4f11d05 100644
--- a/management/server/route.go
+++ b/management/server/route.go
@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
- "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -372,103 +371,12 @@ func validateRouteGroups(ctx context.Context, transaction store.Store, accountID
return groupsMap, nil
}
-func toProtocolRoute(route *route.Route) *proto.Route {
- return &proto.Route{
- ID: string(route.ID),
- NetID: string(route.NetID),
- Network: route.Network.String(),
- Domains: route.Domains.ToPunycodeList(),
- NetworkType: int64(route.NetworkType),
- Peer: route.Peer,
- Metric: int64(route.Metric),
- Masquerade: route.Masquerade,
- KeepRoute: route.KeepRoute,
- SkipAutoApply: route.SkipAutoApply,
- }
-}
-
-func toProtocolRoutes(routes []*route.Route) []*proto.Route {
- protoRoutes := make([]*proto.Route, 0, len(routes))
- for _, r := range routes {
- protoRoutes = append(protoRoutes, toProtocolRoute(r))
- }
- return protoRoutes
-}
-
// getPlaceholderIP returns a placeholder IP address for the route if domains are used
func getPlaceholderIP() netip.Prefix {
// Using an IP from the documentation range to minimize impact in case older clients try to set a route
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
}
-func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
- result := make([]*proto.RouteFirewallRule, len(rules))
- for i := range rules {
- rule := rules[i]
- result[i] = &proto.RouteFirewallRule{
- SourceRanges: rule.SourceRanges,
- Action: getProtoAction(rule.Action),
- Destination: rule.Destination,
- Protocol: getProtoProtocol(rule.Protocol),
- PortInfo: getProtoPortInfo(rule),
- IsDynamic: rule.IsDynamic,
- Domains: rule.Domains.ToPunycodeList(),
- PolicyID: []byte(rule.PolicyID),
- RouteID: string(rule.RouteID),
- }
- }
-
- return result
-}
-
-// getProtoDirection converts the direction to proto.RuleDirection.
-func getProtoDirection(direction int) proto.RuleDirection {
- if direction == types.FirewallRuleDirectionOUT {
- return proto.RuleDirection_OUT
- }
- return proto.RuleDirection_IN
-}
-
-// getProtoAction converts the action to proto.RuleAction.
-func getProtoAction(action string) proto.RuleAction {
- if action == string(types.PolicyTrafficActionDrop) {
- return proto.RuleAction_DROP
- }
- return proto.RuleAction_ACCEPT
-}
-
-// getProtoProtocol converts the protocol to proto.RuleProtocol.
-func getProtoProtocol(protocol string) proto.RuleProtocol {
- switch types.PolicyRuleProtocolType(protocol) {
- case types.PolicyRuleProtocolALL:
- return proto.RuleProtocol_ALL
- case types.PolicyRuleProtocolTCP:
- return proto.RuleProtocol_TCP
- case types.PolicyRuleProtocolUDP:
- return proto.RuleProtocol_UDP
- case types.PolicyRuleProtocolICMP:
- return proto.RuleProtocol_ICMP
- default:
- return proto.RuleProtocol_UNKNOWN
- }
-}
-
-// getProtoPortInfo converts the port info to proto.PortInfo.
-func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
- var portInfo proto.PortInfo
- if rule.Port != 0 {
- portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
- } else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
- portInfo.PortSelection = &proto.PortInfo_Range_{
- Range: &proto.PortInfo_Range{
- Start: uint32(portRange.Start),
- End: uint32(portRange.End),
- },
- }
- }
- return &portInfo
-}
-
// areRouteChangesAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers.
func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {
diff --git a/management/server/route_test.go b/management/server/route_test.go
index 388db140c..a413d545b 100644
--- a/management/server/route_test.go
+++ b/management/server/route_test.go
@@ -14,6 +14,11 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
+ "github.com/netbirdio/netbird/management/internals/modules/peers"
+ ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
+ "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -432,7 +437,7 @@ func TestCreateRoute(t *testing.T) {
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
- am, err := createRouterManager(t)
+ am, _, err := createRouterManager(t)
if err != nil {
t.Error("failed to create account manager")
}
@@ -922,7 +927,7 @@ func TestSaveRoute(t *testing.T) {
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
- am, err := createRouterManager(t)
+ am, _, err := createRouterManager(t)
if err != nil {
t.Error("failed to create account manager")
}
@@ -1024,7 +1029,7 @@ func TestDeleteRoute(t *testing.T) {
Enabled: true,
}
- am, err := createRouterManager(t)
+ am, _, err := createRouterManager(t)
if err != nil {
t.Error("failed to create account manager")
}
@@ -1071,7 +1076,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
AccessControlGroups: []string{routeGroup1},
}
- am, err := createRouterManager(t)
+ am, _, err := createRouterManager(t)
if err != nil {
t.Error("failed to create account manager")
}
@@ -1163,7 +1168,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
AccessControlGroups: []string{routeGroup1},
}
- am, err := createRouterManager(t)
+ am, _, err := createRouterManager(t)
if err != nil {
t.Error("failed to create account manager")
}
@@ -1250,11 +1255,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1")
}
-func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
+func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) {
t.Helper()
store, err := createRouterStore(t)
if err != nil {
- return nil, err
+ return nil, nil, err
}
eventStore := &activity.InMemoryEventStore{}
@@ -1285,7 +1290,16 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store)
- return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
+ ctx := context.Background()
+ updateManager := update_channel.NewPeersUpdateManager(metrics)
+ requestBuffer := NewAccountRequestBuffer(ctx, store)
+ networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
+
+ am, err := BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
+ if err != nil {
+ return nil, nil, err
+ }
+ return am, updateManager, nil
}
func createRouterStore(t *testing.T) (store.Store, error) {
@@ -1948,7 +1962,7 @@ func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFi
}
func TestRouteAccountPeersUpdate(t *testing.T) {
- manager, err := createRouterManager(t)
+ manager, updateManager, err := createRouterManager(t)
require.NoError(t, err, "failed to create account manager")
account, err := initTestRouteAccount(t, manager)
@@ -1976,9 +1990,9 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
require.NoError(t, err, "failed to create group %s", group.Name)
}
- updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID)
+ updMsg := updateManager.CreateChannel(context.Background(), peer1ID)
t.Cleanup(func() {
- manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID)
+ updateManager.CloseChannel(context.Background(), peer1ID)
})
// Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update
diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go
index e55b33c94..bc361bbd7 100644
--- a/management/server/setupkey_test.go
+++ b/management/server/setupkey_test.go
@@ -18,7 +18,7 @@ import (
)
func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -93,7 +93,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
}
func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -198,7 +198,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
}
func TestGetSetupKeys(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -396,7 +396,7 @@ func TestSetupKey_Copy(t *testing.T) {
}
func TestSetupKeyAccountPeersUpdate(t *testing.T) {
- manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
+ manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
@@ -420,9 +420,9 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
require.NoError(t, err)
- updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
+ updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
- manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
+ updateManager.CloseChannel(context.Background(), peer1.ID)
})
var setupKey *types.SetupKey
@@ -465,7 +465,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
}
func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go
index 382d026c8..d2220d4b4 100644
--- a/management/server/store/sql_store.go
+++ b/management/server/store/sql_store.go
@@ -2,6 +2,7 @@ package store
import (
"context"
+ "database/sql"
"encoding/json"
"errors"
"fmt"
@@ -15,6 +16,8 @@ import (
"sync"
"time"
+ "github.com/jackc/pgx/v5"
+ "github.com/jackc/pgx/v5/pgxpool"
log "github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
@@ -24,7 +27,6 @@ import (
"gorm.io/gorm/logger"
nbdns "github.com/netbirdio/netbird/dns"
- nbcontext "github.com/netbirdio/netbird/management/server/context"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -46,6 +48,11 @@ const (
accountAndIDsQueryCondition = "account_id = ? AND id IN ?"
accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found"
+
+ pgMaxConnections = 30
+ pgMinConnections = 1
+ pgMaxConnLifetime = 60 * time.Minute
+ pgHealthCheckPeriod = 1 * time.Minute
)
// SqlStore represents an account storage backed by a Sql DB persisted to disk
@@ -55,6 +62,7 @@ type SqlStore struct {
metrics telemetry.AppMetrics
installationPK int
storeEngine types.Engine
+ pool *pgxpool.Pool
}
type installation struct {
@@ -76,12 +84,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
conns = runtime.NumCPU()
}
- switch storeEngine {
- case types.MysqlStoreEngine:
- if err := db.Exec("SET GLOBAL FOREIGN_KEY_CHECKS = 0").Error; err != nil {
- return nil, err
- }
- case types.SqliteStoreEngine:
+ if storeEngine == types.SqliteStoreEngine {
if err == nil {
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
}
@@ -89,8 +92,12 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
}
sql.SetMaxOpenConns(conns)
+ sql.SetMaxIdleConns(conns)
+ sql.SetConnMaxLifetime(time.Hour)
+ sql.SetConnMaxIdleTime(3 * time.Minute)
- log.WithContext(ctx).Infof("Set max open db connections to %d", conns)
+ log.WithContext(ctx).Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v",
+ conns, conns, time.Hour, 3*time.Minute)
if skipMigration {
log.WithContext(ctx).Infof("skipping migration")
@@ -162,7 +169,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
group.StoreGroupPeers()
}
- err := s.db.Transaction(func(tx *gorm.DB) error {
+ err := s.transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
@@ -257,7 +264,7 @@ func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID,
func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) error {
start := time.Now()
- err := s.db.Transaction(func(tx *gorm.DB) error {
+ err := s.transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
@@ -280,7 +287,7 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) er
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
- log.WithContext(ctx).Debugf("took %d ms to delete an account to the store", took.Milliseconds())
+ log.WithContext(ctx).Tracef("took %d ms to delete an account to the store", took.Milliseconds())
return err
}
@@ -307,7 +314,7 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
peerCopy := peer.Copy()
peerCopy.AccountID = accountID
- err := s.db.Transaction(func(tx *gorm.DB) error {
+ err := s.transaction(func(tx *gorm.DB) error {
// check if peer exists before saving
var peerID string
result := tx.Model(&nbpeer.Peer{}).Select("id").Take(&peerID, accountAndIDQueryCondition, accountID, peer.ID)
@@ -405,6 +412,18 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerW
return nil
}
+// ApproveAccountPeers marks all peers that currently require approval in the given account as approved.
+func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) {
+ result := s.db.Model(&nbpeer.Peer{}).
+ Where("account_id = ? AND peer_status_requires_approval = ?", accountID, true).
+ Update("peer_status_requires_approval", false)
+ if result.Error != nil {
+ return 0, status.Errorf(status.Internal, "failed to approve pending account peers: %v", result.Error)
+ }
+
+ return int(result.RowsAffected), nil
+}
+
// SaveUsers saves the given list of users to the database.
func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
if len(users) == 0 {
@@ -575,16 +594,13 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
- ctx, cancel := getDebuggingCtx(ctx)
- defer cancel()
-
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var user types.User
- result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID)
+ result := tx.Take(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
@@ -596,7 +612,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
}
func (s *SqlStore) DeleteUser(ctx context.Context, accountID, userID string) error {
- err := s.db.Transaction(func(tx *gorm.DB) error {
+ err := s.transaction(func(tx *gorm.DB) error {
result := tx.Delete(&types.PersonalAccessToken{}, "user_id = ?", userID)
if result.Error != nil {
return result.Error
@@ -774,6 +790,13 @@ func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types.
}
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) {
+ if s.pool != nil {
+ return s.getAccountPgx(ctx, accountID)
+ }
+ return s.getAccountGorm(ctx, accountID)
+}
+
+func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types.Account, error) {
start := time.Now()
defer func() {
elapsed := time.Since(start)
@@ -784,9 +807,19 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
var account types.Account
result := s.db.Model(&account).
- Omit("GroupsG").
- Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
- Preload(clause.Associations).
+ Preload("UsersG.PATsG"). // have to be specified as this is nested reference
+ Preload("Policies.Rules").
+ Preload("SetupKeysG").
+ Preload("PeersG").
+ Preload("UsersG").
+ Preload("GroupsG.GroupPeers").
+ Preload("RoutesG").
+ Preload("NameServerGroupsG").
+ Preload("PostureChecks").
+ Preload("Networks").
+ Preload("NetworkRouters").
+ Preload("NetworkResources").
+ Preload("Onboarding").
Take(&account, idQueryCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
@@ -796,70 +829,1154 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
return nil, status.NewGetAccountFromStoreError(result.Error)
}
- // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
- for i, policy := range account.Policies {
- var rules []*types.PolicyRule
- err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
- if err != nil {
- return nil, status.Errorf(status.NotFound, "rule not found")
- }
- account.Policies[i].Rules = rules
- }
-
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
for _, key := range account.SetupKeysG {
- account.SetupKeys[key.Key] = key.Copy()
+ if key.UpdatedAt.IsZero() {
+ key.UpdatedAt = key.CreatedAt
+ }
+ if key.AutoGroups == nil {
+ key.AutoGroups = []string{}
+ }
+ account.SetupKeys[key.Key] = &key
}
account.SetupKeysG = nil
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for _, peer := range account.PeersG {
- account.Peers[peer.ID] = peer.Copy()
+ account.Peers[peer.ID] = &peer
}
account.PeersG = nil
-
account.Users = make(map[string]*types.User, len(account.UsersG))
for _, user := range account.UsersG {
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs))
for _, pat := range user.PATsG {
- user.PATs[pat.ID] = pat.Copy()
+ pat.UserID = ""
+ user.PATs[pat.ID] = &pat
}
- account.Users[user.Id] = user.Copy()
+ if user.AutoGroups == nil {
+ user.AutoGroups = []string{}
+ }
+ account.Users[user.Id] = &user
+ user.PATsG = nil
}
account.UsersG = nil
-
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for _, group := range account.GroupsG {
- account.Groups[group.ID] = group.Copy()
+ group.Peers = make([]string, len(group.GroupPeers))
+ for i, gp := range group.GroupPeers {
+ group.Peers[i] = gp.PeerID
+ }
+ if group.Resources == nil {
+ group.Resources = []types.Resource{}
+ }
+ account.Groups[group.ID] = group
}
account.GroupsG = nil
- var groupPeers []types.GroupPeer
- s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID).
- Find(&groupPeers)
- for _, groupPeer := range groupPeers {
- if group, ok := account.Groups[groupPeer.GroupID]; ok {
- group.Peers = append(group.Peers, groupPeer.PeerID)
- } else {
- log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID)
+ account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
+ for _, route := range account.RoutesG {
+ account.Routes[route.ID] = &route
+ }
+ account.RoutesG = nil
+ account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
+ for _, ns := range account.NameServerGroupsG {
+ ns.AccountID = ""
+ if ns.NameServers == nil {
+ ns.NameServers = []nbdns.NameServer{}
+ }
+ if ns.Groups == nil {
+ ns.Groups = []string{}
+ }
+ if ns.Domains == nil {
+ ns.Domains = []string{}
+ }
+ account.NameServerGroups[ns.ID] = &ns
+ }
+ account.NameServerGroupsG = nil
+ account.InitOnce()
+ return &account, nil
+}
+
+func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.Account, error) {
+ account, err := s.getAccount(ctx, accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ var wg sync.WaitGroup
+ errChan := make(chan error, 12)
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ keys, err := s.getSetupKeys(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.SetupKeysG = keys
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ peers, err := s.getPeers(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.PeersG = peers
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ users, err := s.getUsers(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.UsersG = users
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ groups, err := s.getGroups(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.GroupsG = groups
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ policies, err := s.getPolicies(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.Policies = policies
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ routes, err := s.getRoutes(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.RoutesG = routes
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ nsgs, err := s.getNameServerGroups(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.NameServerGroupsG = nsgs
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ checks, err := s.getPostureChecks(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.PostureChecks = checks
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ networks, err := s.getNetworks(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.Networks = networks
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ routers, err := s.getNetworkRouters(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.NetworkRouters = routers
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ resources, err := s.getNetworkResources(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.NetworkResources = resources
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ err := s.getAccountOnboarding(ctx, accountID, account)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ }()
+
+ wg.Wait()
+ close(errChan)
+ for e := range errChan {
+ if e != nil {
+ return nil, e
}
}
- account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
- for _, route := range account.RoutesG {
- account.Routes[route.ID] = route.Copy()
+ var userIDs []string
+ for _, u := range account.UsersG {
+ userIDs = append(userIDs, u.Id)
+ }
+ var policyIDs []string
+ for _, p := range account.Policies {
+ policyIDs = append(policyIDs, p.ID)
+ }
+ var groupIDs []string
+ for _, g := range account.GroupsG {
+ groupIDs = append(groupIDs, g.ID)
+ }
+
+ wg.Add(3)
+ errChan = make(chan error, 3)
+
+ var pats []types.PersonalAccessToken
+ go func() {
+ defer wg.Done()
+ var err error
+ pats, err = s.getPersonalAccessTokens(ctx, userIDs)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ var rules []*types.PolicyRule
+ go func() {
+ defer wg.Done()
+ var err error
+ rules, err = s.getPolicyRules(ctx, policyIDs)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ var groupPeers []types.GroupPeer
+ go func() {
+ defer wg.Done()
+ var err error
+ groupPeers, err = s.getGroupPeers(ctx, groupIDs)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ wg.Wait()
+ close(errChan)
+ for e := range errChan {
+ if e != nil {
+ return nil, e
+ }
+ }
+
+ patsByUserID := make(map[string][]*types.PersonalAccessToken)
+ for i := range pats {
+ pat := &pats[i]
+ patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat)
+ pat.UserID = ""
+ }
+
+ rulesByPolicyID := make(map[string][]*types.PolicyRule)
+ for _, rule := range rules {
+ rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule)
+ }
+
+ peersByGroupID := make(map[string][]string)
+ for _, gp := range groupPeers {
+ peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
+ }
+
+ account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
+ for i := range account.SetupKeysG {
+ key := &account.SetupKeysG[i]
+ account.SetupKeys[key.Key] = key
+ }
+
+ account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
+ for i := range account.PeersG {
+ peer := &account.PeersG[i]
+ account.Peers[peer.ID] = peer
+ }
+
+ account.Users = make(map[string]*types.User, len(account.UsersG))
+ for i := range account.UsersG {
+ user := &account.UsersG[i]
+ user.PATs = make(map[string]*types.PersonalAccessToken)
+ if userPats, ok := patsByUserID[user.Id]; ok {
+ for j := range userPats {
+ pat := userPats[j]
+ user.PATs[pat.ID] = pat
+ }
+ }
+ account.Users[user.Id] = user
+ }
+
+ for i := range account.Policies {
+ policy := account.Policies[i]
+ if policyRules, ok := rulesByPolicyID[policy.ID]; ok {
+ policy.Rules = policyRules
+ }
+ }
+
+ account.Groups = make(map[string]*types.Group, len(account.GroupsG))
+ for i := range account.GroupsG {
+ group := account.GroupsG[i]
+ if peerIDs, ok := peersByGroupID[group.ID]; ok {
+ group.Peers = peerIDs
+ }
+ account.Groups[group.ID] = group
+ }
+
+ account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
+ for i := range account.RoutesG {
+ route := &account.RoutesG[i]
+ account.Routes[route.ID] = route
}
- account.RoutesG = nil
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
- for _, ns := range account.NameServerGroupsG {
- account.NameServerGroups[ns.ID] = ns.Copy()
+ for i := range account.NameServerGroupsG {
+ nsg := &account.NameServerGroupsG[i]
+ nsg.AccountID = ""
+ account.NameServerGroups[nsg.ID] = nsg
}
+
+ account.SetupKeysG = nil
+ account.PeersG = nil
+ account.UsersG = nil
+ account.GroupsG = nil
+ account.RoutesG = nil
account.NameServerGroupsG = nil
+ return account, nil
+}
+
+func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Account, error) {
+ var account types.Account
+ account.Network = &types.Network{}
+ const accountQuery = `
+ SELECT
+ id, created_by, created_at, domain, domain_category, is_domain_primary_account,
+ -- Embedded Network
+ network_identifier, network_net, network_dns, network_serial,
+ -- Embedded DNSSettings
+ dns_settings_disabled_management_groups,
+ -- Embedded Settings
+ settings_peer_login_expiration_enabled, settings_peer_login_expiration,
+ settings_peer_inactivity_expiration_enabled, settings_peer_inactivity_expiration,
+ settings_regular_users_view_blocked, settings_groups_propagation_enabled,
+ settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups,
+ settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range,
+ settings_lazy_connection_enabled,
+ -- Embedded ExtraSettings
+ settings_extra_peer_approval_enabled, settings_extra_user_approval_required,
+ settings_extra_integrated_validator, settings_extra_integrated_validator_groups
+ FROM accounts WHERE id = $1`
+
+ var (
+ sPeerLoginExpirationEnabled sql.NullBool
+ sPeerLoginExpiration sql.NullInt64
+ sPeerInactivityExpirationEnabled sql.NullBool
+ sPeerInactivityExpiration sql.NullInt64
+ sRegularUsersViewBlocked sql.NullBool
+ sGroupsPropagationEnabled sql.NullBool
+ sJWTGroupsEnabled sql.NullBool
+ sJWTGroupsClaimName sql.NullString
+ sJWTAllowGroups sql.NullString
+ sRoutingPeerDNSResolutionEnabled sql.NullBool
+ sDNSDomain sql.NullString
+ sNetworkRange sql.NullString
+ sLazyConnectionEnabled sql.NullBool
+ sExtraPeerApprovalEnabled sql.NullBool
+ sExtraUserApprovalRequired sql.NullBool
+ sExtraIntegratedValidator sql.NullString
+ sExtraIntegratedValidatorGroups sql.NullString
+ networkNet sql.NullString
+ dnsSettingsDisabledGroups sql.NullString
+ networkIdentifier sql.NullString
+ networkDns sql.NullString
+ networkSerial sql.NullInt64
+ createdAt sql.NullTime
+ )
+ err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan(
+ &account.Id, &account.CreatedBy, &createdAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount,
+ &networkIdentifier, &networkNet, &networkDns, &networkSerial,
+ &dnsSettingsDisabledGroups,
+ &sPeerLoginExpirationEnabled, &sPeerLoginExpiration,
+ &sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration,
+ &sRegularUsersViewBlocked, &sGroupsPropagationEnabled,
+ &sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups,
+ &sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange,
+ &sLazyConnectionEnabled,
+ &sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired,
+ &sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups,
+ )
+ if err != nil {
+ if errors.Is(err, pgx.ErrNoRows) {
+ return nil, status.NewAccountNotFoundError(accountID)
+ }
+ return nil, status.NewGetAccountFromStoreError(err)
+ }
+
+ account.Settings = &types.Settings{Extra: &types.ExtraSettings{}}
+ if networkNet.Valid {
+ _ = json.Unmarshal([]byte(networkNet.String), &account.Network.Net)
+ }
+ if createdAt.Valid {
+ account.CreatedAt = createdAt.Time
+ }
+ if dnsSettingsDisabledGroups.Valid {
+ _ = json.Unmarshal([]byte(dnsSettingsDisabledGroups.String), &account.DNSSettings.DisabledManagementGroups)
+ }
+ if networkIdentifier.Valid {
+ account.Network.Identifier = networkIdentifier.String
+ }
+ if networkDns.Valid {
+ account.Network.Dns = networkDns.String
+ }
+ if networkSerial.Valid {
+ account.Network.Serial = uint64(networkSerial.Int64)
+ }
+ if sPeerLoginExpirationEnabled.Valid {
+ account.Settings.PeerLoginExpirationEnabled = sPeerLoginExpirationEnabled.Bool
+ }
+ if sPeerLoginExpiration.Valid {
+ account.Settings.PeerLoginExpiration = time.Duration(sPeerLoginExpiration.Int64)
+ }
+ if sPeerInactivityExpirationEnabled.Valid {
+ account.Settings.PeerInactivityExpirationEnabled = sPeerInactivityExpirationEnabled.Bool
+ }
+ if sPeerInactivityExpiration.Valid {
+ account.Settings.PeerInactivityExpiration = time.Duration(sPeerInactivityExpiration.Int64)
+ }
+ if sRegularUsersViewBlocked.Valid {
+ account.Settings.RegularUsersViewBlocked = sRegularUsersViewBlocked.Bool
+ }
+ if sGroupsPropagationEnabled.Valid {
+ account.Settings.GroupsPropagationEnabled = sGroupsPropagationEnabled.Bool
+ }
+ if sJWTGroupsEnabled.Valid {
+ account.Settings.JWTGroupsEnabled = sJWTGroupsEnabled.Bool
+ }
+ if sJWTGroupsClaimName.Valid {
+ account.Settings.JWTGroupsClaimName = sJWTGroupsClaimName.String
+ }
+ if sRoutingPeerDNSResolutionEnabled.Valid {
+ account.Settings.RoutingPeerDNSResolutionEnabled = sRoutingPeerDNSResolutionEnabled.Bool
+ }
+ if sDNSDomain.Valid {
+ account.Settings.DNSDomain = sDNSDomain.String
+ }
+ if sLazyConnectionEnabled.Valid {
+ account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool
+ }
+ if sJWTAllowGroups.Valid {
+ _ = json.Unmarshal([]byte(sJWTAllowGroups.String), &account.Settings.JWTAllowGroups)
+ }
+ if sNetworkRange.Valid {
+ _ = json.Unmarshal([]byte(sNetworkRange.String), &account.Settings.NetworkRange)
+ }
+
+ if sExtraPeerApprovalEnabled.Valid {
+ account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool
+ }
+ if sExtraUserApprovalRequired.Valid {
+ account.Settings.Extra.UserApprovalRequired = sExtraUserApprovalRequired.Bool
+ }
+ if sExtraIntegratedValidator.Valid {
+ account.Settings.Extra.IntegratedValidator = sExtraIntegratedValidator.String
+ }
+ if sExtraIntegratedValidatorGroups.Valid {
+ _ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups)
+ }
+ account.InitOnce()
return &account, nil
}
+func (s *SqlStore) getSetupKeys(ctx context.Context, accountID string) ([]types.SetupKey, error) {
+ const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at,
+ revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1`
+ rows, err := s.pool.Query(ctx, query, accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) {
+ var sk types.SetupKey
+ var autoGroups []byte
+ var skCreatedAt, expiresAt, updatedAt, lastUsed sql.NullTime
+ var revoked, ephemeral, allowExtraDNSLabels sql.NullBool
+ var usedTimes, usageLimit sql.NullInt64
+
+ err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &skCreatedAt,
+ &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels)
+
+ if err == nil {
+ if expiresAt.Valid {
+ sk.ExpiresAt = &expiresAt.Time
+ }
+ if skCreatedAt.Valid {
+ sk.CreatedAt = skCreatedAt.Time
+ }
+ if updatedAt.Valid {
+ sk.UpdatedAt = updatedAt.Time
+ if sk.UpdatedAt.IsZero() {
+ sk.UpdatedAt = sk.CreatedAt
+ }
+ }
+ if lastUsed.Valid {
+ sk.LastUsed = &lastUsed.Time
+ }
+ if revoked.Valid {
+ sk.Revoked = revoked.Bool
+ }
+ if usedTimes.Valid {
+ sk.UsedTimes = int(usedTimes.Int64)
+ }
+ if usageLimit.Valid {
+ sk.UsageLimit = int(usageLimit.Int64)
+ }
+ if ephemeral.Valid {
+ sk.Ephemeral = ephemeral.Bool
+ }
+ if allowExtraDNSLabels.Valid {
+ sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool
+ }
+ if autoGroups != nil {
+ _ = json.Unmarshal(autoGroups, &sk.AutoGroups)
+ } else {
+ sk.AutoGroups = []string{}
+ }
+ }
+ return sk, err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return keys, nil
+}
+
+func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Peer, error) {
+ const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled,
+ inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname,
+ meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version,
+ meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer,
+ meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired,
+ peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name,
+ location_geo_name_id FROM peers WHERE account_id = $1`
+ rows, err := s.pool.Query(ctx, query, accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) {
+ var p nbpeer.Peer
+ p.Status = &nbpeer.PeerStatus{}
+ var (
+ lastLogin, createdAt sql.NullTime
+ sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool
+ peerStatusLastSeen sql.NullTime
+ peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool
+ ip, extraDNS, netAddr, env, flags, files, connIP []byte
+ metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString
+ metaOS, metaOSVersion, metaWtVersion, metaUIVersion, metaKernelVersion sql.NullString
+ metaSystemSerialNumber, metaSystemProductName, metaSystemManufacturer sql.NullString
+ locationCountryCode, locationCityName sql.NullString
+ locationGeoNameID sql.NullInt64
+ )
+
+ err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled,
+ &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &createdAt, &ephemeral, &extraDNS,
+ &allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform,
+ &metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr,
+ &metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files,
+ &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP,
+ &locationCountryCode, &locationCityName, &locationGeoNameID)
+
+ if err == nil {
+ if lastLogin.Valid {
+ p.LastLogin = &lastLogin.Time
+ }
+ if createdAt.Valid {
+ p.CreatedAt = createdAt.Time
+ }
+ if sshEnabled.Valid {
+ p.SSHEnabled = sshEnabled.Bool
+ }
+ if loginExpirationEnabled.Valid {
+ p.LoginExpirationEnabled = loginExpirationEnabled.Bool
+ }
+ if inactivityExpirationEnabled.Valid {
+ p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool
+ }
+ if ephemeral.Valid {
+ p.Ephemeral = ephemeral.Bool
+ }
+ if allowExtraDNSLabels.Valid {
+ p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool
+ }
+ if peerStatusLastSeen.Valid {
+ p.Status.LastSeen = peerStatusLastSeen.Time
+ }
+ if peerStatusConnected.Valid {
+ p.Status.Connected = peerStatusConnected.Bool
+ }
+ if peerStatusLoginExpired.Valid {
+ p.Status.LoginExpired = peerStatusLoginExpired.Bool
+ }
+ if peerStatusRequiresApproval.Valid {
+ p.Status.RequiresApproval = peerStatusRequiresApproval.Bool
+ }
+ if metaHostname.Valid {
+ p.Meta.Hostname = metaHostname.String
+ }
+ if metaGoOS.Valid {
+ p.Meta.GoOS = metaGoOS.String
+ }
+ if metaKernel.Valid {
+ p.Meta.Kernel = metaKernel.String
+ }
+ if metaCore.Valid {
+ p.Meta.Core = metaCore.String
+ }
+ if metaPlatform.Valid {
+ p.Meta.Platform = metaPlatform.String
+ }
+ if metaOS.Valid {
+ p.Meta.OS = metaOS.String
+ }
+ if metaOSVersion.Valid {
+ p.Meta.OSVersion = metaOSVersion.String
+ }
+ if metaWtVersion.Valid {
+ p.Meta.WtVersion = metaWtVersion.String
+ }
+ if metaUIVersion.Valid {
+ p.Meta.UIVersion = metaUIVersion.String
+ }
+ if metaKernelVersion.Valid {
+ p.Meta.KernelVersion = metaKernelVersion.String
+ }
+ if metaSystemSerialNumber.Valid {
+ p.Meta.SystemSerialNumber = metaSystemSerialNumber.String
+ }
+ if metaSystemProductName.Valid {
+ p.Meta.SystemProductName = metaSystemProductName.String
+ }
+ if metaSystemManufacturer.Valid {
+ p.Meta.SystemManufacturer = metaSystemManufacturer.String
+ }
+ if locationCountryCode.Valid {
+ p.Location.CountryCode = locationCountryCode.String
+ }
+ if locationCityName.Valid {
+ p.Location.CityName = locationCityName.String
+ }
+ if locationGeoNameID.Valid {
+ p.Location.GeoNameID = uint(locationGeoNameID.Int64)
+ }
+ if ip != nil {
+ _ = json.Unmarshal(ip, &p.IP)
+ }
+ if extraDNS != nil {
+ _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels)
+ }
+ if netAddr != nil {
+ _ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses)
+ }
+ if env != nil {
+ _ = json.Unmarshal(env, &p.Meta.Environment)
+ }
+ if flags != nil {
+ _ = json.Unmarshal(flags, &p.Meta.Flags)
+ }
+ if files != nil {
+ _ = json.Unmarshal(files, &p.Meta.Files)
+ }
+ if connIP != nil {
+ _ = json.Unmarshal(connIP, &p.Location.ConnectionIP)
+ }
+ }
+ return p, err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return peers, nil
+}
+
+func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) {
+ const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1`
+ rows, err := s.pool.Query(ctx, query, accountID)
+ if err != nil {
+ return nil, err
+ }
+ users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) {
+ var u types.User
+ var autoGroups []byte
+ var lastLogin, createdAt sql.NullTime
+ var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool
+ err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType)
+ if err == nil {
+ if lastLogin.Valid {
+ u.LastLogin = &lastLogin.Time
+ }
+ if createdAt.Valid {
+ u.CreatedAt = createdAt.Time
+ }
+ if isServiceUser.Valid {
+ u.IsServiceUser = isServiceUser.Bool
+ }
+ if nonDeletable.Valid {
+ u.NonDeletable = nonDeletable.Bool
+ }
+ if blocked.Valid {
+ u.Blocked = blocked.Bool
+ }
+ if pendingApproval.Valid {
+ u.PendingApproval = pendingApproval.Bool
+ }
+ if autoGroups != nil {
+ _ = json.Unmarshal(autoGroups, &u.AutoGroups)
+ } else {
+ u.AutoGroups = []string{}
+ }
+ }
+ return u, err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return users, nil
+}
+
+func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) {
+ const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1`
+ rows, err := s.pool.Query(ctx, query, accountID)
+ if err != nil {
+ return nil, err
+ }
+ groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) {
+ var g types.Group
+ var resources []byte
+ var refID sql.NullInt64
+ var refType sql.NullString
+ err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType)
+ if err == nil {
+ if refID.Valid {
+ g.IntegrationReference.ID = int(refID.Int64)
+ }
+ if refType.Valid {
+ g.IntegrationReference.IntegrationType = refType.String
+ }
+ if resources != nil {
+ _ = json.Unmarshal(resources, &g.Resources)
+ } else {
+ g.Resources = []types.Resource{}
+ }
+ g.GroupPeers = []types.GroupPeer{}
+ g.Peers = []string{}
+ }
+ return &g, err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return groups, nil
+}
+
+func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) {
+ const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1`
+ rows, err := s.pool.Query(ctx, query, accountID)
+ if err != nil {
+ return nil, err
+ }
+ policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) {
+ var p types.Policy
+ var checks []byte
+ var enabled sql.NullBool
+ err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks)
+ if err == nil {
+ if enabled.Valid {
+ p.Enabled = enabled.Bool
+ }
+ if checks != nil {
+ _ = json.Unmarshal(checks, &p.SourcePostureChecks)
+ }
+ }
+ return &p, err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return policies, nil
+}
+
+func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) {
+ const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1`
+ rows, err := s.pool.Query(ctx, query, accountID)
+ if err != nil {
+ return nil, err
+ }
+ routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) {
+ var r route.Route
+ var network, domains, peerGroups, groups, accessGroups []byte
+ var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool
+ var metric sql.NullInt64
+ err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply)
+ if err == nil {
+ if keepRoute.Valid {
+ r.KeepRoute = keepRoute.Bool
+ }
+ if masquerade.Valid {
+ r.Masquerade = masquerade.Bool
+ }
+ if enabled.Valid {
+ r.Enabled = enabled.Bool
+ }
+ if skipAutoApply.Valid {
+ r.SkipAutoApply = skipAutoApply.Bool
+ }
+ if metric.Valid {
+ r.Metric = int(metric.Int64)
+ }
+ if network != nil {
+ _ = json.Unmarshal(network, &r.Network)
+ }
+ if domains != nil {
+ _ = json.Unmarshal(domains, &r.Domains)
+ }
+ if peerGroups != nil {
+ _ = json.Unmarshal(peerGroups, &r.PeerGroups)
+ }
+ if groups != nil {
+ _ = json.Unmarshal(groups, &r.Groups)
+ }
+ if accessGroups != nil {
+ _ = json.Unmarshal(accessGroups, &r.AccessControlGroups)
+ }
+ }
+ return r, err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return routes, nil
+}
+
+func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) {
+ const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1`
+ rows, err := s.pool.Query(ctx, query, accountID)
+ if err != nil {
+ return nil, err
+ }
+ nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) {
+ var n nbdns.NameServerGroup
+ var ns, groups, domains []byte
+ var primary, enabled, searchDomainsEnabled sql.NullBool
+ err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled)
+ if err == nil {
+ if primary.Valid {
+ n.Primary = primary.Bool
+ }
+ if enabled.Valid {
+ n.Enabled = enabled.Bool
+ }
+ if searchDomainsEnabled.Valid {
+ n.SearchDomainsEnabled = searchDomainsEnabled.Bool
+ }
+ if ns != nil {
+ _ = json.Unmarshal(ns, &n.NameServers)
+ } else {
+ n.NameServers = []nbdns.NameServer{}
+ }
+ if groups != nil {
+ _ = json.Unmarshal(groups, &n.Groups)
+ } else {
+ n.Groups = []string{}
+ }
+ if domains != nil {
+ _ = json.Unmarshal(domains, &n.Domains)
+ } else {
+ n.Domains = []string{}
+ }
+ }
+ return n, err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return nsgs, nil
+}
+
+func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
+ const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1`
+ rows, err := s.pool.Query(ctx, query, accountID)
+ if err != nil {
+ return nil, err
+ }
+ checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) {
+ var c posture.Checks
+ var checksDef []byte
+ err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef)
+ if err == nil && checksDef != nil {
+ _ = json.Unmarshal(checksDef, &c.Checks)
+ }
+ return &c, err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return checks, nil
+}
+
+func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) {
+ const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1`
+ rows, err := s.pool.Query(ctx, query, accountID)
+ if err != nil {
+ return nil, err
+ }
+ networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network])
+ if err != nil {
+ return nil, err
+ }
+ result := make([]*networkTypes.Network, len(networks))
+ for i := range networks {
+ result[i] = &networks[i]
+ }
+ return result, nil
+}
+
+func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) {
+ const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1`
+ rows, err := s.pool.Query(ctx, query, accountID)
+ if err != nil {
+ return nil, err
+ }
+ routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) {
+ var r routerTypes.NetworkRouter
+ var peerGroups []byte
+ var masquerade, enabled sql.NullBool
+ var metric sql.NullInt64
+ err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled)
+ if err == nil {
+ if masquerade.Valid {
+ r.Masquerade = masquerade.Bool
+ }
+ if enabled.Valid {
+ r.Enabled = enabled.Bool
+ }
+ if metric.Valid {
+ r.Metric = int(metric.Int64)
+ }
+ if peerGroups != nil {
+ _ = json.Unmarshal(peerGroups, &r.PeerGroups)
+ }
+ }
+ return r, err
+ })
+ if err != nil {
+ return nil, err
+ }
+ result := make([]*routerTypes.NetworkRouter, len(routers))
+ for i := range routers {
+ result[i] = &routers[i]
+ }
+ return result, nil
+}
+
+func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) {
+ const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1`
+ rows, err := s.pool.Query(ctx, query, accountID)
+ if err != nil {
+ return nil, err
+ }
+ resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) {
+ var r resourceTypes.NetworkResource
+ var prefix []byte
+ var enabled sql.NullBool
+ err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled)
+ if err == nil {
+ if enabled.Valid {
+ r.Enabled = enabled.Bool
+ }
+ if prefix != nil {
+ _ = json.Unmarshal(prefix, &r.Prefix)
+ }
+ }
+ return r, err
+ })
+ if err != nil {
+ return nil, err
+ }
+ result := make([]*resourceTypes.NetworkResource, len(resources))
+ for i := range resources {
+ result[i] = &resources[i]
+ }
+ return result, nil
+}
+
+func (s *SqlStore) getAccountOnboarding(ctx context.Context, accountID string, account *types.Account) error {
+ const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1`
+ var onboardingFlowPending, signupFormPending sql.NullBool
+ var createdAt, updatedAt sql.NullTime
+ err := s.pool.QueryRow(ctx, query, accountID).Scan(
+ &account.Onboarding.AccountID,
+ &onboardingFlowPending,
+ &signupFormPending,
+ &createdAt,
+ &updatedAt,
+ )
+ if err != nil && !errors.Is(err, pgx.ErrNoRows) {
+ return err
+ }
+ if createdAt.Valid {
+ account.Onboarding.CreatedAt = createdAt.Time
+ }
+ if updatedAt.Valid {
+ account.Onboarding.UpdatedAt = updatedAt.Time
+ }
+ if onboardingFlowPending.Valid {
+ account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool
+ }
+ if signupFormPending.Valid {
+ account.Onboarding.SignupFormPending = signupFormPending.Bool
+ }
+ return nil
+}
+
+func (s *SqlStore) getPersonalAccessTokens(ctx context.Context, userIDs []string) ([]types.PersonalAccessToken, error) {
+ if len(userIDs) == 0 {
+ return nil, nil
+ }
+ const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)`
+ rows, err := s.pool.Query(ctx, query, userIDs)
+ if err != nil {
+ return nil, err
+ }
+ pats, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) {
+ var pat types.PersonalAccessToken
+ var expirationDate, lastUsed, createdAt sql.NullTime
+ err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &createdAt, &lastUsed)
+ if err == nil {
+ if expirationDate.Valid {
+ pat.ExpirationDate = &expirationDate.Time
+ }
+ if createdAt.Valid {
+ pat.CreatedAt = createdAt.Time
+ }
+ if lastUsed.Valid {
+ pat.LastUsed = &lastUsed.Time
+ }
+ }
+ return pat, err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return pats, nil
+}
+
+func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*types.PolicyRule, error) {
+ if len(policyIDs) == 0 {
+ return nil, nil
+ }
+ const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges, authorized_groups, authorized_user FROM policy_rules WHERE policy_id = ANY($1)`
+ rows, err := s.pool.Query(ctx, query, policyIDs)
+ if err != nil {
+ return nil, err
+ }
+ rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) {
+ var r types.PolicyRule
+ var dest, destRes, sources, sourceRes, ports, portRanges, authorizedGroups []byte
+ var enabled, bidirectional sql.NullBool
+ var authorizedUser sql.NullString
+ err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges, &authorizedGroups, &authorizedUser)
+ if err == nil {
+ if enabled.Valid {
+ r.Enabled = enabled.Bool
+ }
+ if bidirectional.Valid {
+ r.Bidirectional = bidirectional.Bool
+ }
+ if dest != nil {
+ _ = json.Unmarshal(dest, &r.Destinations)
+ }
+ if destRes != nil {
+ _ = json.Unmarshal(destRes, &r.DestinationResource)
+ }
+ if sources != nil {
+ _ = json.Unmarshal(sources, &r.Sources)
+ }
+ if sourceRes != nil {
+ _ = json.Unmarshal(sourceRes, &r.SourceResource)
+ }
+ if ports != nil {
+ _ = json.Unmarshal(ports, &r.Ports)
+ }
+ if portRanges != nil {
+ _ = json.Unmarshal(portRanges, &r.PortRanges)
+ }
+ if authorizedGroups != nil {
+ _ = json.Unmarshal(authorizedGroups, &r.AuthorizedGroups)
+ }
+ if authorizedUser.Valid {
+ r.AuthorizedUser = authorizedUser.String
+ }
+ }
+ return &r, err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return rules, nil
+}
+
+func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]types.GroupPeer, error) {
+ if len(groupIDs) == 0 {
+ return nil, nil
+ }
+ const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)`
+ rows, err := s.pool.Query(ctx, query, groupIDs)
+ if err != nil {
+ return nil, err
+ }
+ groupPeers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer])
+ if err != nil {
+ return nil, err
+ }
+ return groupPeers, nil
+}
+
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) {
var user types.User
result := s.db.Select("account_id").Take(&user, idQueryCondition, userID)
@@ -1050,16 +2167,13 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
- ctx, cancel := getDebuggingCtx(ctx)
- defer cancel()
-
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountNetwork types.AccountNetwork
- if err := tx.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
+ if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
@@ -1069,16 +2183,13 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
- ctx, cancel := getDebuggingCtx(ctx)
- defer cancel()
-
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peer nbpeer.Peer
- result := tx.WithContext(ctx).Take(&peer, GetKeyQueryCondition(s), peerKey)
+ result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -1127,11 +2238,8 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
- ctx, cancel := getDebuggingCtx(ctx)
- defer cancel()
-
var user types.User
- result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID)
+ result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewUserNotFoundError(userID)
@@ -1199,8 +2307,41 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
if err != nil {
return nil, err
}
+ pool, err := connectToPgDb(context.Background(), dsn)
+ if err != nil {
+ return nil, err
+ }
+ store, err := NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration)
+ if err != nil {
+ pool.Close()
+ return nil, err
+ }
+ store.pool = pool
+ return store, nil
+}
- return NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration)
+func connectToPgDb(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
+ config, err := pgxpool.ParseConfig(dsn)
+ if err != nil {
+ return nil, fmt.Errorf("unable to parse database config: %w", err)
+ }
+
+ config.MaxConns = pgMaxConnections
+ config.MinConns = pgMinConnections
+ config.MaxConnLifetime = pgMaxConnLifetime
+ config.HealthCheckPeriod = pgHealthCheckPeriod
+
+ pool, err := pgxpool.NewWithConfig(ctx, config)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create connection pool: %w", err)
+ }
+
+ if err := pool.Ping(ctx); err != nil {
+ pool.Close()
+ return nil, fmt.Errorf("unable to ping database: %w", err)
+ }
+
+ return pool, nil
}
// NewMysqlStore creates a new MySQL store.
@@ -1269,7 +2410,7 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data
// NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB.
func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
- store, err := NewPostgresqlStore(ctx, dsn, metrics, false)
+ store, err := NewPostgresqlStoreForTests(ctx, dsn, metrics, false)
if err != nil {
return nil, err
}
@@ -1289,6 +2430,50 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore,
return store, nil
}
+// used for tests only
+func NewPostgresqlStoreForTests(ctx context.Context, dsn string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) {
+ db, err := gorm.Open(postgres.Open(dsn), getGormConfig())
+ if err != nil {
+ return nil, err
+ }
+ pool, err := connectToPgDbForTests(context.Background(), dsn)
+ if err != nil {
+ return nil, err
+ }
+ store, err := NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration)
+ if err != nil {
+ pool.Close()
+ return nil, err
+ }
+ store.pool = pool
+ return store, nil
+}
+
+// used for tests only
+func connectToPgDbForTests(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
+ config, err := pgxpool.ParseConfig(dsn)
+ if err != nil {
+ return nil, fmt.Errorf("unable to parse database config: %w", err)
+ }
+
+ config.MaxConns = 5
+ config.MinConns = 1
+ config.MaxConnLifetime = 30 * time.Second
+ config.HealthCheckPeriod = 10 * time.Second
+
+ pool, err := pgxpool.NewWithConfig(ctx, config)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create connection pool: %w", err)
+ }
+
+ if err := pool.Ping(ctx); err != nil {
+ pool.Close()
+ return nil, fmt.Errorf("unable to ping database: %w", err)
+ }
+
+ return pool, nil
+}
+
// NewMysqlStoreFromSqlStore restores a store from SqlStore and stores MySQL DB.
func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
store, err := NewMysqlStore(ctx, dsn, metrics, false)
@@ -1312,16 +2497,13 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
}
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
- ctx, cancel := getDebuggingCtx(ctx)
- defer cancel()
-
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var setupKey types.SetupKey
- result := tx.WithContext(ctx).
+ result := tx.
Take(&setupKey, GetKeyQueryCondition(s), key)
if result.Error != nil {
@@ -1335,10 +2517,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
}
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
- ctx, cancel := getDebuggingCtx(ctx)
- defer cancel()
-
- result := s.db.WithContext(ctx).Model(&types.SetupKey{}).
+ result := s.db.Model(&types.SetupKey{}).
Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"),
@@ -1358,11 +2537,8 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
- ctx, cancel := getDebuggingCtx(ctx)
- defer cancel()
-
var groupID string
- _ = s.db.WithContext(ctx).Model(types.Group{}).
+ _ = s.db.Model(types.Group{}).
Select("id").
Where("account_id = ? AND name = ?", accountID, "All").
Limit(1).
@@ -1390,9 +2566,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
// AddPeerToGroup adds a peer to a group
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
- ctx, cancel := getDebuggingCtx(ctx)
- defer cancel()
-
peer := &types.GroupPeer{
AccountID: accountID,
GroupID: groupID,
@@ -1589,10 +2762,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
- ctx, cancel := getDebuggingCtx(ctx)
- defer cancel()
-
- if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
+ if err := s.db.Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
}
@@ -1718,10 +2888,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
- ctx, cancel := getDebuggingCtx(ctx)
- defer cancel()
-
- result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
+ result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to increment network serial count in store")
@@ -1735,6 +2902,33 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
if tx.Error != nil {
return tx.Error
}
+ defer func() {
+ if r := recover(); r != nil {
+ tx.Rollback()
+ panic(r)
+ }
+ }()
+
+ if s.storeEngine == types.PostgresStoreEngine {
+ if err := tx.Exec("SET LOCAL statement_timeout = '1min'").Error; err != nil {
+ tx.Rollback()
+ return fmt.Errorf("failed to set statement timeout: %w", err)
+ }
+ if err := tx.Exec("SET LOCAL lock_timeout = '1min'").Error; err != nil {
+ tx.Rollback()
+ return fmt.Errorf("failed to set lock timeout: %w", err)
+ }
+ }
+
+ // For MySQL, disable FK checks within this transaction to avoid deadlocks
+ // This is session-scoped and doesn't require SUPER privileges
+ if s.storeEngine == types.MysqlStoreEngine {
+ if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 0").Error; err != nil {
+ tx.Rollback()
+ return fmt.Errorf("failed to disable FK checks: %w", err)
+ }
+ }
+
repo := s.withTx(tx)
err := operation(repo)
if err != nil {
@@ -1742,6 +2936,14 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
return err
}
+ // Re-enable FK checks before commit (optional, as transaction end resets it)
+ if s.storeEngine == types.MysqlStoreEngine {
+ if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 1").Error; err != nil {
+ tx.Rollback()
+ return fmt.Errorf("failed to re-enable FK checks: %w", err)
+ }
+ }
+
err = tx.Commit().Error
log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime))
@@ -1759,6 +2961,31 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
}
}
+// transaction wraps a GORM transaction with MySQL-specific FK checks handling
+// Use this instead of db.Transaction() directly to avoid deadlocks on MySQL/Aurora
+func (s *SqlStore) transaction(fn func(*gorm.DB) error) error {
+ return s.db.Transaction(func(tx *gorm.DB) error {
+ // For MySQL, disable FK checks within this transaction to avoid deadlocks
+ // This is session-scoped and doesn't require SUPER privileges
+ if s.storeEngine == types.MysqlStoreEngine {
+ if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 0").Error; err != nil {
+ return fmt.Errorf("failed to disable FK checks: %w", err)
+ }
+ }
+
+ err := fn(tx)
+
+ // Re-enable FK checks before commit (optional, as transaction end resets it)
+ if s.storeEngine == types.MysqlStoreEngine && err == nil {
+ if fkErr := tx.Exec("SET FOREIGN_KEY_CHECKS = 1").Error; fkErr != nil {
+ return fmt.Errorf("failed to re-enable FK checks: %w", fkErr)
+ }
+ }
+
+ return err
+ })
+}
+
func (s *SqlStore) GetDB() *gorm.DB {
return s.db
}
@@ -2015,7 +3242,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error {
}
func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) error {
- return s.db.Transaction(func(tx *gorm.DB) error {
+ return s.transaction(func(tx *gorm.DB) error {
if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil {
return fmt.Errorf("delete policy rules: %w", err)
}
@@ -2783,36 +4010,6 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin
return groupPeers, nil
}
-func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFunc) {
- ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
- userID, ok := grpcCtx.Value(nbcontext.UserIDKey).(string)
- if ok {
- //nolint
- ctx = context.WithValue(ctx, nbcontext.UserIDKey, userID)
- }
-
- requestID, ok := grpcCtx.Value(nbcontext.RequestIDKey).(string)
- if ok {
- //nolint
- ctx = context.WithValue(ctx, nbcontext.RequestIDKey, requestID)
- }
-
- accountID, ok := grpcCtx.Value(nbcontext.AccountIDKey).(string)
- if ok {
- //nolint
- ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
- }
-
- go func() {
- select {
- case <-ctx.Done():
- case <-grpcCtx.Done():
- log.WithContext(grpcCtx).Warnf("grpc context ended early, error: %v", grpcCtx.Err())
- }
- }()
- return ctx, cancel
-}
-
func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
var info types.PrimaryAccountInfo
result := s.db.Model(&types.Account{}).
@@ -2852,7 +4049,7 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i
Network: &types.Network{Net: ipNet},
}
- result := s.db.WithContext(ctx).
+ result := s.db.
Model(&types.Account{}).
Where(idQueryCondition, accountID).
Updates(&patch)
diff --git a/management/server/store/sql_store_get_account_test.go b/management/server/store/sql_store_get_account_test.go
new file mode 100644
index 000000000..8ff04d68a
--- /dev/null
+++ b/management/server/store/sql_store_get_account_test.go
@@ -0,0 +1,1089 @@
+package store
+
+import (
+ "context"
+ "net"
+ "net/netip"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ nbdns "github.com/netbirdio/netbird/dns"
+ "github.com/netbirdio/netbird/management/server/integration_reference"
+ resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
+ routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
+ networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/posture"
+ "github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/route"
+)
+
+// TestGetAccount_ComprehensiveFieldValidation validates that GetAccount properly loads
+// all fields and nested objects from the database, including deeply nested structures.
+func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping comprehensive test in short mode")
+ }
+
+ ctx := context.Background()
+ store, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir())
+ require.NoError(t, err)
+ defer cleanup()
+
+ // Create comprehensive test data
+ accountID := "test-account-comprehensive"
+ userID1 := "user-1"
+ userID2 := "user-2"
+ peerID1 := "peer-1"
+ peerID2 := "peer-2"
+ peerID3 := "peer-3"
+ groupID1 := "group-1"
+ groupID2 := "group-2"
+ setupKeyID1 := "setup-key-1"
+ setupKeyID2 := "setup-key-2"
+ routeID1 := route.ID("route-1")
+ routeID2 := route.ID("route-2")
+ nsGroupID1 := "ns-group-1"
+ nsGroupID2 := "ns-group-2"
+ policyID1 := "policy-1"
+ policyID2 := "policy-2"
+ postureCheckID1 := "posture-check-1"
+ postureCheckID2 := "posture-check-2"
+ networkID1 := "network-1"
+ routerID1 := "router-1"
+ resourceID1 := "resource-1"
+ patID1 := "pat-1"
+ patID2 := "pat-2"
+ patID3 := "pat-3"
+
+ now := time.Now().UTC().Truncate(time.Second)
+ lastLogin := now.Add(-24 * time.Hour)
+ patLastUsed := now.Add(-1 * time.Hour)
+
+ // Build comprehensive account with all fields populated
+ account := &types.Account{
+ Id: accountID,
+ CreatedBy: userID1,
+ CreatedAt: now,
+ Domain: "example.com",
+ DomainCategory: "business",
+ IsDomainPrimaryAccount: true,
+ Network: &types.Network{
+ Identifier: "test-network",
+ Net: net.IPNet{
+ IP: net.ParseIP("100.64.0.0"),
+ Mask: net.CIDRMask(10, 32),
+ },
+ Dns: "test-dns",
+ Serial: 42,
+ },
+ DNSSettings: types.DNSSettings{
+ DisabledManagementGroups: []string{"dns-group-1", "dns-group-2"},
+ },
+ Settings: &types.Settings{
+ PeerLoginExpirationEnabled: true,
+ PeerLoginExpiration: time.Hour * 24 * 30,
+ GroupsPropagationEnabled: true,
+ JWTGroupsEnabled: true,
+ JWTGroupsClaimName: "groups",
+ JWTAllowGroups: []string{"allowed-group-1", "allowed-group-2"},
+ RegularUsersViewBlocked: false,
+ Extra: &types.ExtraSettings{
+ PeerApprovalEnabled: true,
+ IntegratedValidatorGroups: []string{"validator-1"},
+ },
+ },
+ }
+
+ // Create Setup Keys with all fields
+ setupKey1ExpiresAt := now.Add(30 * 24 * time.Hour)
+ setupKey1LastUsed := now.Add(-2 * time.Hour)
+ setupKey1 := &types.SetupKey{
+ Id: setupKeyID1,
+ AccountID: accountID,
+ Key: "setup-key-secret-1",
+ Name: "Setup Key 1",
+ Type: types.SetupKeyReusable,
+ CreatedAt: now,
+ UpdatedAt: now,
+ ExpiresAt: &setupKey1ExpiresAt,
+ Revoked: false,
+ UsedTimes: 5,
+ LastUsed: &setupKey1LastUsed,
+ AutoGroups: []string{groupID1, groupID2},
+ UsageLimit: 100,
+ Ephemeral: false,
+ }
+
+ setupKey2ExpiresAt := now.Add(7 * 24 * time.Hour)
+ setupKey2LastUsed := now.Add(-1 * time.Hour)
+ setupKey2 := &types.SetupKey{
+ Id: setupKeyID2,
+ AccountID: accountID,
+ Key: "setup-key-secret-2",
+ Name: "Setup Key 2 (One-off)",
+ Type: types.SetupKeyOneOff,
+ CreatedAt: now,
+ UpdatedAt: now,
+ ExpiresAt: &setupKey2ExpiresAt,
+ Revoked: true,
+ UsedTimes: 1,
+ LastUsed: &setupKey2LastUsed,
+ AutoGroups: []string{},
+ UsageLimit: 1,
+ Ephemeral: true,
+ }
+
+ account.SetupKeys = map[string]*types.SetupKey{
+ setupKey1.Key: setupKey1,
+ setupKey2.Key: setupKey2,
+ }
+
+ // Create Peers with comprehensive fields
+ peer1 := &nbpeer.Peer{
+ ID: peerID1,
+ AccountID: accountID,
+ Key: "peer-key-1-AAAA",
+ Name: "Peer 1",
+ IP: net.ParseIP("100.64.0.1"),
+ Meta: nbpeer.PeerSystemMeta{
+ Hostname: "peer1.example.com",
+ GoOS: "linux",
+ Kernel: "5.15.0",
+ Core: "x86_64",
+ Platform: "ubuntu",
+ OS: "Ubuntu 22.04",
+ WtVersion: "0.24.0",
+ UIVersion: "0.24.0",
+ KernelVersion: "5.15.0-78-generic",
+ OSVersion: "22.04",
+ NetworkAddresses: []nbpeer.NetworkAddress{
+ {NetIP: netip.MustParsePrefix("192.168.1.10/32"), Mac: "00:11:22:33:44:55"},
+ {NetIP: netip.MustParsePrefix("10.0.0.5/32"), Mac: "00:11:22:33:44:66"},
+ },
+ SystemSerialNumber: "ABC123",
+ SystemProductName: "Server Model X",
+ SystemManufacturer: "Dell Inc.",
+ },
+ Status: &nbpeer.PeerStatus{
+ LastSeen: now.Add(-5 * time.Minute),
+ Connected: true,
+ LoginExpired: false,
+ RequiresApproval: false,
+ },
+ Location: nbpeer.Location{
+ ConnectionIP: net.ParseIP("203.0.113.10"),
+ CountryCode: "US",
+ CityName: "San Francisco",
+ GeoNameID: 5391959,
+ },
+ SSHEnabled: true,
+ SSHKey: "ssh-rsa AAAAB3NzaC1...",
+ UserID: userID1,
+ LoginExpirationEnabled: true,
+ InactivityExpirationEnabled: false,
+ DNSLabel: "peer1",
+ CreatedAt: now.Add(-30 * 24 * time.Hour),
+ Ephemeral: false,
+ }
+
+ peer2 := &nbpeer.Peer{
+ ID: peerID2,
+ AccountID: accountID,
+ Key: "peer-key-2-BBBB",
+ Name: "Peer 2",
+ IP: net.ParseIP("100.64.0.2"),
+ Meta: nbpeer.PeerSystemMeta{
+ Hostname: "peer2.example.com",
+ GoOS: "darwin",
+ Kernel: "22.0.0",
+ Core: "arm64",
+ Platform: "darwin",
+ OS: "macOS Ventura",
+ WtVersion: "0.24.0",
+ UIVersion: "0.24.0",
+ },
+ Status: &nbpeer.PeerStatus{
+ LastSeen: now.Add(-1 * time.Hour),
+ Connected: false,
+ LoginExpired: true,
+ RequiresApproval: true,
+ },
+ Location: nbpeer.Location{
+ ConnectionIP: net.ParseIP("198.51.100.20"),
+ CountryCode: "GB",
+ CityName: "London",
+ GeoNameID: 2643743,
+ },
+ SSHEnabled: false,
+ UserID: userID2,
+ LoginExpirationEnabled: false,
+ InactivityExpirationEnabled: true,
+ DNSLabel: "peer2",
+ CreatedAt: now.Add(-15 * 24 * time.Hour),
+ Ephemeral: false,
+ }
+
+ peer3 := &nbpeer.Peer{
+ ID: peerID3,
+ AccountID: accountID,
+ Key: "peer-key-3-CCCC",
+ Name: "Peer 3 (Ephemeral)",
+ IP: net.ParseIP("100.64.0.3"),
+ Meta: nbpeer.PeerSystemMeta{
+ Hostname: "peer3.example.com",
+ GoOS: "windows",
+ Platform: "windows",
+ },
+ Status: &nbpeer.PeerStatus{
+ LastSeen: now.Add(-10 * time.Minute),
+ Connected: true,
+ },
+ DNSLabel: "peer3",
+ CreatedAt: now.Add(-1 * time.Hour),
+ Ephemeral: true,
+ }
+
+ account.Peers = map[string]*nbpeer.Peer{
+ peerID1: peer1,
+ peerID2: peer2,
+ peerID3: peer3,
+ }
+
+ // Create Users with PATs
+ pat1ExpirationDate := now.Add(90 * 24 * time.Hour)
+ pat1 := &types.PersonalAccessToken{
+ ID: patID1,
+ Name: "PAT 1",
+ HashedToken: "hashed-token-1",
+ ExpirationDate: &pat1ExpirationDate,
+ CreatedAt: now.Add(-10 * 24 * time.Hour),
+ CreatedBy: userID1,
+ LastUsed: &patLastUsed,
+ }
+
+ pat2ExpirationDate := now.Add(30 * 24 * time.Hour)
+ pat2 := &types.PersonalAccessToken{
+ ID: patID2,
+ Name: "PAT 2",
+ HashedToken: "hashed-token-2",
+ ExpirationDate: &pat2ExpirationDate,
+ CreatedAt: now.Add(-5 * 24 * time.Hour),
+ CreatedBy: userID1,
+ }
+
+ pat3ExpirationDate := now.Add(60 * 24 * time.Hour)
+ pat3 := &types.PersonalAccessToken{
+ ID: patID3,
+ Name: "PAT 3",
+ HashedToken: "hashed-token-3",
+ ExpirationDate: &pat3ExpirationDate,
+ CreatedAt: now.Add(-2 * 24 * time.Hour),
+ CreatedBy: userID2,
+ }
+
+ user1 := &types.User{
+ Id: userID1,
+ AccountID: accountID,
+ Role: types.UserRoleOwner,
+ IsServiceUser: false,
+ NonDeletable: true,
+ AutoGroups: []string{groupID1},
+ Issued: types.UserIssuedAPI,
+ IntegrationReference: integration_reference.IntegrationReference{
+ ID: 123,
+ IntegrationType: "azure_ad",
+ },
+ CreatedAt: now.Add(-60 * 24 * time.Hour),
+ LastLogin: &lastLogin,
+ Blocked: false,
+ PATs: map[string]*types.PersonalAccessToken{
+ patID1: pat1,
+ patID2: pat2,
+ },
+ }
+
+ user2 := &types.User{
+ Id: userID2,
+ AccountID: accountID,
+ Role: types.UserRoleAdmin,
+ IsServiceUser: true,
+ NonDeletable: false,
+ AutoGroups: []string{groupID2},
+ Issued: types.UserIssuedIntegration,
+ IntegrationReference: integration_reference.IntegrationReference{
+ ID: 456,
+ IntegrationType: "google_workspace",
+ },
+ CreatedAt: now.Add(-30 * 24 * time.Hour),
+ Blocked: false,
+ PATs: map[string]*types.PersonalAccessToken{
+ patID3: pat3,
+ },
+ }
+
+ account.Users = map[string]*types.User{
+ userID1: user1,
+ userID2: user2,
+ }
+
+ // Create Groups with peers and resources
+ group1 := &types.Group{
+ ID: groupID1,
+ AccountID: accountID,
+ Name: "Group 1",
+ Issued: types.GroupIssuedAPI,
+ Peers: []string{peerID1, peerID2},
+ Resources: []types.Resource{
+ {
+ ID: "resource-1",
+ Type: types.ResourceTypeHost,
+ },
+ },
+ }
+
+ group2 := &types.Group{
+ ID: groupID2,
+ AccountID: accountID,
+ Name: "Group 2",
+ Issued: types.GroupIssuedIntegration,
+ IntegrationReference: integration_reference.IntegrationReference{
+ ID: 789,
+ IntegrationType: "okta",
+ },
+ Peers: []string{peerID3},
+ Resources: []types.Resource{},
+ }
+
+ account.Groups = map[string]*types.Group{
+ groupID1: group1,
+ groupID2: group2,
+ }
+
+ // Create Policies with Rules
+ policy1 := &types.Policy{
+ ID: policyID1,
+ AccountID: accountID,
+ Name: "Policy 1",
+ Description: "Main access policy",
+ Enabled: true,
+ Rules: []*types.PolicyRule{
+ {
+ ID: "rule-1",
+ PolicyID: policyID1,
+ Name: "Rule 1",
+ Description: "Allow access",
+ Enabled: true,
+ Action: types.PolicyTrafficActionAccept,
+ Bidirectional: true,
+ Protocol: types.PolicyRuleProtocolALL,
+ Ports: []string{},
+ PortRanges: []types.RulePortRange{},
+ Sources: []string{groupID1},
+ Destinations: []string{groupID2},
+ },
+ {
+ ID: "rule-2",
+ PolicyID: policyID1,
+ Name: "Rule 2",
+ Description: "Block traffic on specific ports",
+ Enabled: true,
+ Action: types.PolicyTrafficActionDrop,
+ Bidirectional: false,
+ Protocol: types.PolicyRuleProtocolTCP,
+ Ports: []string{"22", "3389"},
+ PortRanges: []types.RulePortRange{
+ {Start: 8000, End: 8999},
+ },
+ Sources: []string{groupID2},
+ Destinations: []string{groupID1},
+ },
+ },
+ }
+
+ policy2 := &types.Policy{
+ ID: policyID2,
+ AccountID: accountID,
+ Name: "Policy 2",
+ Description: "Secondary policy",
+ Enabled: false,
+ Rules: []*types.PolicyRule{
+ {
+ ID: "rule-3",
+ PolicyID: policyID2,
+ Name: "Rule 3",
+ Description: "UDP access",
+ Enabled: false,
+ Action: types.PolicyTrafficActionAccept,
+ Bidirectional: true,
+ Protocol: types.PolicyRuleProtocolUDP,
+ Ports: []string{"53"},
+ Sources: []string{groupID1},
+ Destinations: []string{groupID1},
+ },
+ },
+ }
+
+ account.Policies = []*types.Policy{policy1, policy2}
+
+ // Create Routes
+ route1 := &route.Route{
+ ID: routeID1,
+ AccountID: accountID,
+ Network: netip.MustParsePrefix("10.0.0.0/24"),
+ NetworkType: route.IPv4Network,
+ Peer: peerID1,
+ PeerGroups: []string{},
+ Description: "Route 1",
+ NetID: "net-id-1",
+ Masquerade: true,
+ Metric: 9999,
+ Enabled: true,
+ Groups: []string{groupID1},
+ AccessControlGroups: []string{groupID2},
+ }
+
+ route2 := &route.Route{
+ ID: routeID2,
+ AccountID: accountID,
+ Network: netip.MustParsePrefix("192.168.1.0/24"),
+ NetworkType: route.IPv4Network,
+ Peer: "",
+ PeerGroups: []string{groupID2},
+ Description: "Route 2 (High Availability)",
+ NetID: "net-id-2",
+ Masquerade: false,
+ Metric: 100,
+ Enabled: true,
+ Groups: []string{groupID1, groupID2},
+ AccessControlGroups: []string{groupID1},
+ }
+
+ account.Routes = map[route.ID]*route.Route{
+ routeID1: route1,
+ routeID2: route2,
+ }
+
+ // Create NameServer Groups
+ nsGroup1 := &nbdns.NameServerGroup{
+ ID: nsGroupID1,
+ AccountID: accountID,
+ Name: "NS Group 1",
+ Description: "Primary nameservers",
+ NameServers: []nbdns.NameServer{
+ {
+ IP: netip.MustParseAddr("8.8.8.8"),
+ NSType: nbdns.UDPNameServerType,
+ Port: 53,
+ },
+ {
+ IP: netip.MustParseAddr("8.8.4.4"),
+ NSType: nbdns.UDPNameServerType,
+ Port: 53,
+ },
+ },
+ Groups: []string{groupID1, groupID2},
+ Domains: []string{"example.com", "test.com"},
+ Enabled: true,
+ Primary: true,
+ SearchDomainsEnabled: true,
+ }
+
+ nsGroup2 := &nbdns.NameServerGroup{
+ ID: nsGroupID2,
+ AccountID: accountID,
+ Name: "NS Group 2",
+ Description: "Secondary nameservers",
+ NameServers: []nbdns.NameServer{
+ {
+ IP: netip.MustParseAddr("1.1.1.1"),
+ NSType: nbdns.UDPNameServerType,
+ Port: 53,
+ },
+ },
+ Groups: []string{},
+ Domains: []string{},
+ Enabled: false,
+ Primary: false,
+ SearchDomainsEnabled: false,
+ }
+
+ account.NameServerGroups = map[string]*nbdns.NameServerGroup{
+ nsGroupID1: nsGroup1,
+ nsGroupID2: nsGroup2,
+ }
+
+ // Create Posture Checks
+ postureCheck1 := &posture.Checks{
+ ID: postureCheckID1,
+ AccountID: accountID,
+ Name: "Posture Check 1",
+ Description: "OS version check",
+ Checks: posture.ChecksDefinition{
+ NBVersionCheck: &posture.NBVersionCheck{
+ MinVersion: "0.24.0",
+ },
+ OSVersionCheck: &posture.OSVersionCheck{
+ Ios: &posture.MinVersionCheck{
+ MinVersion: "16.0",
+ },
+ Darwin: &posture.MinVersionCheck{
+ MinVersion: "22.0.0",
+ },
+ },
+ },
+ }
+
+ postureCheck2 := &posture.Checks{
+ ID: postureCheckID2,
+ AccountID: accountID,
+ Name: "Posture Check 2",
+ Description: "Geo location check",
+ Checks: posture.ChecksDefinition{
+ GeoLocationCheck: &posture.GeoLocationCheck{
+ Locations: []posture.Location{
+ {
+ CountryCode: "US",
+ CityName: "San Francisco",
+ },
+ {
+ CountryCode: "GB",
+ CityName: "London",
+ },
+ },
+ Action: "allow",
+ },
+ PeerNetworkRangeCheck: &posture.PeerNetworkRangeCheck{
+ Ranges: []netip.Prefix{
+ netip.MustParsePrefix("192.168.0.0/16"),
+ netip.MustParsePrefix("10.0.0.0/8"),
+ },
+ Action: "allow",
+ },
+ },
+ }
+
+ account.PostureChecks = []*posture.Checks{postureCheck1, postureCheck2}
+
+ // Create Networks
+ network1 := &networkTypes.Network{
+ ID: networkID1,
+ AccountID: accountID,
+ Name: "Network 1",
+ Description: "Primary network",
+ }
+
+ account.Networks = []*networkTypes.Network{network1}
+
+ // Create Network Routers
+ router1 := &routerTypes.NetworkRouter{
+ ID: routerID1,
+ AccountID: accountID,
+ NetworkID: networkID1,
+ Peer: peerID1,
+ PeerGroups: []string{},
+ Masquerade: true,
+ Metric: 100,
+ }
+
+ account.NetworkRouters = []*routerTypes.NetworkRouter{router1}
+
+ // Create Network Resources
+ resource1 := &resourceTypes.NetworkResource{
+ ID: resourceID1,
+ AccountID: accountID,
+ NetworkID: networkID1,
+ Name: "Resource 1",
+ Description: "Web server",
+ Prefix: netip.MustParsePrefix("192.168.1.100/32"),
+ Type: resourceTypes.Host,
+ }
+
+ account.NetworkResources = []*resourceTypes.NetworkResource{resource1}
+
+ // Create Onboarding
+ account.Onboarding = types.AccountOnboarding{
+ AccountID: accountID,
+ OnboardingFlowPending: true,
+ SignupFormPending: false,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+
+ // Save the account to the database
+ err = store.SaveAccount(ctx, account)
+ require.NoError(t, err, "Failed to save comprehensive test account")
+
+ // Retrieve the account from the database
+ retrievedAccount, err := store.GetAccount(ctx, accountID)
+ require.NoError(t, err, "Failed to retrieve account")
+ require.NotNil(t, retrievedAccount, "Retrieved account should not be nil")
+
+ // ========== VALIDATE TOP-LEVEL FIELDS ==========
+ t.Run("TopLevelFields", func(t *testing.T) {
+ assert.Equal(t, accountID, retrievedAccount.Id, "Account ID mismatch")
+ assert.Equal(t, userID1, retrievedAccount.CreatedBy, "CreatedBy mismatch")
+ assert.WithinDuration(t, now, retrievedAccount.CreatedAt, time.Second, "CreatedAt mismatch")
+ assert.Equal(t, "example.com", retrievedAccount.Domain, "Domain mismatch")
+ assert.Equal(t, "business", retrievedAccount.DomainCategory, "DomainCategory mismatch")
+ assert.True(t, retrievedAccount.IsDomainPrimaryAccount, "IsDomainPrimaryAccount should be true")
+ })
+
+ // ========== VALIDATE EMBEDDED NETWORK ==========
+ t.Run("EmbeddedNetwork", func(t *testing.T) {
+ require.NotNil(t, retrievedAccount.Network, "Network should not be nil")
+ assert.Equal(t, "test-network", retrievedAccount.Network.Identifier, "Network Identifier mismatch")
+ assert.Equal(t, "test-dns", retrievedAccount.Network.Dns, "Network DNS mismatch")
+ assert.Equal(t, uint64(42), retrievedAccount.Network.Serial, "Network Serial mismatch")
+
+ expectedIP := net.ParseIP("100.64.0.0")
+ assert.True(t, retrievedAccount.Network.Net.IP.Equal(expectedIP), "Network IP mismatch")
+ expectedMask := net.CIDRMask(10, 32)
+ assert.Equal(t, expectedMask, retrievedAccount.Network.Net.Mask, "Network Mask mismatch")
+ })
+
+ // ========== VALIDATE DNS SETTINGS ==========
+ t.Run("DNSSettings", func(t *testing.T) {
+ assert.Len(t, retrievedAccount.DNSSettings.DisabledManagementGroups, 2, "DisabledManagementGroups length mismatch")
+ assert.Contains(t, retrievedAccount.DNSSettings.DisabledManagementGroups, "dns-group-1", "Missing dns-group-1")
+ assert.Contains(t, retrievedAccount.DNSSettings.DisabledManagementGroups, "dns-group-2", "Missing dns-group-2")
+ })
+
+ // ========== VALIDATE SETTINGS ==========
+ t.Run("Settings", func(t *testing.T) {
+ require.NotNil(t, retrievedAccount.Settings, "Settings should not be nil")
+ assert.True(t, retrievedAccount.Settings.PeerLoginExpirationEnabled, "PeerLoginExpirationEnabled mismatch")
+ assert.Equal(t, time.Hour*24*30, retrievedAccount.Settings.PeerLoginExpiration, "PeerLoginExpiration mismatch")
+ assert.True(t, retrievedAccount.Settings.GroupsPropagationEnabled, "GroupsPropagationEnabled mismatch")
+ assert.True(t, retrievedAccount.Settings.JWTGroupsEnabled, "JWTGroupsEnabled mismatch")
+ assert.Equal(t, "groups", retrievedAccount.Settings.JWTGroupsClaimName, "JWTGroupsClaimName mismatch")
+ assert.Len(t, retrievedAccount.Settings.JWTAllowGroups, 2, "JWTAllowGroups length mismatch")
+ assert.Contains(t, retrievedAccount.Settings.JWTAllowGroups, "allowed-group-1")
+ assert.Contains(t, retrievedAccount.Settings.JWTAllowGroups, "allowed-group-2")
+ assert.False(t, retrievedAccount.Settings.RegularUsersViewBlocked, "RegularUsersViewBlocked mismatch")
+
+ // Validate Extra Settings
+ require.NotNil(t, retrievedAccount.Settings.Extra, "Extra settings should not be nil")
+ assert.True(t, retrievedAccount.Settings.Extra.PeerApprovalEnabled, "PeerApprovalEnabled mismatch")
+ assert.Len(t, retrievedAccount.Settings.Extra.IntegratedValidatorGroups, 1, "IntegratedValidatorGroups length mismatch")
+ assert.Equal(t, "validator-1", retrievedAccount.Settings.Extra.IntegratedValidatorGroups[0])
+ })
+
+ // ========== VALIDATE SETUP KEYS ==========
+ t.Run("SetupKeys", func(t *testing.T) {
+ require.Len(t, retrievedAccount.SetupKeys, 2, "Should have 2 setup keys")
+
+ // Validate Setup Key 1
+ sk1, exists := retrievedAccount.SetupKeys["setup-key-secret-1"]
+ require.True(t, exists, "Setup key 1 should exist")
+ assert.Equal(t, "Setup Key 1", sk1.Name, "Setup key 1 name mismatch")
+ assert.Equal(t, types.SetupKeyReusable, sk1.Type, "Setup key 1 type mismatch")
+ assert.False(t, sk1.Revoked, "Setup key 1 should not be revoked")
+ assert.Equal(t, 5, sk1.UsedTimes, "Setup key 1 used times mismatch")
+ assert.Equal(t, 100, sk1.UsageLimit, "Setup key 1 usage limit mismatch")
+ assert.False(t, sk1.Ephemeral, "Setup key 1 should not be ephemeral")
+ assert.Len(t, sk1.AutoGroups, 2, "Setup key 1 auto groups length mismatch")
+ assert.Contains(t, sk1.AutoGroups, groupID1)
+ assert.Contains(t, sk1.AutoGroups, groupID2)
+
+ // Validate Setup Key 2
+ sk2, exists := retrievedAccount.SetupKeys["setup-key-secret-2"]
+ require.True(t, exists, "Setup key 2 should exist")
+ assert.Equal(t, "Setup Key 2 (One-off)", sk2.Name, "Setup key 2 name mismatch")
+ assert.Equal(t, types.SetupKeyOneOff, sk2.Type, "Setup key 2 type mismatch")
+ assert.True(t, sk2.Revoked, "Setup key 2 should be revoked")
+ assert.Equal(t, 1, sk2.UsedTimes, "Setup key 2 used times mismatch")
+ assert.Equal(t, 1, sk2.UsageLimit, "Setup key 2 usage limit mismatch")
+ assert.True(t, sk2.Ephemeral, "Setup key 2 should be ephemeral")
+ assert.Len(t, sk2.AutoGroups, 0, "Setup key 2 should have empty auto groups")
+ })
+
+ // ========== VALIDATE PEERS ==========
+ t.Run("Peers", func(t *testing.T) {
+ require.Len(t, retrievedAccount.Peers, 3, "Should have 3 peers")
+
+ // Validate Peer 1
+ p1, exists := retrievedAccount.Peers[peerID1]
+ require.True(t, exists, "Peer 1 should exist")
+ assert.Equal(t, "Peer 1", p1.Name, "Peer 1 name mismatch")
+ assert.Equal(t, "peer-key-1-AAAA", p1.Key, "Peer 1 key mismatch")
+ assert.True(t, p1.IP.Equal(net.ParseIP("100.64.0.1")), "Peer 1 IP mismatch")
+ assert.Equal(t, userID1, p1.UserID, "Peer 1 user ID mismatch")
+ assert.True(t, p1.SSHEnabled, "Peer 1 SSH should be enabled")
+ assert.Equal(t, "ssh-rsa AAAAB3NzaC1...", p1.SSHKey, "Peer 1 SSH key mismatch")
+ assert.True(t, p1.LoginExpirationEnabled, "Peer 1 login expiration should be enabled")
+ assert.False(t, p1.Ephemeral, "Peer 1 should not be ephemeral")
+ assert.Equal(t, "peer1", p1.DNSLabel, "Peer 1 DNS label mismatch")
+
+ // Validate Peer 1 Meta
+ assert.Equal(t, "peer1.example.com", p1.Meta.Hostname, "Peer 1 hostname mismatch")
+ assert.Equal(t, "linux", p1.Meta.GoOS, "Peer 1 OS mismatch")
+ assert.Equal(t, "5.15.0", p1.Meta.Kernel, "Peer 1 kernel mismatch")
+ assert.Equal(t, "x86_64", p1.Meta.Core, "Peer 1 core mismatch")
+ assert.Equal(t, "ubuntu", p1.Meta.Platform, "Peer 1 platform mismatch")
+ assert.Equal(t, "Ubuntu 22.04", p1.Meta.OS, "Peer 1 OS version mismatch")
+ assert.Equal(t, "0.24.0", p1.Meta.WtVersion, "Peer 1 wt version mismatch")
+ assert.Equal(t, "ABC123", p1.Meta.SystemSerialNumber, "Peer 1 serial number mismatch")
+ assert.Equal(t, "Server Model X", p1.Meta.SystemProductName, "Peer 1 product name mismatch")
+ assert.Equal(t, "Dell Inc.", p1.Meta.SystemManufacturer, "Peer 1 manufacturer mismatch")
+
+ // Validate Network Addresses
+ assert.Len(t, p1.Meta.NetworkAddresses, 2, "Peer 1 should have 2 network addresses")
+ assert.Equal(t, netip.MustParsePrefix("192.168.1.10/32"), p1.Meta.NetworkAddresses[0].NetIP, "Network address 1 IP mismatch")
+ assert.Equal(t, "00:11:22:33:44:55", p1.Meta.NetworkAddresses[0].Mac, "Network address 1 MAC mismatch")
+ assert.Equal(t, netip.MustParsePrefix("10.0.0.5/32"), p1.Meta.NetworkAddresses[1].NetIP, "Network address 2 IP mismatch")
+ assert.Equal(t, "00:11:22:33:44:66", p1.Meta.NetworkAddresses[1].Mac, "Network address 2 MAC mismatch")
+
+ // Validate Peer 1 Status
+ require.NotNil(t, p1.Status, "Peer 1 status should not be nil")
+ assert.True(t, p1.Status.Connected, "Peer 1 should be connected")
+ assert.False(t, p1.Status.LoginExpired, "Peer 1 login should not be expired")
+ assert.False(t, p1.Status.RequiresApproval, "Peer 1 should not require approval")
+
+ // Validate Peer 1 Location
+ assert.True(t, p1.Location.ConnectionIP.Equal(net.ParseIP("203.0.113.10")), "Peer 1 connection IP mismatch")
+ assert.Equal(t, "US", p1.Location.CountryCode, "Peer 1 country code mismatch")
+ assert.Equal(t, "San Francisco", p1.Location.CityName, "Peer 1 city name mismatch")
+ assert.Equal(t, uint(5391959), p1.Location.GeoNameID, "Peer 1 geo name ID mismatch")
+
+ // Validate Peer 2
+ p2, exists := retrievedAccount.Peers[peerID2]
+ require.True(t, exists, "Peer 2 should exist")
+ assert.Equal(t, "Peer 2", p2.Name, "Peer 2 name mismatch")
+ assert.Equal(t, "peer-key-2-BBBB", p2.Key, "Peer 2 key mismatch")
+ assert.False(t, p2.SSHEnabled, "Peer 2 SSH should be disabled")
+ assert.False(t, p2.LoginExpirationEnabled, "Peer 2 login expiration should be disabled")
+ assert.True(t, p2.InactivityExpirationEnabled, "Peer 2 inactivity expiration should be enabled")
+
+ // Validate Peer 2 Status
+ require.NotNil(t, p2.Status, "Peer 2 status should not be nil")
+ assert.False(t, p2.Status.Connected, "Peer 2 should not be connected")
+ assert.True(t, p2.Status.LoginExpired, "Peer 2 login should be expired")
+ assert.True(t, p2.Status.RequiresApproval, "Peer 2 should require approval")
+
+ // Validate Peer 3 (Ephemeral)
+ p3, exists := retrievedAccount.Peers[peerID3]
+ require.True(t, exists, "Peer 3 should exist")
+ assert.True(t, p3.Ephemeral, "Peer 3 should be ephemeral")
+ assert.Equal(t, "Peer 3 (Ephemeral)", p3.Name, "Peer 3 name mismatch")
+ })
+
+ // ========== VALIDATE USERS ==========
+ t.Run("Users", func(t *testing.T) {
+ require.Len(t, retrievedAccount.Users, 2, "Should have 2 users")
+
+ // Validate User 1
+ u1, exists := retrievedAccount.Users[userID1]
+ require.True(t, exists, "User 1 should exist")
+ assert.Equal(t, types.UserRoleOwner, u1.Role, "User 1 role mismatch")
+ assert.False(t, u1.IsServiceUser, "User 1 should not be a service user")
+ assert.True(t, u1.NonDeletable, "User 1 should be non-deletable")
+ assert.Equal(t, types.UserIssuedAPI, u1.Issued, "User 1 issued type mismatch")
+ assert.Len(t, u1.AutoGroups, 1, "User 1 auto groups length mismatch")
+ assert.Contains(t, u1.AutoGroups, groupID1, "User 1 should have group1")
+ assert.False(t, u1.Blocked, "User 1 should not be blocked")
+ require.NotNil(t, u1.LastLogin, "User 1 last login should not be nil")
+ assert.WithinDuration(t, lastLogin, *u1.LastLogin, time.Second, "User 1 last login mismatch")
+
+ // Validate User 1 Integration Reference
+ assert.Equal(t, 123, u1.IntegrationReference.ID, "User 1 integration ID mismatch")
+ assert.Equal(t, "azure_ad", u1.IntegrationReference.IntegrationType, "User 1 integration type mismatch")
+
+ // Validate User 1 PATs
+ require.Len(t, u1.PATs, 2, "User 1 should have 2 PATs")
+
+ pat1Retrieved, exists := u1.PATs[patID1]
+ require.True(t, exists, "PAT 1 should exist")
+ assert.Equal(t, "PAT 1", pat1Retrieved.Name, "PAT 1 name mismatch")
+ assert.Equal(t, "hashed-token-1", pat1Retrieved.HashedToken, "PAT 1 hashed token mismatch")
+ require.NotNil(t, pat1Retrieved.LastUsed, "PAT 1 last used should not be nil")
+ assert.WithinDuration(t, patLastUsed, *pat1Retrieved.LastUsed, time.Second, "PAT 1 last used mismatch")
+ assert.Equal(t, userID1, pat1Retrieved.CreatedBy, "PAT 1 created by mismatch")
+ assert.Empty(t, pat1Retrieved.UserID, "PAT 1 UserID should be cleared")
+
+ pat2Retrieved, exists := u1.PATs[patID2]
+ require.True(t, exists, "PAT 2 should exist")
+ assert.Equal(t, "PAT 2", pat2Retrieved.Name, "PAT 2 name mismatch")
+ assert.Nil(t, pat2Retrieved.LastUsed, "PAT 2 last used should be nil")
+
+ // Validate User 2
+ u2, exists := retrievedAccount.Users[userID2]
+ require.True(t, exists, "User 2 should exist")
+ assert.Equal(t, types.UserRoleAdmin, u2.Role, "User 2 role mismatch")
+ assert.True(t, u2.IsServiceUser, "User 2 should be a service user")
+ assert.False(t, u2.NonDeletable, "User 2 should be deletable")
+ assert.Equal(t, types.UserIssuedIntegration, u2.Issued, "User 2 issued type mismatch")
+ assert.Equal(t, "google_workspace", u2.IntegrationReference.IntegrationType, "User 2 integration type mismatch")
+
+ // Validate User 2 PATs
+ require.Len(t, u2.PATs, 1, "User 2 should have 1 PAT")
+ pat3Retrieved, exists := u2.PATs[patID3]
+ require.True(t, exists, "PAT 3 should exist")
+ assert.Equal(t, "PAT 3", pat3Retrieved.Name, "PAT 3 name mismatch")
+ })
+
+ // ========== VALIDATE GROUPS ==========
+ t.Run("Groups", func(t *testing.T) {
+ require.Len(t, retrievedAccount.Groups, 2, "Should have 2 groups")
+
+ // Validate Group 1
+ g1, exists := retrievedAccount.Groups[groupID1]
+ require.True(t, exists, "Group 1 should exist")
+ assert.Equal(t, "Group 1", g1.Name, "Group 1 name mismatch")
+ assert.Equal(t, types.GroupIssuedAPI, g1.Issued, "Group 1 issued type mismatch")
+ assert.Len(t, g1.Peers, 2, "Group 1 should have 2 peers")
+ assert.Contains(t, g1.Peers, peerID1, "Group 1 should contain peer 1")
+ assert.Contains(t, g1.Peers, peerID2, "Group 1 should contain peer 2")
+
+ // Validate Group 1 Resources
+ assert.Len(t, g1.Resources, 1, "Group 1 should have 1 resource")
+ assert.Equal(t, "resource-1", g1.Resources[0].ID, "Group 1 resource ID mismatch")
+ assert.Equal(t, types.ResourceTypeHost, g1.Resources[0].Type, "Group 1 resource type mismatch")
+
+ // Validate Group 2
+ g2, exists := retrievedAccount.Groups[groupID2]
+ require.True(t, exists, "Group 2 should exist")
+ assert.Equal(t, "Group 2", g2.Name, "Group 2 name mismatch")
+ assert.Equal(t, types.GroupIssuedIntegration, g2.Issued, "Group 2 issued type mismatch")
+ assert.Len(t, g2.Peers, 1, "Group 2 should have 1 peer")
+ assert.Contains(t, g2.Peers, peerID3, "Group 2 should contain peer 3")
+ assert.Len(t, g2.Resources, 0, "Group 2 should have 0 resources")
+
+ // Validate Group 2 Integration Reference
+ assert.Equal(t, 789, g2.IntegrationReference.ID, "Group 2 integration ID mismatch")
+ assert.Equal(t, "okta", g2.IntegrationReference.IntegrationType, "Group 2 integration type mismatch")
+ })
+
+ // ========== VALIDATE POLICIES ==========
+ t.Run("Policies", func(t *testing.T) {
+ require.Len(t, retrievedAccount.Policies, 2, "Should have 2 policies")
+
+ // Validate Policy 1
+ pol1 := retrievedAccount.Policies[0]
+ if pol1.ID != policyID1 {
+ pol1 = retrievedAccount.Policies[1]
+ }
+ assert.Equal(t, policyID1, pol1.ID, "Policy 1 ID mismatch")
+ assert.Equal(t, "Policy 1", pol1.Name, "Policy 1 name mismatch")
+ assert.Equal(t, "Main access policy", pol1.Description, "Policy 1 description mismatch")
+ assert.True(t, pol1.Enabled, "Policy 1 should be enabled")
+
+ // Validate Policy 1 Rules
+ require.Len(t, pol1.Rules, 2, "Policy 1 should have 2 rules")
+
+ rule1 := pol1.Rules[0]
+ assert.Equal(t, "Rule 1", rule1.Name, "Rule 1 name mismatch")
+ assert.Equal(t, "Allow access", rule1.Description, "Rule 1 description mismatch")
+ assert.True(t, rule1.Enabled, "Rule 1 should be enabled")
+ assert.Equal(t, types.PolicyTrafficActionAccept, rule1.Action, "Rule 1 action mismatch")
+ assert.True(t, rule1.Bidirectional, "Rule 1 should be bidirectional")
+ assert.Equal(t, types.PolicyRuleProtocolALL, rule1.Protocol, "Rule 1 protocol mismatch")
+ assert.Len(t, rule1.Sources, 1, "Rule 1 sources length mismatch")
+ assert.Contains(t, rule1.Sources, groupID1, "Rule 1 should have group1 as source")
+ assert.Len(t, rule1.Destinations, 1, "Rule 1 destinations length mismatch")
+ assert.Contains(t, rule1.Destinations, groupID2, "Rule 1 should have group2 as destination")
+
+ rule2 := pol1.Rules[1]
+ assert.Equal(t, "Rule 2", rule2.Name, "Rule 2 name mismatch")
+ assert.Equal(t, types.PolicyTrafficActionDrop, rule2.Action, "Rule 2 action mismatch")
+ assert.False(t, rule2.Bidirectional, "Rule 2 should not be bidirectional")
+ assert.Equal(t, types.PolicyRuleProtocolTCP, rule2.Protocol, "Rule 2 protocol mismatch")
+ assert.Len(t, rule2.Ports, 2, "Rule 2 ports length mismatch")
+ assert.Contains(t, rule2.Ports, "22", "Rule 2 should have port 22")
+ assert.Contains(t, rule2.Ports, "3389", "Rule 2 should have port 3389")
+ assert.Len(t, rule2.PortRanges, 1, "Rule 2 port ranges length mismatch")
+ assert.Equal(t, uint16(8000), rule2.PortRanges[0].Start, "Rule 2 port range start mismatch")
+ assert.Equal(t, uint16(8999), rule2.PortRanges[0].End, "Rule 2 port range end mismatch")
+
+ // Validate Policy 2
+ pol2 := retrievedAccount.Policies[1]
+ if pol2.ID != policyID2 {
+ pol2 = retrievedAccount.Policies[0]
+ }
+ assert.Equal(t, policyID2, pol2.ID, "Policy 2 ID mismatch")
+ assert.Equal(t, "Policy 2", pol2.Name, "Policy 2 name mismatch")
+ assert.False(t, pol2.Enabled, "Policy 2 should be disabled")
+ require.Len(t, pol2.Rules, 1, "Policy 2 should have 1 rule")
+
+ rule3 := pol2.Rules[0]
+ assert.Equal(t, "Rule 3", rule3.Name, "Rule 3 name mismatch")
+ assert.False(t, rule3.Enabled, "Rule 3 should be disabled")
+ assert.Equal(t, types.PolicyRuleProtocolUDP, rule3.Protocol, "Rule 3 protocol mismatch")
+ })
+
+ // ========== VALIDATE ROUTES ==========
+ t.Run("Routes", func(t *testing.T) {
+ require.Len(t, retrievedAccount.Routes, 2, "Should have 2 routes")
+
+ // Validate Route 1
+ r1, exists := retrievedAccount.Routes[routeID1]
+ require.True(t, exists, "Route 1 should exist")
+ assert.Equal(t, "Route 1", r1.Description, "Route 1 description mismatch")
+ assert.Equal(t, route.IPv4Network, r1.NetworkType, "Route 1 network type mismatch")
+ assert.Equal(t, peerID1, r1.Peer, "Route 1 peer mismatch")
+ assert.Empty(t, r1.PeerGroups, "Route 1 peer groups should be empty")
+ assert.Equal(t, route.NetID("net-id-1"), r1.NetID, "Route 1 net ID mismatch")
+ assert.True(t, r1.Masquerade, "Route 1 masquerade should be enabled")
+ assert.Equal(t, 9999, r1.Metric, "Route 1 metric mismatch")
+ assert.True(t, r1.Enabled, "Route 1 should be enabled")
+ assert.Len(t, r1.Groups, 1, "Route 1 groups length mismatch")
+ assert.Contains(t, r1.Groups, groupID1, "Route 1 should have group1")
+ assert.Len(t, r1.AccessControlGroups, 1, "Route 1 ACL groups length mismatch")
+ assert.Contains(t, r1.AccessControlGroups, groupID2, "Route 1 should have group2 in ACL")
+
+ // Validate Route 1 Network CIDR
+ assert.Equal(t, "10.0.0.0/24", r1.Network.String(), "Route 1 network CIDR mismatch")
+
+ // Validate Route 2
+ r2, exists := retrievedAccount.Routes[routeID2]
+ require.True(t, exists, "Route 2 should exist")
+ assert.Equal(t, "Route 2 (High Availability)", r2.Description, "Route 2 description mismatch")
+ assert.Empty(t, r2.Peer, "Route 2 peer should be empty")
+ assert.Len(t, r2.PeerGroups, 1, "Route 2 peer groups length mismatch")
+ assert.Contains(t, r2.PeerGroups, groupID2, "Route 2 should have group2 as peer group")
+ assert.False(t, r2.Masquerade, "Route 2 masquerade should be disabled")
+ assert.Equal(t, 100, r2.Metric, "Route 2 metric mismatch")
+ assert.Equal(t, "192.168.1.0/24", r2.Network.String(), "Route 2 network CIDR mismatch")
+ })
+
+ // ========== VALIDATE NAME SERVER GROUPS ==========
+ t.Run("NameServerGroups", func(t *testing.T) {
+ require.Len(t, retrievedAccount.NameServerGroups, 2, "Should have 2 nameserver groups")
+
+ // Validate NS Group 1
+ nsg1, exists := retrievedAccount.NameServerGroups[nsGroupID1]
+ require.True(t, exists, "NS Group 1 should exist")
+ assert.Equal(t, "NS Group 1", nsg1.Name, "NS Group 1 name mismatch")
+ assert.Equal(t, "Primary nameservers", nsg1.Description, "NS Group 1 description mismatch")
+ assert.True(t, nsg1.Enabled, "NS Group 1 should be enabled")
+ assert.True(t, nsg1.Primary, "NS Group 1 should be primary")
+ assert.True(t, nsg1.SearchDomainsEnabled, "NS Group 1 search domains should be enabled")
+ assert.Empty(t, nsg1.AccountID, "NS Group 1 AccountID should be cleared")
+
+ // Validate NS Group 1 NameServers
+ require.Len(t, nsg1.NameServers, 2, "NS Group 1 should have 2 nameservers")
+ assert.Equal(t, netip.MustParseAddr("8.8.8.8"), nsg1.NameServers[0].IP, "NS Group 1 nameserver 1 IP mismatch")
+ assert.Equal(t, nbdns.UDPNameServerType, nsg1.NameServers[0].NSType, "NS Group 1 nameserver 1 type mismatch")
+ assert.Equal(t, 53, nsg1.NameServers[0].Port, "NS Group 1 nameserver 1 port mismatch")
+ assert.Equal(t, netip.MustParseAddr("8.8.4.4"), nsg1.NameServers[1].IP, "NS Group 1 nameserver 2 IP mismatch")
+
+ // Validate NS Group 1 Groups and Domains
+ assert.Len(t, nsg1.Groups, 2, "NS Group 1 groups length mismatch")
+ assert.Contains(t, nsg1.Groups, groupID1, "NS Group 1 should have group1")
+ assert.Contains(t, nsg1.Groups, groupID2, "NS Group 1 should have group2")
+ assert.Len(t, nsg1.Domains, 2, "NS Group 1 domains length mismatch")
+ assert.Contains(t, nsg1.Domains, "example.com", "NS Group 1 should have example.com domain")
+ assert.Contains(t, nsg1.Domains, "test.com", "NS Group 1 should have test.com domain")
+
+ // Validate NS Group 2
+ nsg2, exists := retrievedAccount.NameServerGroups[nsGroupID2]
+ require.True(t, exists, "NS Group 2 should exist")
+ assert.Equal(t, "NS Group 2", nsg2.Name, "NS Group 2 name mismatch")
+ assert.False(t, nsg2.Enabled, "NS Group 2 should be disabled")
+ assert.False(t, nsg2.Primary, "NS Group 2 should not be primary")
+ assert.False(t, nsg2.SearchDomainsEnabled, "NS Group 2 search domains should be disabled")
+ assert.Len(t, nsg2.NameServers, 1, "NS Group 2 should have 1 nameserver")
+ assert.Len(t, nsg2.Groups, 0, "NS Group 2 should have empty groups")
+ assert.Len(t, nsg2.Domains, 0, "NS Group 2 should have empty domains")
+ })
+
+ // ========== VALIDATE POSTURE CHECKS ==========
+ t.Run("PostureChecks", func(t *testing.T) {
+ require.Len(t, retrievedAccount.PostureChecks, 2, "Should have 2 posture checks")
+
+ // Find posture checks by ID
+ var pc1, pc2 *posture.Checks
+ for _, pc := range retrievedAccount.PostureChecks {
+ if pc.ID == postureCheckID1 {
+ pc1 = pc
+ } else if pc.ID == postureCheckID2 {
+ pc2 = pc
+ }
+ }
+
+ // Validate Posture Check 1
+ require.NotNil(t, pc1, "Posture check 1 should exist")
+ assert.Equal(t, "Posture Check 1", pc1.Name, "Posture check 1 name mismatch")
+ assert.Equal(t, "OS version check", pc1.Description, "Posture check 1 description mismatch")
+
+ // Validate NB Version Check
+ require.NotNil(t, pc1.Checks.NBVersionCheck, "NB version check should not be nil")
+ assert.Equal(t, "0.24.0", pc1.Checks.NBVersionCheck.MinVersion, "NB version check min version mismatch")
+
+ // Validate OS Version Check
+ require.NotNil(t, pc1.Checks.OSVersionCheck, "OS version check should not be nil")
+ require.NotNil(t, pc1.Checks.OSVersionCheck.Ios, "iOS version check should not be nil")
+ assert.Equal(t, "16.0", pc1.Checks.OSVersionCheck.Ios.MinVersion, "iOS min version mismatch")
+ require.NotNil(t, pc1.Checks.OSVersionCheck.Darwin, "Darwin version check should not be nil")
+ assert.Equal(t, "22.0.0", pc1.Checks.OSVersionCheck.Darwin.MinVersion, "Darwin min version mismatch")
+
+ // Validate Posture Check 2
+ require.NotNil(t, pc2, "Posture check 2 should exist")
+ assert.Equal(t, "Posture Check 2", pc2.Name, "Posture check 2 name mismatch")
+
+ // Validate Geo Location Check
+ require.NotNil(t, pc2.Checks.GeoLocationCheck, "Geo location check should not be nil")
+ assert.Equal(t, "allow", pc2.Checks.GeoLocationCheck.Action, "Geo location action mismatch")
+ assert.Len(t, pc2.Checks.GeoLocationCheck.Locations, 2, "Geo location check should have 2 locations")
+ assert.Equal(t, "US", pc2.Checks.GeoLocationCheck.Locations[0].CountryCode, "Location 1 country code mismatch")
+ assert.Equal(t, "San Francisco", pc2.Checks.GeoLocationCheck.Locations[0].CityName, "Location 1 city name mismatch")
+ assert.Equal(t, "GB", pc2.Checks.GeoLocationCheck.Locations[1].CountryCode, "Location 2 country code mismatch")
+ assert.Equal(t, "London", pc2.Checks.GeoLocationCheck.Locations[1].CityName, "Location 2 city name mismatch")
+
+ // Validate Peer Network Range Check
+ require.NotNil(t, pc2.Checks.PeerNetworkRangeCheck, "Peer network range check should not be nil")
+ assert.Equal(t, "allow", pc2.Checks.PeerNetworkRangeCheck.Action, "Peer network range action mismatch")
+ assert.Len(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, 2, "Peer network range check should have 2 ranges")
+ assert.Contains(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, netip.MustParsePrefix("192.168.0.0/16"), "Should have 192.168.0.0/16 range")
+ assert.Contains(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, netip.MustParsePrefix("10.0.0.0/8"), "Should have 10.0.0.0/8 range")
+ })
+
+ // ========== VALIDATE NETWORKS ==========
+ t.Run("Networks", func(t *testing.T) {
+ require.Len(t, retrievedAccount.Networks, 1, "Should have 1 network")
+
+ net1 := retrievedAccount.Networks[0]
+ assert.Equal(t, networkID1, net1.ID, "Network 1 ID mismatch")
+ assert.Equal(t, "Network 1", net1.Name, "Network 1 name mismatch")
+ assert.Equal(t, "Primary network", net1.Description, "Network 1 description mismatch")
+ })
+
+ // ========== VALIDATE NETWORK ROUTERS ==========
+ t.Run("NetworkRouters", func(t *testing.T) {
+ require.Len(t, retrievedAccount.NetworkRouters, 1, "Should have 1 network router")
+
+ router := retrievedAccount.NetworkRouters[0]
+ assert.Equal(t, routerID1, router.ID, "Router 1 ID mismatch")
+ assert.Equal(t, networkID1, router.NetworkID, "Router 1 network ID mismatch")
+ assert.Equal(t, peerID1, router.Peer, "Router 1 peer mismatch")
+ assert.Empty(t, router.PeerGroups, "Router 1 peer groups should be empty")
+ assert.True(t, router.Masquerade, "Router 1 masquerade should be enabled")
+ assert.Equal(t, 100, router.Metric, "Router 1 metric mismatch")
+ })
+
+ // ========== VALIDATE NETWORK RESOURCES ==========
+ t.Run("NetworkResources", func(t *testing.T) {
+ require.Len(t, retrievedAccount.NetworkResources, 1, "Should have 1 network resource")
+
+ res := retrievedAccount.NetworkResources[0]
+ assert.Equal(t, resourceID1, res.ID, "Resource 1 ID mismatch")
+ assert.Equal(t, networkID1, res.NetworkID, "Resource 1 network ID mismatch")
+ assert.Equal(t, "Resource 1", res.Name, "Resource 1 name mismatch")
+ assert.Equal(t, "Web server", res.Description, "Resource 1 description mismatch")
+ assert.Equal(t, netip.MustParsePrefix("192.168.1.100/32"), res.Prefix, "Resource 1 prefix mismatch")
+ assert.Equal(t, resourceTypes.Host, res.Type, "Resource 1 type mismatch")
+ })
+
+ // ========== VALIDATE ONBOARDING ==========
+ t.Run("Onboarding", func(t *testing.T) {
+ assert.Equal(t, accountID, retrievedAccount.Onboarding.AccountID, "Onboarding account ID mismatch")
+ assert.True(t, retrievedAccount.Onboarding.OnboardingFlowPending, "Onboarding flow should be pending")
+ assert.False(t, retrievedAccount.Onboarding.SignupFormPending, "Signup form should not be pending")
+ assert.WithinDuration(t, now, retrievedAccount.Onboarding.CreatedAt, time.Second, "Onboarding created at mismatch")
+ })
+
+ t.Log("✅ All comprehensive account field validations passed!")
+}
diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go
index d40c4664c..2e2623910 100644
--- a/management/server/store/sql_store_test.go
+++ b/management/server/store/sql_store_test.go
@@ -3717,3 +3717,80 @@ func TestSqlStore_GetPeersByGroupIDs(t *testing.T) {
})
}
}
+
+func TestSqlStore_ApproveAccountPeers(t *testing.T) {
+ runTestForAllEngines(t, "", func(t *testing.T, store Store) {
+ accountID := "test-account"
+ ctx := context.Background()
+
+ account := newAccountWithId(ctx, accountID, "testuser", "example.com")
+ err := store.SaveAccount(ctx, account)
+ require.NoError(t, err)
+
+ peers := []*nbpeer.Peer{
+ {
+ ID: "peer1",
+ AccountID: accountID,
+ DNSLabel: "peer1.netbird.cloud",
+ Key: "peer1-key",
+ IP: net.ParseIP("100.64.0.1"),
+ Status: &nbpeer.PeerStatus{
+ RequiresApproval: true,
+ LastSeen: time.Now().UTC(),
+ },
+ },
+ {
+ ID: "peer2",
+ AccountID: accountID,
+ DNSLabel: "peer2.netbird.cloud",
+ Key: "peer2-key",
+ IP: net.ParseIP("100.64.0.2"),
+ Status: &nbpeer.PeerStatus{
+ RequiresApproval: true,
+ LastSeen: time.Now().UTC(),
+ },
+ },
+ {
+ ID: "peer3",
+ AccountID: accountID,
+ DNSLabel: "peer3.netbird.cloud",
+ Key: "peer3-key",
+ IP: net.ParseIP("100.64.0.3"),
+ Status: &nbpeer.PeerStatus{
+ RequiresApproval: false,
+ LastSeen: time.Now().UTC(),
+ },
+ },
+ }
+
+ for _, peer := range peers {
+ err = store.AddPeerToAccount(ctx, peer)
+ require.NoError(t, err)
+ }
+
+ t.Run("approve all pending peers", func(t *testing.T) {
+ count, err := store.ApproveAccountPeers(ctx, accountID)
+ require.NoError(t, err)
+ assert.Equal(t, 2, count)
+
+ allPeers, err := store.GetAccountPeers(ctx, LockingStrengthNone, accountID, "", "")
+ require.NoError(t, err)
+
+ for _, peer := range allPeers {
+ assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval", peer.ID)
+ }
+ })
+
+ t.Run("no peers to approve", func(t *testing.T) {
+ count, err := store.ApproveAccountPeers(ctx, accountID)
+ require.NoError(t, err)
+ assert.Equal(t, 0, count)
+ })
+
+ t.Run("non-existent account", func(t *testing.T) {
+ count, err := store.ApproveAccountPeers(ctx, "non-existent")
+ require.NoError(t, err)
+ assert.Equal(t, 0, count)
+ })
+ })
+}
diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go
new file mode 100644
index 000000000..350a1da83
--- /dev/null
+++ b/management/server/store/sqlstore_bench_test.go
@@ -0,0 +1,951 @@
+package store
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "net/netip"
+ "sort"
+ "sync"
+ "testing"
+ "time"
+
+ "gorm.io/driver/postgres"
+ "gorm.io/gorm"
+ "gorm.io/gorm/clause"
+
+ "github.com/jackc/pgx/v5/pgxpool"
+ log "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
+
+ nbdns "github.com/netbirdio/netbird/dns"
+ resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
+ routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
+ networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/posture"
+ "github.com/netbirdio/netbird/management/server/testutil"
+ "github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/route"
+ "github.com/netbirdio/netbird/shared/management/status"
+)
+
+func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types.Account, error) {
+ start := time.Now()
+ defer func() {
+ elapsed := time.Since(start)
+ if elapsed > 1*time.Second {
+ log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed)
+ }
+ }()
+
+ var account types.Account
+ result := s.db.Model(&account).
+ Omit("GroupsG").
+ Preload("UsersG.PATsG"). // have to be specified as this is nested reference
+ Preload(clause.Associations).
+ Take(&account, idQueryCondition, accountID)
+ if result.Error != nil {
+ log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.NewAccountNotFoundError(accountID)
+ }
+ return nil, status.NewGetAccountFromStoreError(result.Error)
+ }
+
+ // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
+ for i, policy := range account.Policies {
+ var rules []*types.PolicyRule
+ err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
+ if err != nil {
+ return nil, status.Errorf(status.NotFound, "rule not found")
+ }
+ account.Policies[i].Rules = rules
+ }
+
+ account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
+ for _, key := range account.SetupKeysG {
+ account.SetupKeys[key.Key] = key.Copy()
+ }
+ account.SetupKeysG = nil
+
+ account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
+ for _, peer := range account.PeersG {
+ account.Peers[peer.ID] = peer.Copy()
+ }
+ account.PeersG = nil
+
+ account.Users = make(map[string]*types.User, len(account.UsersG))
+ for _, user := range account.UsersG {
+ user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs))
+ for _, pat := range user.PATsG {
+ user.PATs[pat.ID] = pat.Copy()
+ }
+ account.Users[user.Id] = user.Copy()
+ }
+ account.UsersG = nil
+
+ account.Groups = make(map[string]*types.Group, len(account.GroupsG))
+ for _, group := range account.GroupsG {
+ account.Groups[group.ID] = group.Copy()
+ }
+ account.GroupsG = nil
+
+ var groupPeers []types.GroupPeer
+ s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID).
+ Find(&groupPeers)
+ for _, groupPeer := range groupPeers {
+ if group, ok := account.Groups[groupPeer.GroupID]; ok {
+ group.Peers = append(group.Peers, groupPeer.PeerID)
+ } else {
+ log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID)
+ }
+ }
+
+ account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
+ for _, route := range account.RoutesG {
+ account.Routes[route.ID] = route.Copy()
+ }
+ account.RoutesG = nil
+
+ account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
+ for _, ns := range account.NameServerGroupsG {
+ account.NameServerGroups[ns.ID] = ns.Copy()
+ }
+ account.NameServerGroupsG = nil
+
+ return &account, nil
+}
+
+func (s *SqlStore) GetAccountGormOpt(ctx context.Context, accountID string) (*types.Account, error) {
+ start := time.Now()
+ defer func() {
+ elapsed := time.Since(start)
+ if elapsed > 1*time.Second {
+ log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed)
+ }
+ }()
+
+ var account types.Account
+ result := s.db.Model(&account).
+ Preload("UsersG.PATsG"). // have to be specified as this is nested reference
+ Preload("Policies.Rules").
+ Preload("SetupKeysG").
+ Preload("PeersG").
+ Preload("UsersG").
+ Preload("GroupsG.GroupPeers").
+ Preload("RoutesG").
+ Preload("NameServerGroupsG").
+ Preload("PostureChecks").
+ Preload("Networks").
+ Preload("NetworkRouters").
+ Preload("NetworkResources").
+ Preload("Onboarding").
+ Take(&account, idQueryCondition, accountID)
+ if result.Error != nil {
+ log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.NewAccountNotFoundError(accountID)
+ }
+ return nil, status.NewGetAccountFromStoreError(result.Error)
+ }
+
+ account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
+ for _, key := range account.SetupKeysG {
+ if key.UpdatedAt.IsZero() {
+ key.UpdatedAt = key.CreatedAt
+ }
+ if key.AutoGroups == nil {
+ key.AutoGroups = []string{}
+ }
+ account.SetupKeys[key.Key] = &key
+ }
+ account.SetupKeysG = nil
+
+ account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
+ for _, peer := range account.PeersG {
+ account.Peers[peer.ID] = &peer
+ }
+ account.PeersG = nil
+ account.Users = make(map[string]*types.User, len(account.UsersG))
+ for _, user := range account.UsersG {
+ user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs))
+ for _, pat := range user.PATsG {
+ pat.UserID = ""
+ user.PATs[pat.ID] = &pat
+ }
+ if user.AutoGroups == nil {
+ user.AutoGroups = []string{}
+ }
+ account.Users[user.Id] = &user
+ user.PATsG = nil
+ }
+ account.UsersG = nil
+ account.Groups = make(map[string]*types.Group, len(account.GroupsG))
+ for _, group := range account.GroupsG {
+ group.Peers = make([]string, len(group.GroupPeers))
+ for i, gp := range group.GroupPeers {
+ group.Peers[i] = gp.PeerID
+ }
+ if group.Resources == nil {
+ group.Resources = []types.Resource{}
+ }
+ account.Groups[group.ID] = group
+ }
+ account.GroupsG = nil
+
+ account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
+ for _, route := range account.RoutesG {
+ account.Routes[route.ID] = &route
+ }
+ account.RoutesG = nil
+ account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
+ for _, ns := range account.NameServerGroupsG {
+ ns.AccountID = ""
+ if ns.NameServers == nil {
+ ns.NameServers = []nbdns.NameServer{}
+ }
+ if ns.Groups == nil {
+ ns.Groups = []string{}
+ }
+ if ns.Domains == nil {
+ ns.Domains = []string{}
+ }
+ account.NameServerGroups[ns.ID] = &ns
+ }
+ account.NameServerGroupsG = nil
+ return &account, nil
+}
+
+func connectDBforTest(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
+ config, err := pgxpool.ParseConfig(dsn)
+ if err != nil {
+ return nil, fmt.Errorf("unable to parse database config: %w", err)
+ }
+
+ config.MaxConns = 12
+ config.MinConns = 2
+ config.MaxConnLifetime = time.Hour
+ config.HealthCheckPeriod = time.Minute
+
+ pool, err := pgxpool.NewWithConfig(ctx, config)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create connection pool: %w", err)
+ }
+
+ if err := pool.Ping(ctx); err != nil {
+ pool.Close()
+ return nil, fmt.Errorf("unable to ping database: %w", err)
+ }
+ return pool, nil
+}
+
+func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
+ cleanup, dsn, err := testutil.CreatePostgresTestContainer()
+ if err != nil {
+ b.Fatalf("failed to create test container: %v", err)
+ }
+
+ db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
+ if err != nil {
+ b.Fatalf("failed to connect database: %v", err)
+ }
+
+ pool, err := connectDBforTest(context.Background(), dsn)
+ if err != nil {
+ b.Fatalf("failed to connect database: %v", err)
+ }
+
+ models := []interface{}{
+ &types.Account{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{},
+ &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
+ &types.Policy{}, &types.PolicyRule{}, &route.Route{},
+ &nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
+ &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
+ &types.AccountOnboarding{},
+ }
+
+ for i := len(models) - 1; i >= 0; i-- {
+ err := db.Migrator().DropTable(models[i])
+ if err != nil {
+ b.Fatalf("failed to drop table: %v", err)
+ }
+ }
+
+ err = db.AutoMigrate(models...)
+ if err != nil {
+ b.Fatalf("failed to migrate database: %v", err)
+ }
+
+ store := &SqlStore{
+ db: db,
+ pool: pool,
+ }
+
+ const (
+ accountID = "benchmark-account-id"
+ numUsers = 20
+ numPatsPerUser = 3
+ numSetupKeys = 25
+ numPeers = 200
+ numGroups = 30
+ numPolicies = 50
+ numRulesPerPolicy = 10
+ numRoutes = 40
+ numNSGroups = 10
+ numPostureChecks = 15
+ numNetworks = 5
+ numNetworkRouters = 5
+ numNetworkResources = 10
+ )
+
+ _, ipNet, _ := net.ParseCIDR("100.64.0.0/10")
+ acc := types.Account{
+ Id: accountID,
+ CreatedBy: "benchmark-user",
+ CreatedAt: time.Now(),
+ Domain: "benchmark.com",
+ IsDomainPrimaryAccount: true,
+ Network: &types.Network{
+ Identifier: "benchmark-net",
+ Net: *ipNet,
+ Serial: 1,
+ },
+ DNSSettings: types.DNSSettings{
+ DisabledManagementGroups: []string{"group-disabled-1"},
+ },
+ Settings: &types.Settings{},
+ }
+ if err := db.Create(&acc).Error; err != nil {
+ b.Fatalf("create account: %v", err)
+ }
+
+ var setupKeys []types.SetupKey
+ for i := 0; i < numSetupKeys; i++ {
+ setupKeys = append(setupKeys, types.SetupKey{
+ Id: fmt.Sprintf("keyid-%d", i),
+ AccountID: accountID,
+ Key: fmt.Sprintf("key-%d", i),
+ Name: fmt.Sprintf("Benchmark Key %d", i),
+ ExpiresAt: &time.Time{},
+ })
+ }
+ if err := db.Create(&setupKeys).Error; err != nil {
+ b.Fatalf("create setup keys: %v", err)
+ }
+
+ var peers []nbpeer.Peer
+ for i := 0; i < numPeers; i++ {
+ peers = append(peers, nbpeer.Peer{
+ ID: fmt.Sprintf("peer-%d", i),
+ AccountID: accountID,
+ Key: fmt.Sprintf("peerkey-%d", i),
+ IP: net.ParseIP(fmt.Sprintf("100.64.0.%d", i+1)),
+ Name: fmt.Sprintf("peer-name-%d", i),
+ Status: &nbpeer.PeerStatus{Connected: i%2 == 0, LastSeen: time.Now()},
+ })
+ }
+ if err := db.Create(&peers).Error; err != nil {
+ b.Fatalf("create peers: %v", err)
+ }
+
+ for i := 0; i < numUsers; i++ {
+ userID := fmt.Sprintf("user-%d", i)
+ user := types.User{Id: userID, AccountID: accountID}
+ if err := db.Create(&user).Error; err != nil {
+ b.Fatalf("create user %s: %v", userID, err)
+ }
+
+ var pats []types.PersonalAccessToken
+ for j := 0; j < numPatsPerUser; j++ {
+ pats = append(pats, types.PersonalAccessToken{
+ ID: fmt.Sprintf("pat-%d-%d", i, j),
+ UserID: userID,
+ Name: fmt.Sprintf("PAT %d for User %d", j, i),
+ })
+ }
+ if err := db.Create(&pats).Error; err != nil {
+ b.Fatalf("create pats for user %s: %v", userID, err)
+ }
+ }
+
+ var groups []*types.Group
+ for i := 0; i < numGroups; i++ {
+ groups = append(groups, &types.Group{
+ ID: fmt.Sprintf("group-%d", i),
+ AccountID: accountID,
+ Name: fmt.Sprintf("Group %d", i),
+ })
+ }
+ if err := db.Create(&groups).Error; err != nil {
+ b.Fatalf("create groups: %v", err)
+ }
+
+ for i := 0; i < numPolicies; i++ {
+ policyID := fmt.Sprintf("policy-%d", i)
+ policy := types.Policy{ID: policyID, AccountID: accountID, Name: fmt.Sprintf("Policy %d", i), Enabled: true}
+ if err := db.Create(&policy).Error; err != nil {
+ b.Fatalf("create policy %s: %v", policyID, err)
+ }
+
+ var rules []*types.PolicyRule
+ for j := 0; j < numRulesPerPolicy; j++ {
+ rules = append(rules, &types.PolicyRule{
+ ID: fmt.Sprintf("rule-%d-%d", i, j),
+ PolicyID: policyID,
+ Name: fmt.Sprintf("Rule %d for Policy %d", j, i),
+ Enabled: true,
+ Protocol: "all",
+ })
+ }
+ if err := db.Create(&rules).Error; err != nil {
+ b.Fatalf("create rules for policy %s: %v", policyID, err)
+ }
+ }
+
+ var routes []route.Route
+ for i := 0; i < numRoutes; i++ {
+ routes = append(routes, route.Route{
+ ID: route.ID(fmt.Sprintf("route-%d", i)),
+ AccountID: accountID,
+ Description: fmt.Sprintf("Route %d", i),
+ Network: netip.MustParsePrefix(fmt.Sprintf("192.168.%d.0/24", i)),
+ Enabled: true,
+ })
+ }
+ if err := db.Create(&routes).Error; err != nil {
+ b.Fatalf("create routes: %v", err)
+ }
+
+ var nsGroups []nbdns.NameServerGroup
+ for i := 0; i < numNSGroups; i++ {
+ nsGroups = append(nsGroups, nbdns.NameServerGroup{
+ ID: fmt.Sprintf("nsg-%d", i),
+ AccountID: accountID,
+ Name: fmt.Sprintf("NS Group %d", i),
+ Description: "Benchmark NS Group",
+ Enabled: true,
+ })
+ }
+ if err := db.Create(&nsGroups).Error; err != nil {
+ b.Fatalf("create nsgroups: %v", err)
+ }
+
+ var postureChecks []*posture.Checks
+ for i := 0; i < numPostureChecks; i++ {
+ postureChecks = append(postureChecks, &posture.Checks{
+ ID: fmt.Sprintf("pc-%d", i),
+ AccountID: accountID,
+ Name: fmt.Sprintf("Posture Check %d", i),
+ })
+ }
+ if err := db.Create(&postureChecks).Error; err != nil {
+ b.Fatalf("create posture checks: %v", err)
+ }
+
+ var networks []*networkTypes.Network
+ for i := 0; i < numNetworks; i++ {
+ networks = append(networks, &networkTypes.Network{
+ ID: fmt.Sprintf("nettype-%d", i),
+ AccountID: accountID,
+ Name: fmt.Sprintf("Network Type %d", i),
+ })
+ }
+ if err := db.Create(&networks).Error; err != nil {
+ b.Fatalf("create networks: %v", err)
+ }
+
+ var networkRouters []*routerTypes.NetworkRouter
+ for i := 0; i < numNetworkRouters; i++ {
+ networkRouters = append(networkRouters, &routerTypes.NetworkRouter{
+ ID: fmt.Sprintf("router-%d", i),
+ AccountID: accountID,
+ NetworkID: networks[i%numNetworks].ID,
+ Peer: peers[i%numPeers].ID,
+ })
+ }
+ if err := db.Create(&networkRouters).Error; err != nil {
+ b.Fatalf("create network routers: %v", err)
+ }
+
+ var networkResources []*resourceTypes.NetworkResource
+ for i := 0; i < numNetworkResources; i++ {
+ networkResources = append(networkResources, &resourceTypes.NetworkResource{
+ ID: fmt.Sprintf("resource-%d", i),
+ AccountID: accountID,
+ NetworkID: networks[i%numNetworks].ID,
+ Name: fmt.Sprintf("Resource %d", i),
+ })
+ }
+ if err := db.Create(&networkResources).Error; err != nil {
+ b.Fatalf("create network resources: %v", err)
+ }
+
+ onboarding := types.AccountOnboarding{
+ AccountID: accountID,
+ OnboardingFlowPending: true,
+ }
+ if err := db.Create(&onboarding).Error; err != nil {
+ b.Fatalf("create onboarding: %v", err)
+ }
+
+ return store, cleanup, accountID
+}
+
+func BenchmarkGetAccount(b *testing.B) {
+ store, cleanup, accountID := setupBenchmarkDB(b)
+ defer cleanup()
+ ctx := context.Background()
+ b.ResetTimer()
+ b.ReportAllocs()
+ b.Run("old", func(b *testing.B) {
+ for range b.N {
+ _, err := store.GetAccountSlow(ctx, accountID)
+ if err != nil {
+ b.Fatalf("GetAccountSlow failed: %v", err)
+ }
+ }
+ })
+ b.Run("gorm opt", func(b *testing.B) {
+ for range b.N {
+ _, err := store.GetAccountGormOpt(ctx, accountID)
+ if err != nil {
+ b.Fatalf("GetAccountFast failed: %v", err)
+ }
+ }
+ })
+ b.Run("raw", func(b *testing.B) {
+ for range b.N {
+ _, err := store.GetAccount(ctx, accountID)
+ if err != nil {
+ b.Fatalf("GetAccountPureSQL failed: %v", err)
+ }
+ }
+ })
+ store.pool.Close()
+}
+
+func TestAccountEquivalence(t *testing.T) {
+ store, cleanup, accountID := setupBenchmarkDB(t)
+ defer cleanup()
+ ctx := context.Background()
+
+ type getAccountFunc func(context.Context, string) (*types.Account, error)
+
+ tests := []struct {
+ name string
+ expectedF getAccountFunc
+ actualF getAccountFunc
+ }{
+ {"old vs new", store.GetAccountSlow, store.GetAccountGormOpt},
+ {"old vs raw", store.GetAccountSlow, store.GetAccount},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ expected, errOld := tt.expectedF(ctx, accountID)
+ assert.NoError(t, errOld, "expected function should not return an error")
+ assert.NotNil(t, expected, "expected should not be nil")
+
+ actual, errNew := tt.actualF(ctx, accountID)
+ assert.NoError(t, errNew, "actual function should not return an error")
+ assert.NotNil(t, actual, "actual should not be nil")
+ testAccountEquivalence(t, expected, actual)
+ })
+ }
+
+ expected, errOld := store.GetAccountSlow(ctx, accountID)
+ assert.NoError(t, errOld, "GetAccountSlow should not return an error")
+ assert.NotNil(t, expected, "expected should not be nil")
+
+ actual, errNew := store.GetAccount(ctx, accountID)
+ assert.NoError(t, errNew, "GetAccount (new) should not return an error")
+ assert.NotNil(t, actual, "actual should not be nil")
+}
+
+func testAccountEquivalence(t *testing.T, expected, actual *types.Account) {
+ assert.Equal(t, expected.Id, actual.Id, "Account IDs should be equal")
+ assert.Equal(t, expected.CreatedBy, actual.CreatedBy, "Account CreatedBy fields should be equal")
+ assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second, "Account CreatedAt timestamps should be within a second")
+ assert.Equal(t, expected.Domain, actual.Domain, "Account Domains should be equal")
+ assert.Equal(t, expected.DomainCategory, actual.DomainCategory, "Account DomainCategories should be equal")
+ assert.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount, "Account IsDomainPrimaryAccount flags should be equal")
+ assert.Equal(t, expected.Network, actual.Network, "Embedded Account Network structs should be equal")
+ assert.Equal(t, expected.DNSSettings, actual.DNSSettings, "Embedded Account DNSSettings structs should be equal")
+ assert.Equal(t, expected.Onboarding, actual.Onboarding, "Embedded Account Onboarding structs should be equal")
+
+ assert.Len(t, actual.SetupKeys, len(expected.SetupKeys), "SetupKeys maps should have the same number of elements")
+ for key, oldVal := range expected.SetupKeys {
+ newVal, ok := actual.SetupKeys[key]
+ assert.True(t, ok, "SetupKey with key '%s' should exist in new account", key)
+ assert.Equal(t, *oldVal, *newVal, "SetupKey with key '%s' should be equal", key)
+ }
+
+ assert.Len(t, actual.Peers, len(expected.Peers), "Peers maps should have the same number of elements")
+ for key, oldVal := range expected.Peers {
+ newVal, ok := actual.Peers[key]
+ assert.True(t, ok, "Peer with ID '%s' should exist in new account", key)
+ assert.Equal(t, *oldVal, *newVal, "Peer with ID '%s' should be equal", key)
+ }
+
+ assert.Len(t, actual.Users, len(expected.Users), "Users maps should have the same number of elements")
+ for key, oldUser := range expected.Users {
+ newUser, ok := actual.Users[key]
+ assert.True(t, ok, "User with ID '%s' should exist in new account", key)
+
+ assert.Len(t, newUser.PATs, len(oldUser.PATs), "PATs map for user '%s' should have the same size", key)
+ for patKey, oldPAT := range oldUser.PATs {
+ newPAT, patOk := newUser.PATs[patKey]
+ assert.True(t, patOk, "PAT with ID '%s' for user '%s' should exist in new user object", patKey, key)
+ assert.Equal(t, *oldPAT, *newPAT, "PAT with ID '%s' for user '%s' should be equal", patKey, key)
+ }
+
+ oldUser.PATs = nil
+ newUser.PATs = nil
+ assert.Equal(t, *oldUser, *newUser, "User struct for ID '%s' (without PATs) should be equal", key)
+ }
+
+ assert.Len(t, actual.Groups, len(expected.Groups), "Groups maps should have the same number of elements")
+ for key, oldVal := range expected.Groups {
+ newVal, ok := actual.Groups[key]
+ assert.True(t, ok, "Group with ID '%s' should exist in new account", key)
+ sort.Strings(oldVal.Peers)
+ sort.Strings(newVal.Peers)
+ assert.Equal(t, *oldVal, *newVal, "Group with ID '%s' should be equal", key)
+ }
+
+ assert.Len(t, actual.Routes, len(expected.Routes), "Routes maps should have the same number of elements")
+ for key, oldVal := range expected.Routes {
+ newVal, ok := actual.Routes[key]
+ assert.True(t, ok, "Route with ID '%s' should exist in new account", key)
+ assert.Equal(t, *oldVal, *newVal, "Route with ID '%s' should be equal", key)
+ }
+
+ assert.Len(t, actual.NameServerGroups, len(expected.NameServerGroups), "NameServerGroups maps should have the same number of elements")
+ for key, oldVal := range expected.NameServerGroups {
+ newVal, ok := actual.NameServerGroups[key]
+ assert.True(t, ok, "NameServerGroup with ID '%s' should exist in new account", key)
+ assert.Equal(t, *oldVal, *newVal, "NameServerGroup with ID '%s' should be equal", key)
+ }
+
+ assert.Len(t, actual.Policies, len(expected.Policies), "Policies slices should have the same number of elements")
+ sort.Slice(expected.Policies, func(i, j int) bool { return expected.Policies[i].ID < expected.Policies[j].ID })
+ sort.Slice(actual.Policies, func(i, j int) bool { return actual.Policies[i].ID < actual.Policies[j].ID })
+ for i := range expected.Policies {
+ sort.Slice(expected.Policies[i].Rules, func(j, k int) bool { return expected.Policies[i].Rules[j].ID < expected.Policies[i].Rules[k].ID })
+ sort.Slice(actual.Policies[i].Rules, func(j, k int) bool { return actual.Policies[i].Rules[j].ID < actual.Policies[i].Rules[k].ID })
+ assert.Equal(t, *expected.Policies[i], *actual.Policies[i], "Policy with ID '%s' should be equal", expected.Policies[i].ID)
+ }
+
+ assert.Len(t, actual.PostureChecks, len(expected.PostureChecks), "PostureChecks slices should have the same number of elements")
+ sort.Slice(expected.PostureChecks, func(i, j int) bool { return expected.PostureChecks[i].ID < expected.PostureChecks[j].ID })
+ sort.Slice(actual.PostureChecks, func(i, j int) bool { return actual.PostureChecks[i].ID < actual.PostureChecks[j].ID })
+ for i := range expected.PostureChecks {
+ assert.Equal(t, *expected.PostureChecks[i], *actual.PostureChecks[i], "PostureCheck with ID '%s' should be equal", expected.PostureChecks[i].ID)
+ }
+
+ assert.Len(t, actual.Networks, len(expected.Networks), "Networks slices should have the same number of elements")
+ sort.Slice(expected.Networks, func(i, j int) bool { return expected.Networks[i].ID < expected.Networks[j].ID })
+ sort.Slice(actual.Networks, func(i, j int) bool { return actual.Networks[i].ID < actual.Networks[j].ID })
+ for i := range expected.Networks {
+ assert.Equal(t, *expected.Networks[i], *actual.Networks[i], "Network with ID '%s' should be equal", expected.Networks[i].ID)
+ }
+
+ assert.Len(t, actual.NetworkRouters, len(expected.NetworkRouters), "NetworkRouters slices should have the same number of elements")
+ sort.Slice(expected.NetworkRouters, func(i, j int) bool { return expected.NetworkRouters[i].ID < expected.NetworkRouters[j].ID })
+ sort.Slice(actual.NetworkRouters, func(i, j int) bool { return actual.NetworkRouters[i].ID < actual.NetworkRouters[j].ID })
+ for i := range expected.NetworkRouters {
+ assert.Equal(t, *expected.NetworkRouters[i], *actual.NetworkRouters[i], "NetworkRouter with ID '%s' should be equal", expected.NetworkRouters[i].ID)
+ }
+
+ assert.Len(t, actual.NetworkResources, len(expected.NetworkResources), "NetworkResources slices should have the same number of elements")
+ sort.Slice(expected.NetworkResources, func(i, j int) bool { return expected.NetworkResources[i].ID < expected.NetworkResources[j].ID })
+ sort.Slice(actual.NetworkResources, func(i, j int) bool { return actual.NetworkResources[i].ID < actual.NetworkResources[j].ID })
+ for i := range expected.NetworkResources {
+ assert.Equal(t, *expected.NetworkResources[i], *actual.NetworkResources[i], "NetworkResource with ID '%s' should be equal", expected.NetworkResources[i].ID)
+ }
+}
+
+func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) {
+ account, err := s.getAccount(ctx, accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ var wg sync.WaitGroup
+ errChan := make(chan error, 12)
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ keys, err := s.getSetupKeys(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.SetupKeysG = keys
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ peers, err := s.getPeers(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.PeersG = peers
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ users, err := s.getUsers(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.UsersG = users
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ groups, err := s.getGroups(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.GroupsG = groups
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ policies, err := s.getPolicies(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.Policies = policies
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ routes, err := s.getRoutes(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.RoutesG = routes
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ nsgs, err := s.getNameServerGroups(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.NameServerGroupsG = nsgs
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ checks, err := s.getPostureChecks(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.PostureChecks = checks
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ networks, err := s.getNetworks(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.Networks = networks
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ routers, err := s.getNetworkRouters(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.NetworkRouters = routers
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ resources, err := s.getNetworkResources(ctx, accountID)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ account.NetworkResources = resources
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ err := s.getAccountOnboarding(ctx, accountID, account)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ }()
+
+ wg.Wait()
+ close(errChan)
+ for e := range errChan {
+ if e != nil {
+ return nil, e
+ }
+ }
+
+ var userIDs []string
+ for _, u := range account.UsersG {
+ userIDs = append(userIDs, u.Id)
+ }
+ var policyIDs []string
+ for _, p := range account.Policies {
+ policyIDs = append(policyIDs, p.ID)
+ }
+ var groupIDs []string
+ for _, g := range account.GroupsG {
+ groupIDs = append(groupIDs, g.ID)
+ }
+
+ wg.Add(3)
+ errChan = make(chan error, 3)
+
+ var pats []types.PersonalAccessToken
+ go func() {
+ defer wg.Done()
+ var err error
+ pats, err = s.getPersonalAccessTokens(ctx, userIDs)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ var rules []*types.PolicyRule
+ go func() {
+ defer wg.Done()
+ var err error
+ rules, err = s.getPolicyRules(ctx, policyIDs)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ var groupPeers []types.GroupPeer
+ go func() {
+ defer wg.Done()
+ var err error
+ groupPeers, err = s.getGroupPeers(ctx, groupIDs)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ wg.Wait()
+ close(errChan)
+ for e := range errChan {
+ if e != nil {
+ return nil, e
+ }
+ }
+
+ patsByUserID := make(map[string][]*types.PersonalAccessToken)
+ for i := range pats {
+ pat := &pats[i]
+ patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat)
+ pat.UserID = ""
+ }
+
+ rulesByPolicyID := make(map[string][]*types.PolicyRule)
+ for _, rule := range rules {
+ rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule)
+ }
+
+ peersByGroupID := make(map[string][]string)
+ for _, gp := range groupPeers {
+ peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
+ }
+
+ account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
+ for i := range account.SetupKeysG {
+ key := &account.SetupKeysG[i]
+ account.SetupKeys[key.Key] = key
+ }
+
+ account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
+ for i := range account.PeersG {
+ peer := &account.PeersG[i]
+ account.Peers[peer.ID] = peer
+ }
+
+ account.Users = make(map[string]*types.User, len(account.UsersG))
+ for i := range account.UsersG {
+ user := &account.UsersG[i]
+ user.PATs = make(map[string]*types.PersonalAccessToken)
+ if userPats, ok := patsByUserID[user.Id]; ok {
+ for j := range userPats {
+ pat := userPats[j]
+ user.PATs[pat.ID] = pat
+ }
+ }
+ account.Users[user.Id] = user
+ }
+
+ for i := range account.Policies {
+ policy := account.Policies[i]
+ if policyRules, ok := rulesByPolicyID[policy.ID]; ok {
+ policy.Rules = policyRules
+ }
+ }
+
+ account.Groups = make(map[string]*types.Group, len(account.GroupsG))
+ for i := range account.GroupsG {
+ group := account.GroupsG[i]
+ if peerIDs, ok := peersByGroupID[group.ID]; ok {
+ group.Peers = peerIDs
+ }
+ account.Groups[group.ID] = group
+ }
+
+ account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
+ for i := range account.RoutesG {
+ route := &account.RoutesG[i]
+ account.Routes[route.ID] = route
+ }
+
+ account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
+ for i := range account.NameServerGroupsG {
+ nsg := &account.NameServerGroupsG[i]
+ nsg.AccountID = ""
+ account.NameServerGroups[nsg.ID] = nsg
+ }
+
+ account.SetupKeysG = nil
+ account.PeersG = nil
+ account.UsersG = nil
+ account.GroupsG = nil
+ account.RoutesG = nil
+ account.NameServerGroupsG = nil
+
+ return account, nil
+}
diff --git a/management/server/store/store.go b/management/server/store/store.go
index 21b660d96..0ec7949f9 100644
--- a/management/server/store/store.go
+++ b/management/server/store/store.go
@@ -143,6 +143,7 @@ type Store interface {
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
+ ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
DeletePeer(ctx context.Context, accountID string, peerID string) error
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error)
@@ -468,6 +469,9 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine)
closeConnection := func() {
cleanup()
store.Close(ctx)
+ if store.pool != nil {
+ store.pool.Close()
+ }
}
return store, closeConnection, nil
@@ -487,12 +491,18 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Eng
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
}
- db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
+ db, err := openDBWithRetry(dsn, kind, 5)
if err != nil {
return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err)
}
dsn, cleanup, err := createRandomDB(dsn, db, kind)
+
+ sqlDB, _ := db.DB()
+ if sqlDB != nil {
+ sqlDB.Close()
+ }
+
if err != nil {
return nil, nil, err
}
@@ -519,12 +529,22 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine
return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv)
}
- db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
+ db, err := openDBWithRetry(dsn, kind, 5)
if err != nil {
return nil, nil, fmt.Errorf("failed to open mysql connection: %v", err)
}
+ sqlDB, err := db.DB()
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to get underlying sql.DB: %v", err)
+ }
+ sqlDB.SetMaxOpenConns(1)
+ sqlDB.SetMaxIdleConns(1)
+
dsn, cleanup, err := createRandomDB(dsn, db, kind)
+
+ sqlDB.Close()
+
if err != nil {
return nil, nil, err
}
@@ -537,6 +557,31 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine
return store, cleanup, nil
}
+func openDBWithRetry(dsn string, engine types.Engine, maxRetries int) (*gorm.DB, error) {
+ var db *gorm.DB
+ var err error
+
+ for i := range maxRetries {
+ switch engine {
+ case types.PostgresStoreEngine:
+ db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
+ case types.MysqlStoreEngine:
+ db, err = gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
+ }
+
+ if err == nil {
+ return db, nil
+ }
+
+ if i < maxRetries-1 {
+ waitTime := time.Duration(100*(i+1)) * time.Millisecond
+ time.Sleep(waitTime)
+ }
+ }
+
+ return nil, err
+}
+
func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(), error) {
dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_"))
@@ -544,21 +589,63 @@ func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(
return "", nil, fmt.Errorf("failed to create database: %v", err)
}
- var err error
+ originalDSN := dsn
+
cleanup := func() {
+ var dropDB *gorm.DB
+ var err error
+
switch engine {
case types.PostgresStoreEngine:
- err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error
+ dropDB, err = gorm.Open(postgres.Open(originalDSN), &gorm.Config{
+ SkipDefaultTransaction: true,
+ PrepareStmt: false,
+ })
+ if err != nil {
+ log.Errorf("failed to connect for dropping database %s: %v", dbName, err)
+ return
+ }
+ defer func() {
+ if sqlDB, _ := dropDB.DB(); sqlDB != nil {
+ sqlDB.Close()
+ }
+ }()
+
+ if sqlDB, _ := dropDB.DB(); sqlDB != nil {
+ sqlDB.SetMaxOpenConns(1)
+ sqlDB.SetMaxIdleConns(0)
+ sqlDB.SetConnMaxLifetime(time.Second)
+ }
+
+ err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", dbName)).Error
+
case types.MysqlStoreEngine:
- // err = killMySQLConnections(dsn, dbName)
- err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error
+ dropDB, err = gorm.Open(mysql.Open(originalDSN+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{
+ SkipDefaultTransaction: true,
+ PrepareStmt: false,
+ })
+ if err != nil {
+ log.Errorf("failed to connect for dropping database %s: %v", dbName, err)
+ return
+ }
+ defer func() {
+ if sqlDB, _ := dropDB.DB(); sqlDB != nil {
+ sqlDB.Close()
+ }
+ }()
+
+ if sqlDB, _ := dropDB.DB(); sqlDB != nil {
+ sqlDB.SetMaxOpenConns(1)
+ sqlDB.SetMaxIdleConns(0)
+ sqlDB.SetConnMaxLifetime(time.Second)
+ }
+
+ err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName)).Error
}
+
if err != nil {
log.Errorf("failed to drop database %s: %v", dbName, err)
- panic(err)
}
- sqlDB, _ := db.DB()
- _ = sqlDB.Close()
}
return replaceDBName(dsn, dbName), cleanup, nil
diff --git a/management/server/telemetry/grpc_metrics.go b/management/server/telemetry/grpc_metrics.go
index d4301802f..bd7fbc235 100644
--- a/management/server/telemetry/grpc_metrics.go
+++ b/management/server/telemetry/grpc_metrics.go
@@ -16,7 +16,6 @@ type GRPCMetrics struct {
meter metric.Meter
syncRequestsCounter metric.Int64Counter
syncRequestsBlockedCounter metric.Int64Counter
- syncRequestHighLatencyCounter metric.Int64Counter
loginRequestsCounter metric.Int64Counter
loginRequestsBlockedCounter metric.Int64Counter
loginRequestHighLatencyCounter metric.Int64Counter
@@ -46,14 +45,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err
}
- syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter",
- metric.WithUnit("1"),
- metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"),
- )
- if err != nil {
- return nil, err
- }
-
loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
@@ -126,7 +117,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
meter: meter,
syncRequestsCounter: syncRequestsCounter,
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
- syncRequestHighLatencyCounter: syncRequestHighLatencyCounter,
loginRequestsCounter: loginRequestsCounter,
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
@@ -175,9 +165,6 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration
// CountSyncRequestDuration counts the duration of the sync gRPC requests
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
- if duration > HighLatencyThreshold {
- grpcMetrics.syncRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
- }
}
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.
diff --git a/management/server/telemetry/http_api_metrics.go b/management/server/telemetry/http_api_metrics.go
index ae27466d9..c50ed1e51 100644
--- a/management/server/telemetry/http_api_metrics.go
+++ b/management/server/telemetry/http_api_metrics.go
@@ -7,8 +7,8 @@ import (
"strings"
"time"
- "github.com/google/uuid"
"github.com/gorilla/mux"
+ "github.com/rs/xid"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
@@ -169,7 +169,7 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
//nolint
ctx := context.WithValue(r.Context(), hook.ExecutionContextKey, hook.HTTPSource)
- reqID := uuid.New().String()
+ reqID := xid.New().String()
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
@@ -185,6 +185,18 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
h.ServeHTTP(w, r.WithContext(ctx))
+ userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
+ if err == nil {
+ if userAuth.AccountId != "" {
+ //nolint
+ ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId)
+ }
+ if userAuth.UserId != "" {
+ //nolint
+ ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId)
+ }
+ }
+
if w.Status() > 399 {
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
} else {
diff --git a/management/server/types/account.go b/management/server/types/account.go
index f830023c7..c43e0bb57 100644
--- a/management/server/types/account.go
+++ b/management/server/types/account.go
@@ -8,6 +8,7 @@ import (
"slices"
"strconv"
"strings"
+ "sync"
"time"
"github.com/hashicorp/go-multierror"
@@ -15,6 +16,7 @@ import (
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
+ "github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
@@ -39,8 +41,22 @@ const (
// firewallRuleMinPortRangesVer defines the minimum peer version that supports port range rules.
firewallRuleMinPortRangesVer = "0.48.0"
+ // firewallRuleMinNativeSSHVer defines the minimum peer version that supports native SSH features in the firewall rules.
+ firewallRuleMinNativeSSHVer = "0.60.0"
+
+ // nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections.
+ nativeSSHPortString = "22022"
+ nativeSSHPortNumber = 22022
+ // defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections.
+ defaultSSHPortString = "22"
+ defaultSSHPortNumber = 22
)
+type supportedFeatures struct {
+ nativeSSH bool
+ portRanges bool
+}
+
type LookupMap map[string]struct{}
// AccountMeta is a struct that contains a stripped down version of the Account object.
@@ -87,6 +103,13 @@ type Account struct {
NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"`
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
+
+ NetworkMapCache *NetworkMapBuilder `gorm:"-"`
+ nmapInitOnce *sync.Once `gorm:"-"`
+}
+
+func (a *Account) InitOnce() {
+ a.nmapInitOnce = &sync.Once{}
}
// this class is used by gorm only
@@ -255,9 +278,9 @@ func (a *Account) GetPeerNetworkMap(
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
metrics *telemetry.AccountManagerMetrics,
+ groupIDToUserIDs map[string][]string,
) *NetworkMap {
start := time.Now()
-
peer := a.Peers[peerID]
if peer == nil {
return &NetworkMap{
@@ -271,7 +294,7 @@ func (a *Account) GetPeerNetworkMap(
}
}
- aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
+ aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
// exclude expired peers
var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer
@@ -301,7 +324,7 @@ func (a *Account) GetPeerNetworkMap(
if dnsManagementStatus {
var zones []nbdns.CustomZone
if peersCustomZone.Domain != "" {
- records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect)
+ records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
zones = append(zones, nbdns.CustomZone{
Domain: peersCustomZone.Domain,
Records: records,
@@ -319,6 +342,8 @@ func (a *Account) GetPeerNetworkMap(
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
+ AuthorizedUsers: authorizedUsers,
+ EnableSSH: enableSSH,
}
if metrics != nil {
@@ -890,6 +915,8 @@ func (a *Account) Copy() *Account {
NetworkRouters: networkRouters,
NetworkResources: networkResources,
Onboarding: a.Onboarding,
+ NetworkMapCache: a.NetworkMapCache,
+ nmapInitOnce: a.nmapInitOnce,
}
}
@@ -988,8 +1015,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
// GetPeerConnectionResources for a given peer
//
// This function returns the list of peers and firewall rules that are applicable to a given peer.
-func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
+func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, groupIDToUserIDs map[string][]string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
+ authorizedUsers := make(map[string]map[string]struct{}) // machine user to list of userIDs
+ sshEnabled := false
for _, policy := range a.Policies {
if !policy.Enabled {
@@ -1032,10 +1061,58 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
if peerInDestinations {
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
}
+
+ if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
+ sshEnabled = true
+ switch {
+ case len(rule.AuthorizedGroups) > 0:
+ for groupID, localUsers := range rule.AuthorizedGroups {
+ userIDs, ok := groupIDToUserIDs[groupID]
+ if !ok {
+ log.WithContext(ctx).Tracef("no user IDs found for group ID %s", groupID)
+ continue
+ }
+
+ if len(localUsers) == 0 {
+ localUsers = []string{auth.Wildcard}
+ }
+
+ for _, localUser := range localUsers {
+ if authorizedUsers[localUser] == nil {
+ authorizedUsers[localUser] = make(map[string]struct{})
+ }
+ for _, userID := range userIDs {
+ authorizedUsers[localUser][userID] = struct{}{}
+ }
+ }
+ }
+ case rule.AuthorizedUser != "":
+ if authorizedUsers[auth.Wildcard] == nil {
+ authorizedUsers[auth.Wildcard] = make(map[string]struct{})
+ }
+ authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{}
+ default:
+ authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
+ }
+ } else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
+ sshEnabled = true
+ authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
+ }
}
}
- return getAccumulatedResources()
+ peers, fwRules := getAccumulatedResources()
+ return peers, fwRules, authorizedUsers, sshEnabled
+}
+
+func (a *Account) getAllowedUserIDs() map[string]struct{} {
+ users := make(map[string]struct{})
+ for _, nbUser := range a.Users {
+ if !nbUser.IsBlocked() && !nbUser.IsServiceUser {
+ users[nbUser.Id] = struct{}{}
+ }
+ }
+ return users
}
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
@@ -1049,14 +1126,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
rules := make([]*FirewallRule, 0)
peers := make([]*nbpeer.Peer, 0)
- all, err := a.GetGroupAll()
- if err != nil {
- log.WithContext(ctx).Errorf("failed to get group all: %v", err)
- all = &Group{}
- }
-
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
- isAll := (len(all.Peers) - 1) == len(groupPeers)
for _, peer := range groupPeers {
if peer == nil {
continue
@@ -1067,16 +1137,17 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
peersExists[peer.ID] = struct{}{}
}
+ protocol := rule.Protocol
+ if protocol == PolicyRuleProtocolNetbirdSSH {
+ protocol = PolicyRuleProtocolTCP
+ }
+
fr := FirewallRule{
PolicyID: rule.ID,
PeerIP: peer.IP.String(),
Direction: direction,
Action: string(rule.Action),
- Protocol: string(rule.Protocol),
- }
-
- if isAll {
- fr.PeerIP = "0.0.0.0"
+ Protocol: string(protocol),
}
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
@@ -1098,6 +1169,28 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
}
}
+func policyRuleImpliesLegacySSH(rule *PolicyRule) bool {
+ return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
+}
+
+func portRangeIncludesSSH(portRanges []RulePortRange) bool {
+ for _, pr := range portRanges {
+ if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) {
+ return true
+ }
+ }
+ return false
+}
+
+func portsIncludesSSH(ports []string) bool {
+ for _, port := range ports {
+ if port == defaultSSHPortString || port == nativeSSHPortString {
+ return true
+ }
+ }
+ return false
+}
+
// getAllPeersFromGroups for given peer ID and list of groups
//
// Returns a list of peers from specified groups that pass specified posture checks
@@ -1244,6 +1337,13 @@ func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID
}
}
}
+ if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
+ _, distPeer := distributionPeers[rule.SourceResource.ID]
+ _, valid := validatedPeersMap[rule.SourceResource.ID]
+ if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, rule.SourceResource.ID) {
+ distPeersWithPolicy[rule.SourceResource.ID] = struct{}{}
+ }
+ }
distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy))
for pID := range distPeersWithPolicy {
@@ -1589,6 +1689,10 @@ func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[st
sourcePeers[peer] = struct{}{}
}
}
+
+ if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
+ sourcePeers[rule.SourceResource.ID] = struct{}{}
+ }
}
}
@@ -1639,24 +1743,46 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
return nil
}
+func (a *Account) GetActiveGroupUsers() map[string][]string {
+ allGroupID := ""
+ group, err := a.GetGroupAll()
+ if err != nil {
+ log.Errorf("failed to get group all: %v", err)
+ } else {
+ allGroupID = group.ID
+ }
+ groups := make(map[string][]string, len(a.GroupsG))
+ for _, user := range a.Users {
+ if !user.IsBlocked() && !user.IsServiceUser {
+ for _, groupID := range user.AutoGroups {
+ groups[groupID] = append(groups[groupID], user.Id)
+ }
+ groups[allGroupID] = append(groups[allGroupID], user.Id)
+ }
+ }
+ return groups
+}
+
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
+ features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
+
var expanded []*FirewallRule
- if len(rule.Ports) > 0 {
- for _, port := range rule.Ports {
- fr := base
- fr.Port = port
- expanded = append(expanded, &fr)
- }
- return expanded
+ for _, port := range rule.Ports {
+ fr := base
+ fr.Port = port
+ expanded = append(expanded, &fr)
}
- supportPortRanges := peerSupportsPortRanges(peer.Meta.WtVersion)
for _, portRange := range rule.PortRanges {
+ // prefer PolicyRule.Ports
+ if len(rule.Ports) > 0 {
+ break
+ }
fr := base
- if supportPortRanges {
+ if features.portRanges {
fr.PortRange = portRange
} else {
// Peer doesn't support port ranges, only allow single-port ranges
@@ -1668,21 +1794,67 @@ func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer
expanded = append(expanded, &fr)
}
+ if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) || rule.Protocol == PolicyRuleProtocolNetbirdSSH {
+ expanded = addNativeSSHRule(base, expanded)
+ }
+
return expanded
}
-// peerSupportsPortRanges checks if the peer version supports port ranges.
-func peerSupportsPortRanges(peerVer string) bool {
- if strings.Contains(peerVer, "dev") {
- return true
+// addNativeSSHRule adds a native SSH rule (port 22022) to the expanded rules if the base rule has port 22 configured.
+func addNativeSSHRule(base FirewallRule, expanded []*FirewallRule) []*FirewallRule {
+ shouldAdd := false
+ for _, fr := range expanded {
+ if isPortInRule(nativeSSHPortString, 22022, fr) {
+ return expanded
+ }
+ if isPortInRule(defaultSSHPortString, 22, fr) {
+ shouldAdd = true
+ }
+ }
+ if !shouldAdd {
+ return expanded
}
- meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
- return err == nil && meetMinVer
+ fr := base
+ fr.Port = nativeSSHPortString
+ return append(expanded, &fr)
+}
+
+func isPortInRule(portString string, portInt uint16, rule *FirewallRule) bool {
+ return rule.Port == portString || (rule.PortRange.Start <= portInt && portInt <= rule.PortRange.End)
+}
+
+// shouldCheckRulesForNativeSSH determines whether specific policy rules should be checked for native SSH support.
+// While users can add the nativeSSHPortString, we look for cases when they used port 22 and based on SSH enabled
+// in both management and client, we indicate to add the native port.
+func shouldCheckRulesForNativeSSH(supportsNative bool, rule *PolicyRule, peer *nbpeer.Peer) bool {
+ return supportsNative && peer.SSHEnabled && peer.Meta.Flags.ServerSSHAllowed && rule.Protocol == PolicyRuleProtocolTCP
+}
+
+// peerSupportedFirewallFeatures checks if the peer version supports port ranges.
+func peerSupportedFirewallFeatures(peerVer string) supportedFeatures {
+ if strings.Contains(peerVer, "dev") {
+ return supportedFeatures{true, true}
+ }
+
+ var features supportedFeatures
+
+ meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinNativeSSHVer, peerVer)
+ features.nativeSSH = err == nil && meetMinVer
+
+ if features.nativeSSH {
+ features.portRanges = true
+ } else {
+ meetMinVer, err = posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
+ features.portRanges = err == nil && meetMinVer
+ }
+
+ return features
}
// filterZoneRecordsForPeers filters DNS records to only include peers to connect.
-func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord {
+func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord {
filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records))
peerIPs := make(map[string]struct{})
@@ -1693,6 +1865,10 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p
peerIPs[peerToConnect.IP.String()] = struct{}{}
}
+ for _, expiredPeer := range expiredPeers {
+ peerIPs[expiredPeer.IP.String()] = struct{}{}
+ }
+
for _, record := range customZone.Records {
if _, exists := peerIPs[record.RData]; exists {
filteredRecords = append(filteredRecords, record)
diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go
index cd221b590..2c9f2428d 100644
--- a/management/server/types/account_test.go
+++ b/management/server/types/account_test.go
@@ -839,12 +839,466 @@ func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) {
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
}
+func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
+ tests := []struct {
+ name string
+ peer *nbpeer.Peer
+ rule *PolicyRule
+ base FirewallRule
+ expectedPorts []string
+ }{
+ {
+ name: "adds port 22022 when SSH enabled on modern peer with port 22",
+ peer: &nbpeer.Peer{
+ ID: "peer1",
+ SSHEnabled: true,
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.60.0",
+ Flags: nbpeer.Flags{ServerSSHAllowed: true},
+ },
+ },
+ rule: &PolicyRule{
+ Protocol: PolicyRuleProtocolTCP,
+ Ports: []string{"22"},
+ },
+ base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
+ expectedPorts: []string{"22", "22022"},
+ },
+ {
+ name: "adds port 22022 once when port 22 is duplicated within policy",
+ peer: &nbpeer.Peer{
+ ID: "peer1",
+ SSHEnabled: true,
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.60.0",
+ Flags: nbpeer.Flags{ServerSSHAllowed: true},
+ },
+ },
+ rule: &PolicyRule{
+ Protocol: PolicyRuleProtocolTCP,
+ Ports: []string{"22", "80", "22"},
+ },
+ base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
+ expectedPorts: []string{"22", "80", "22", "22022"},
+ },
+ {
+ name: "does not add 22022 for peer with old version",
+ peer: &nbpeer.Peer{
+ ID: "peer1",
+ SSHEnabled: true,
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.59.0",
+ Flags: nbpeer.Flags{ServerSSHAllowed: true},
+ },
+ },
+ rule: &PolicyRule{
+ Protocol: PolicyRuleProtocolTCP,
+ Ports: []string{"22"},
+ },
+ base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
+ expectedPorts: []string{"22"},
+ },
+ {
+ name: "does not add 22022 when SSHEnabled is false",
+ peer: &nbpeer.Peer{
+ ID: "peer1",
+ SSHEnabled: false,
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.60.0",
+ Flags: nbpeer.Flags{ServerSSHAllowed: true},
+ },
+ },
+ rule: &PolicyRule{
+ Protocol: PolicyRuleProtocolTCP,
+ Ports: []string{"22"},
+ },
+ base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
+ expectedPorts: []string{"22"},
+ },
+ {
+ name: "does not add 22022 when ServerSSHAllowed is false",
+ peer: &nbpeer.Peer{
+ ID: "peer1",
+ SSHEnabled: true,
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.60.0",
+ Flags: nbpeer.Flags{ServerSSHAllowed: false},
+ },
+ },
+ rule: &PolicyRule{
+ Protocol: PolicyRuleProtocolTCP,
+ Ports: []string{"22"},
+ },
+ base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
+ expectedPorts: []string{"22"},
+ },
+ {
+ name: "does not add 22022 for UDP protocol",
+ peer: &nbpeer.Peer{
+ ID: "peer1",
+ SSHEnabled: true,
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.60.0",
+ Flags: nbpeer.Flags{ServerSSHAllowed: true},
+ },
+ },
+ rule: &PolicyRule{
+ Protocol: PolicyRuleProtocolUDP,
+ Ports: []string{"22"},
+ },
+ base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "udp"},
+ expectedPorts: []string{"22"},
+ },
+ {
+ name: "does not add 22022 when port 22 not in rule",
+ peer: &nbpeer.Peer{
+ ID: "peer1",
+ SSHEnabled: true,
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.60.0",
+ Flags: nbpeer.Flags{ServerSSHAllowed: true},
+ },
+ },
+ rule: &PolicyRule{
+ Protocol: PolicyRuleProtocolTCP,
+ Ports: []string{"80", "443"},
+ },
+ base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
+ expectedPorts: []string{"80", "443"},
+ },
+ {
+ name: "does not duplicate 22022 when already present",
+ peer: &nbpeer.Peer{
+ ID: "peer1",
+ SSHEnabled: true,
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.60.0",
+ Flags: nbpeer.Flags{ServerSSHAllowed: true},
+ },
+ },
+ rule: &PolicyRule{
+ Protocol: PolicyRuleProtocolTCP,
+ Ports: []string{"22", "22022"},
+ },
+ base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
+ expectedPorts: []string{"22", "22022"},
+ },
+ {
+ name: "does not duplicate 22022 when already within a port range",
+ peer: &nbpeer.Peer{
+ ID: "peer1",
+ SSHEnabled: true,
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.60.0",
+ Flags: nbpeer.Flags{ServerSSHAllowed: true},
+ },
+ },
+ rule: &PolicyRule{
+ Protocol: PolicyRuleProtocolTCP,
+ PortRanges: []RulePortRange{{Start: 20, End: 32000}},
+ },
+ base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
+ expectedPorts: []string{"20-32000"},
+ },
+ {
+ name: "adds 22022 when port 22 in port range",
+ peer: &nbpeer.Peer{
+ ID: "peer1",
+ SSHEnabled: true,
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.60.0",
+ Flags: nbpeer.Flags{ServerSSHAllowed: true},
+ },
+ },
+ rule: &PolicyRule{
+ Protocol: PolicyRuleProtocolTCP,
+ PortRanges: []RulePortRange{{Start: 20, End: 25}},
+ },
+ base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
+ expectedPorts: []string{"20-25", "22022"},
+ },
+ {
+ name: "adds single 22022 once when port 22 in multiple port ranges",
+ peer: &nbpeer.Peer{
+ ID: "peer1",
+ SSHEnabled: true,
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.60.0",
+ Flags: nbpeer.Flags{ServerSSHAllowed: true},
+ },
+ },
+ rule: &PolicyRule{
+ Protocol: PolicyRuleProtocolTCP,
+ PortRanges: []RulePortRange{{Start: 20, End: 25}, {Start: 10, End: 100}},
+ },
+ base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
+ expectedPorts: []string{"20-25", "10-100", "22022"},
+ },
+ {
+ name: "dev suffix version supports all features",
+ peer: &nbpeer.Peer{
+ ID: "peer1",
+ SSHEnabled: true,
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "0.50.0-dev",
+ Flags: nbpeer.Flags{ServerSSHAllowed: true},
+ },
+ },
+ rule: &PolicyRule{
+ Protocol: PolicyRuleProtocolTCP,
+ Ports: []string{"22"},
+ },
+ base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
+ expectedPorts: []string{"22", "22022"},
+ },
+ {
+ name: "dev suffix version supports all features",
+ peer: &nbpeer.Peer{
+ ID: "peer1",
+ SSHEnabled: true,
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "dev",
+ Flags: nbpeer.Flags{ServerSSHAllowed: true},
+ },
+ },
+ rule: &PolicyRule{
+ Protocol: PolicyRuleProtocolTCP,
+ Ports: []string{"22"},
+ },
+ base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
+ expectedPorts: []string{"22", "22022"},
+ },
+ {
+ name: "development suffix version supports all features",
+ peer: &nbpeer.Peer{
+ ID: "peer1",
+ SSHEnabled: true,
+ Meta: nbpeer.PeerSystemMeta{
+ WtVersion: "development",
+ Flags: nbpeer.Flags{ServerSSHAllowed: true},
+ },
+ },
+ rule: &PolicyRule{
+ Protocol: PolicyRuleProtocolTCP,
+ Ports: []string{"22"},
+ },
+ base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
+ expectedPorts: []string{"22", "22022"},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := expandPortsAndRanges(tt.base, tt.rule, tt.peer)
+
+ var ports []string
+ for _, fr := range result {
+ if fr.Port != "" {
+ ports = append(ports, fr.Port)
+ } else if fr.PortRange.Start > 0 {
+ ports = append(ports, fmt.Sprintf("%d-%d", fr.PortRange.Start, fr.PortRange.End))
+ }
+ }
+
+ assert.Equal(t, tt.expectedPorts, ports, "expanded ports should match expected")
+ })
+ }
+}
+
+func Test_GetActiveGroupUsers(t *testing.T) {
+ tests := []struct {
+ name string
+ account *Account
+ expected map[string][]string
+ }{
+ {
+ name: "all users are active",
+ account: &Account{
+ Users: map[string]*User{
+ "user1": {
+ Id: "user1",
+ AutoGroups: []string{"group1", "group2"},
+ Blocked: false,
+ },
+ "user2": {
+ Id: "user2",
+ AutoGroups: []string{"group2", "group3"},
+ Blocked: false,
+ },
+ "user3": {
+ Id: "user3",
+ AutoGroups: []string{"group1"},
+ Blocked: false,
+ },
+ },
+ },
+ expected: map[string][]string{
+ "group1": {"user1", "user3"},
+ "group2": {"user1", "user2"},
+ "group3": {"user2"},
+ "": {"user1", "user2", "user3"},
+ },
+ },
+ {
+ name: "some users are blocked",
+ account: &Account{
+ Users: map[string]*User{
+ "user1": {
+ Id: "user1",
+ AutoGroups: []string{"group1", "group2"},
+ Blocked: false,
+ },
+ "user2": {
+ Id: "user2",
+ AutoGroups: []string{"group2", "group3"},
+ Blocked: true,
+ },
+ "user3": {
+ Id: "user3",
+ AutoGroups: []string{"group1", "group3"},
+ Blocked: false,
+ },
+ },
+ },
+ expected: map[string][]string{
+ "group1": {"user1", "user3"},
+ "group2": {"user1"},
+ "group3": {"user3"},
+ "": {"user1", "user3"},
+ },
+ },
+ {
+ name: "all users are blocked",
+ account: &Account{
+ Users: map[string]*User{
+ "user1": {
+ Id: "user1",
+ AutoGroups: []string{"group1"},
+ Blocked: true,
+ },
+ "user2": {
+ Id: "user2",
+ AutoGroups: []string{"group2"},
+ Blocked: true,
+ },
+ },
+ },
+ expected: map[string][]string{},
+ },
+ {
+ name: "user with no auto groups",
+ account: &Account{
+ Users: map[string]*User{
+ "user1": {
+ Id: "user1",
+ AutoGroups: []string{},
+ Blocked: false,
+ },
+ "user2": {
+ Id: "user2",
+ AutoGroups: []string{"group1"},
+ Blocked: false,
+ },
+ },
+ },
+ expected: map[string][]string{
+ "group1": {"user2"},
+ "": {"user1", "user2"},
+ },
+ },
+ {
+ name: "empty account",
+ account: &Account{
+ Users: map[string]*User{},
+ },
+ expected: map[string][]string{},
+ },
+ {
+ name: "multiple users in same group",
+ account: &Account{
+ Users: map[string]*User{
+ "user1": {
+ Id: "user1",
+ AutoGroups: []string{"group1"},
+ Blocked: false,
+ },
+ "user2": {
+ Id: "user2",
+ AutoGroups: []string{"group1"},
+ Blocked: false,
+ },
+ "user3": {
+ Id: "user3",
+ AutoGroups: []string{"group1"},
+ Blocked: false,
+ },
+ },
+ },
+ expected: map[string][]string{
+ "group1": {"user1", "user2", "user3"},
+ "": {"user1", "user2", "user3"},
+ },
+ },
+ {
+ name: "user in multiple groups with blocked users",
+ account: &Account{
+ Users: map[string]*User{
+ "user1": {
+ Id: "user1",
+ AutoGroups: []string{"group1", "group2", "group3"},
+ Blocked: false,
+ },
+ "user2": {
+ Id: "user2",
+ AutoGroups: []string{"group1", "group2"},
+ Blocked: true,
+ },
+ "user3": {
+ Id: "user3",
+ AutoGroups: []string{"group3"},
+ Blocked: false,
+ },
+ },
+ },
+ expected: map[string][]string{
+ "group1": {"user1"},
+ "group2": {"user1"},
+ "group3": {"user1", "user3"},
+ "": {"user1", "user3"},
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tt.account.GetActiveGroupUsers()
+
+ // Check that the number of groups matches
+ assert.Equal(t, len(tt.expected), len(result), "number of groups should match")
+
+ // Check each group's users
+ for groupID, expectedUsers := range tt.expected {
+ actualUsers, exists := result[groupID]
+ assert.True(t, exists, "group %s should exist in result", groupID)
+ assert.ElementsMatch(t, expectedUsers, actualUsers, "users in group %s should match", groupID)
+ }
+
+ // Ensure no extra groups in result
+ for groupID := range result {
+ _, exists := tt.expected[groupID]
+ assert.True(t, exists, "unexpected group %s in result", groupID)
+ }
+ })
+ }
+}
+
func Test_FilterZoneRecordsForPeers(t *testing.T) {
tests := []struct {
name string
peer *nbpeer.Peer
customZone nbdns.CustomZone
peersToConnect []*nbpeer.Peer
+ expiredPeers []*nbpeer.Peer
expectedRecords []nbdns.SimpleRecord
}{
{
@@ -857,6 +1311,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
},
},
peersToConnect: []*nbpeer.Peer{},
+ expiredPeers: []*nbpeer.Peer{},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
@@ -890,7 +1345,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
}
return peers
}(),
- peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
+ expiredPeers: []*nbpeer.Peer{},
+ peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: func() []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
@@ -924,7 +1380,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
{ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}},
{ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}},
},
- peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
+ expiredPeers: []*nbpeer.Peer{},
+ peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
@@ -934,11 +1391,35 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
+ {
+ name: "expired peers are included in DNS entries",
+ customZone: nbdns.CustomZone{
+ Domain: "netbird.cloud.",
+ Records: []nbdns.SimpleRecord{
+ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
+ {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
+ {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"},
+ {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
+ },
+ },
+ peersToConnect: []*nbpeer.Peer{
+ {ID: "peer1", IP: net.ParseIP("10.0.0.1")},
+ },
+ expiredPeers: []*nbpeer.Peer{
+ {ID: "expired-peer", IP: net.ParseIP("10.0.0.99")},
+ },
+ peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
+ expectedRecords: []nbdns.SimpleRecord{
+ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
+ {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"},
+ {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
+ },
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect)
+ result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect, tt.expiredPeers)
assert.Equal(t, len(tt.expectedRecords), len(result))
assert.ElementsMatch(t, tt.expectedRecords, result)
})
diff --git a/management/server/types/holder.go b/management/server/types/holder.go
new file mode 100644
index 000000000..3996db2b6
--- /dev/null
+++ b/management/server/types/holder.go
@@ -0,0 +1,43 @@
+package types
+
+import (
+ "context"
+ "sync"
+)
+
+type Holder struct {
+ mu sync.RWMutex
+ accounts map[string]*Account
+}
+
+func NewHolder() *Holder {
+ return &Holder{
+ accounts: make(map[string]*Account),
+ }
+}
+
+func (h *Holder) GetAccount(id string) *Account {
+ h.mu.RLock()
+ defer h.mu.RUnlock()
+ return h.accounts[id]
+}
+
+func (h *Holder) AddAccount(account *Account) {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ h.accounts[account.Id] = account
+}
+
+func (h *Holder) LoadOrStoreFunc(id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ if acc, ok := h.accounts[id]; ok {
+ return acc, nil
+ }
+ account, err := accGetter(context.Background(), id)
+ if err != nil {
+ return nil, err
+ }
+ h.accounts[id] = account
+ return account, nil
+}
diff --git a/management/server/types/network.go b/management/server/types/network.go
index ffc019565..d3708d80a 100644
--- a/management/server/types/network.go
+++ b/management/server/types/network.go
@@ -38,6 +38,8 @@ type NetworkMap struct {
FirewallRules []*FirewallRule
RoutesFirewallRules []*RouteFirewallRule
ForwardingRules []*ForwardingRule
+ AuthorizedUsers map[string]map[string]struct{}
+ EnableSSH bool
}
func (nm *NetworkMap) Merge(other *NetworkMap) {
diff --git a/management/server/types/networkmap.go b/management/server/types/networkmap.go
new file mode 100644
index 000000000..c1099726f
--- /dev/null
+++ b/management/server/types/networkmap.go
@@ -0,0 +1,58 @@
+package types
+
+import (
+ "context"
+
+ nbdns "github.com/netbirdio/netbird/dns"
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/telemetry"
+)
+
+func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) {
+ if a.NetworkMapCache != nil {
+ return
+ }
+ a.nmapInitOnce.Do(func() {
+ a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers)
+ })
+}
+
+func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) {
+ a.initNetworkMapBuilder(validatedPeers)
+}
+
+func (a *Account) GetPeerNetworkMapExp(
+ ctx context.Context,
+ peerID string,
+ peersCustomZone nbdns.CustomZone,
+ validatedPeers map[string]struct{},
+ metrics *telemetry.AccountManagerMetrics,
+) *NetworkMap {
+ a.initNetworkMapBuilder(validatedPeers)
+ return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, validatedPeers, metrics)
+}
+
+func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
+ if a.NetworkMapCache == nil {
+ return nil
+ }
+ return a.NetworkMapCache.OnPeerAddedIncremental(peerId)
+}
+
+func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error {
+ if a.NetworkMapCache == nil {
+ return nil
+ }
+ return a.NetworkMapCache.OnPeerDeleted(peerId)
+}
+
+func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) {
+ if a.NetworkMapCache == nil {
+ return
+ }
+ a.NetworkMapCache.UpdatePeer(peer)
+}
+
+func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) {
+ a.initNetworkMapBuilder(validatedPeers)
+}
diff --git a/management/server/types/networkmap_golden_test.go b/management/server/types/networkmap_golden_test.go
new file mode 100644
index 000000000..913094e4c
--- /dev/null
+++ b/management/server/types/networkmap_golden_test.go
@@ -0,0 +1,1069 @@
+package types_test
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net"
+ "net/netip"
+ "os"
+ "path/filepath"
+ "slices"
+ "sort"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "github.com/netbirdio/netbird/dns"
+ resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
+ routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
+ networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/posture"
+ "github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/route"
+)
+
+// update flag is used to update the golden file.
+// example: go test ./... -v -update
+// var update = flag.Bool("update", false, "update golden files")
+
+const (
+ numPeers = 100
+ devGroupID = "group-dev"
+ opsGroupID = "group-ops"
+ allGroupID = "group-all"
+ routeID = route.ID("route-main")
+ routeHA1ID = route.ID("route-ha-1")
+ routeHA2ID = route.ID("route-ha-2")
+ policyIDDevOps = "policy-dev-ops"
+ policyIDAll = "policy-all"
+ policyIDPosture = "policy-posture"
+ policyIDDrop = "policy-drop"
+ postureCheckID = "posture-check-ver"
+ networkResourceID = "res-database"
+ networkID = "net-database"
+ networkRouterID = "router-database"
+ nameserverGroupID = "ns-group-main"
+ testingPeerID = "peer-60" // A peer from the "dev" group, should receive the most detailed map.
+ expiredPeerID = "peer-98" // This peer will be online but with an expired session.
+ offlinePeerID = "peer-99" // This peer will be completely offline.
+ routingPeerID = "peer-95" // This peer is used for routing, it has a route to the network.
+ testAccountID = "account-golden-test"
+)
+
+func TestGetPeerNetworkMap_Golden(t *testing.T) {
+ account := createTestAccountWithEntities()
+
+ ctx := context.Background()
+ validatedPeersMap := make(map[string]struct{})
+ for i := range numPeers {
+ peerID := fmt.Sprintf("peer-%d", i)
+ if peerID == offlinePeerID {
+ continue
+ }
+ validatedPeersMap[peerID] = struct{}{}
+ }
+
+ resourcePolicies := account.GetResourcePoliciesMap()
+ routers := account.GetResourceRoutersMap()
+
+ networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
+
+ normalizeAndSortNetworkMap(networkMap)
+
+ jsonData, err := json.MarshalIndent(networkMap, "", " ")
+ require.NoError(t, err, "error marshaling network map to JSON")
+
+ goldenFilePath := filepath.Join("testdata", "networkmap_golden.json")
+
+ t.Log("Update golden file...")
+ err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
+ require.NoError(t, err)
+ err = os.WriteFile(goldenFilePath, jsonData, 0644)
+ require.NoError(t, err)
+
+ expectedJSON, err := os.ReadFile(goldenFilePath)
+ require.NoError(t, err, "error reading golden file")
+
+ require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from OLD method does not match golden file")
+}
+
+func TestGetPeerNetworkMap_Golden_New(t *testing.T) {
+ account := createTestAccountWithEntities()
+
+ ctx := context.Background()
+ validatedPeersMap := make(map[string]struct{})
+ for i := range numPeers {
+ peerID := fmt.Sprintf("peer-%d", i)
+
+ if peerID == offlinePeerID {
+ continue
+ }
+ validatedPeersMap[peerID] = struct{}{}
+ }
+
+ builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
+ networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
+
+ normalizeAndSortNetworkMap(networkMap)
+
+ jsonData, err := json.MarshalIndent(networkMap, "", " ")
+ require.NoError(t, err, "error marshaling network map to JSON")
+
+ goldenFilePath := filepath.Join("testdata", "networkmap_golden_new.json")
+
+ t.Log("Update golden file...")
+ err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
+ require.NoError(t, err)
+ err = os.WriteFile(goldenFilePath, jsonData, 0644)
+ require.NoError(t, err)
+
+ expectedJSON, err := os.ReadFile(goldenFilePath)
+ require.NoError(t, err, "error reading golden file")
+
+ require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from NEW builder does not match golden file")
+}
+
+func BenchmarkGetPeerNetworkMap(b *testing.B) {
+ account := createTestAccountWithEntities()
+ ctx := context.Background()
+ validatedPeersMap := make(map[string]struct{})
+ var peerIDs []string
+ for i := range numPeers {
+ peerID := fmt.Sprintf("peer-%d", i)
+ validatedPeersMap[peerID] = struct{}{}
+ peerIDs = append(peerIDs, peerID)
+ }
+
+ b.ResetTimer()
+ b.Run("old builder", func(b *testing.B) {
+ for range b.N {
+ for _, peerID := range peerIDs {
+ _ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
+ }
+ }
+ })
+ b.ResetTimer()
+ b.Run("new builder", func(b *testing.B) {
+ for range b.N {
+ builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
+ for _, peerID := range peerIDs {
+ _ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil)
+ }
+ }
+ })
+}
+
+func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
+ account := createTestAccountWithEntities()
+
+ ctx := context.Background()
+ validatedPeersMap := make(map[string]struct{})
+ for i := range numPeers {
+ peerID := fmt.Sprintf("peer-%d", i)
+ if peerID == offlinePeerID {
+ continue
+ }
+ validatedPeersMap[peerID] = struct{}{}
+ }
+
+ newPeerID := "peer-new-101"
+ newPeerIP := net.IP{100, 64, 1, 1}
+ newPeer := &nbpeer.Peer{
+ ID: newPeerID,
+ IP: newPeerIP,
+ Key: fmt.Sprintf("key-%s", newPeerID),
+ DNSLabel: "peernew101",
+ Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
+ UserID: "user-admin",
+ Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
+ LastLogin: func() *time.Time { t := time.Now(); return &t }(),
+ }
+
+ account.Peers[newPeerID] = newPeer
+
+ if devGroup, exists := account.Groups[devGroupID]; exists {
+ devGroup.Peers = append(devGroup.Peers, newPeerID)
+ }
+
+ if allGroup, exists := account.Groups[allGroupID]; exists {
+ allGroup.Peers = append(allGroup.Peers, newPeerID)
+ }
+
+ validatedPeersMap[newPeerID] = struct{}{}
+
+ if account.Network != nil {
+ account.Network.Serial++
+ }
+
+ resourcePolicies := account.GetResourcePoliciesMap()
+ routers := account.GetResourceRoutersMap()
+
+ networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
+
+ normalizeAndSortNetworkMap(networkMap)
+
+ jsonData, err := json.MarshalIndent(networkMap, "", " ")
+ require.NoError(t, err, "error marshaling network map to JSON")
+
+ goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json")
+
+ t.Log("Update golden file with new peer...")
+ err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
+ require.NoError(t, err)
+ err = os.WriteFile(goldenFilePath, jsonData, 0644)
+ require.NoError(t, err)
+
+ expectedJSON, err := os.ReadFile(goldenFilePath)
+ require.NoError(t, err, "error reading golden file")
+
+ require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new peer does not match golden file")
+}
+
+func TestGetPeerNetworkMap_Golden_New_WithOnPeerAdded(t *testing.T) {
+ account := createTestAccountWithEntities()
+
+ ctx := context.Background()
+ validatedPeersMap := make(map[string]struct{})
+ for i := range numPeers {
+ peerID := fmt.Sprintf("peer-%d", i)
+ if peerID == offlinePeerID {
+ continue
+ }
+ validatedPeersMap[peerID] = struct{}{}
+ }
+
+ builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
+
+ newPeerID := "peer-new-101"
+ newPeerIP := net.IP{100, 64, 1, 1}
+ newPeer := &nbpeer.Peer{
+ ID: newPeerID,
+ IP: newPeerIP,
+ Key: fmt.Sprintf("key-%s", newPeerID),
+ DNSLabel: "peernew101",
+ Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
+ UserID: "user-admin",
+ Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
+ LastLogin: func() *time.Time { t := time.Now(); return &t }(),
+ }
+
+ account.Peers[newPeerID] = newPeer
+
+ if devGroup, exists := account.Groups[devGroupID]; exists {
+ devGroup.Peers = append(devGroup.Peers, newPeerID)
+ }
+
+ if allGroup, exists := account.Groups[allGroupID]; exists {
+ allGroup.Peers = append(allGroup.Peers, newPeerID)
+ }
+
+ validatedPeersMap[newPeerID] = struct{}{}
+
+ if account.Network != nil {
+ account.Network.Serial++
+ }
+
+ err := builder.OnPeerAddedIncremental(newPeerID)
+ require.NoError(t, err, "error adding peer to cache")
+
+ networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
+
+ normalizeAndSortNetworkMap(networkMap)
+
+ jsonData, err := json.MarshalIndent(networkMap, "", " ")
+ require.NoError(t, err, "error marshaling network map to JSON")
+
+ goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json")
+ t.Log("Update golden file with OnPeerAdded...")
+ err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
+ require.NoError(t, err)
+ err = os.WriteFile(goldenFilePath, jsonData, 0644)
+ require.NoError(t, err)
+
+ expectedJSON, err := os.ReadFile(goldenFilePath)
+ require.NoError(t, err, "error reading golden file")
+
+ require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded does not match golden file")
+}
+
+func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
+ account := createTestAccountWithEntities()
+ ctx := context.Background()
+ validatedPeersMap := make(map[string]struct{})
+ var peerIDs []string
+ for i := range numPeers {
+ peerID := fmt.Sprintf("peer-%d", i)
+ validatedPeersMap[peerID] = struct{}{}
+ peerIDs = append(peerIDs, peerID)
+ }
+ builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
+ newPeerID := "peer-new-101"
+ newPeer := &nbpeer.Peer{
+ ID: newPeerID,
+ IP: net.IP{100, 64, 1, 1},
+ Key: fmt.Sprintf("key-%s", newPeerID),
+ DNSLabel: "peernew101",
+ Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
+ UserID: "user-admin",
+ Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
+ }
+
+ account.Peers[newPeerID] = newPeer
+ account.Groups[devGroupID].Peers = append(account.Groups[devGroupID].Peers, newPeerID)
+ account.Groups[allGroupID].Peers = append(account.Groups[allGroupID].Peers, newPeerID)
+ validatedPeersMap[newPeerID] = struct{}{}
+
+ b.ResetTimer()
+ b.Run("old builder after add", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ for _, testingPeerID := range peerIDs {
+ _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
+ }
+ }
+ })
+
+ b.ResetTimer()
+ b.Run("new builder after add", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ _ = builder.OnPeerAddedIncremental(newPeerID)
+ for _, testingPeerID := range peerIDs {
+ _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
+ }
+ }
+ })
+}
+
+func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
+ account := createTestAccountWithEntities()
+
+ ctx := context.Background()
+ validatedPeersMap := make(map[string]struct{})
+ for i := range numPeers {
+ peerID := fmt.Sprintf("peer-%d", i)
+ if peerID == offlinePeerID {
+ continue
+ }
+ validatedPeersMap[peerID] = struct{}{}
+ }
+
+ newRouterID := "peer-new-router-102"
+ newRouterIP := net.IP{100, 64, 1, 2}
+ newRouter := &nbpeer.Peer{
+ ID: newRouterID,
+ IP: newRouterIP,
+ Key: fmt.Sprintf("key-%s", newRouterID),
+ DNSLabel: "newrouter102",
+ Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
+ UserID: "user-admin",
+ Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
+ LastLogin: func() *time.Time { t := time.Now(); return &t }(),
+ }
+
+ account.Peers[newRouterID] = newRouter
+
+ if opsGroup, exists := account.Groups[opsGroupID]; exists {
+ opsGroup.Peers = append(opsGroup.Peers, newRouterID)
+ }
+
+ if allGroup, exists := account.Groups[allGroupID]; exists {
+ allGroup.Peers = append(allGroup.Peers, newRouterID)
+ }
+
+ newRoute := &route.Route{
+ ID: route.ID("route-new-router"),
+ Network: netip.MustParsePrefix("172.16.0.0/24"),
+ Peer: newRouter.Key,
+ PeerID: newRouterID,
+ Description: "Route from new router",
+ Enabled: true,
+ PeerGroups: []string{opsGroupID},
+ Groups: []string{devGroupID, opsGroupID},
+ AccessControlGroups: []string{devGroupID},
+ AccountID: account.Id,
+ }
+ account.Routes[newRoute.ID] = newRoute
+
+ validatedPeersMap[newRouterID] = struct{}{}
+
+ if account.Network != nil {
+ account.Network.Serial++
+ }
+
+ resourcePolicies := account.GetResourcePoliciesMap()
+ routers := account.GetResourceRoutersMap()
+
+ networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
+
+ normalizeAndSortNetworkMap(networkMap)
+
+ jsonData, err := json.MarshalIndent(networkMap, "", " ")
+ require.NoError(t, err, "error marshaling network map to JSON")
+
+ goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json")
+
+ t.Log("Update golden file with new router...")
+ err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
+ require.NoError(t, err)
+ err = os.WriteFile(goldenFilePath, jsonData, 0644)
+ require.NoError(t, err)
+
+ expectedJSON, err := os.ReadFile(goldenFilePath)
+ require.NoError(t, err, "error reading golden file")
+
+ require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new router does not match golden file")
+}
+
+func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter(t *testing.T) {
+ account := createTestAccountWithEntities()
+
+ ctx := context.Background()
+ validatedPeersMap := make(map[string]struct{})
+ for i := range numPeers {
+ peerID := fmt.Sprintf("peer-%d", i)
+ if peerID == offlinePeerID {
+ continue
+ }
+ validatedPeersMap[peerID] = struct{}{}
+ }
+
+ builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
+
+ newRouterID := "peer-new-router-102"
+ newRouterIP := net.IP{100, 64, 1, 2}
+ newRouter := &nbpeer.Peer{
+ ID: newRouterID,
+ IP: newRouterIP,
+ Key: fmt.Sprintf("key-%s", newRouterID),
+ DNSLabel: "newrouter102",
+ Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
+ UserID: "user-admin",
+ Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
+ LastLogin: func() *time.Time { t := time.Now(); return &t }(),
+ }
+
+ account.Peers[newRouterID] = newRouter
+
+ if opsGroup, exists := account.Groups[opsGroupID]; exists {
+ opsGroup.Peers = append(opsGroup.Peers, newRouterID)
+ }
+ if allGroup, exists := account.Groups[allGroupID]; exists {
+ allGroup.Peers = append(allGroup.Peers, newRouterID)
+ }
+
+ newRoute := &route.Route{
+ ID: route.ID("route-new-router"),
+ Network: netip.MustParsePrefix("172.16.0.0/24"),
+ Peer: newRouter.Key,
+ PeerID: newRouterID,
+ Description: "Route from new router",
+ Enabled: true,
+ PeerGroups: []string{opsGroupID},
+ Groups: []string{devGroupID, opsGroupID},
+ AccessControlGroups: []string{devGroupID},
+ AccountID: account.Id,
+ }
+ account.Routes[newRoute.ID] = newRoute
+
+ validatedPeersMap[newRouterID] = struct{}{}
+
+ if account.Network != nil {
+ account.Network.Serial++
+ }
+
+ err := builder.OnPeerAddedIncremental(newRouterID)
+ require.NoError(t, err, "error adding router to cache")
+
+ networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
+
+ normalizeAndSortNetworkMap(networkMap)
+
+ jsonData, err := json.MarshalIndent(networkMap, "", " ")
+ require.NoError(t, err, "error marshaling network map to JSON")
+
+ goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json")
+
+ t.Log("Update golden file with OnPeerAdded router...")
+ err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
+ require.NoError(t, err)
+ err = os.WriteFile(goldenFilePath, jsonData, 0644)
+ require.NoError(t, err)
+
+ expectedJSON, err := os.ReadFile(goldenFilePath)
+ require.NoError(t, err, "error reading golden file")
+
+ require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file")
+}
+
+func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
+ account := createTestAccountWithEntities()
+ ctx := context.Background()
+ validatedPeersMap := make(map[string]struct{})
+ var peerIDs []string
+ for i := range numPeers {
+ peerID := fmt.Sprintf("peer-%d", i)
+ validatedPeersMap[peerID] = struct{}{}
+ peerIDs = append(peerIDs, peerID)
+ }
+ builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
+ newRouterID := "peer-new-router-102"
+ newRouterIP := net.IP{100, 64, 1, 2}
+ newRouter := &nbpeer.Peer{
+ ID: newRouterID,
+ IP: newRouterIP,
+ Key: fmt.Sprintf("key-%s", newRouterID),
+ DNSLabel: "newrouter102",
+ Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
+ UserID: "user-admin",
+ Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
+ LastLogin: func() *time.Time { t := time.Now(); return &t }(),
+ }
+
+ account.Peers[newRouterID] = newRouter
+
+ if opsGroup, exists := account.Groups[opsGroupID]; exists {
+ opsGroup.Peers = append(opsGroup.Peers, newRouterID)
+ }
+ if allGroup, exists := account.Groups[allGroupID]; exists {
+ allGroup.Peers = append(allGroup.Peers, newRouterID)
+ }
+
+ newRoute := &route.Route{
+ ID: route.ID("route-new-router"),
+ Network: netip.MustParsePrefix("172.16.0.0/24"),
+ Peer: newRouter.Key,
+ PeerID: newRouterID,
+ Description: "Route from new router",
+ Enabled: true,
+ PeerGroups: []string{opsGroupID},
+ Groups: []string{devGroupID, opsGroupID},
+ AccessControlGroups: []string{devGroupID},
+ AccountID: account.Id,
+ }
+ account.Routes[newRoute.ID] = newRoute
+
+ validatedPeersMap[newRouterID] = struct{}{}
+
+ b.ResetTimer()
+ b.Run("old builder after add", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ for _, testingPeerID := range peerIDs {
+ _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
+ }
+ }
+ })
+
+ b.ResetTimer()
+ b.Run("new builder after add", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ _ = builder.OnPeerAddedIncremental(newRouterID)
+ for _, testingPeerID := range peerIDs {
+ _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
+ }
+ }
+ })
+}
+
+func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
+ account := createTestAccountWithEntities()
+
+ ctx := context.Background()
+ validatedPeersMap := make(map[string]struct{})
+ for i := range numPeers {
+ peerID := fmt.Sprintf("peer-%d", i)
+ if peerID == offlinePeerID {
+ continue
+ }
+ validatedPeersMap[peerID] = struct{}{}
+ }
+
+ deletedPeerID := "peer-25" // peer from devs group
+
+ delete(account.Peers, deletedPeerID)
+
+ if devGroup, exists := account.Groups[devGroupID]; exists {
+ devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool {
+ return id == deletedPeerID
+ })
+ }
+
+ if allGroup, exists := account.Groups[allGroupID]; exists {
+ allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool {
+ return id == deletedPeerID
+ })
+ }
+
+ delete(validatedPeersMap, deletedPeerID)
+
+ if account.Network != nil {
+ account.Network.Serial++
+ }
+
+ resourcePolicies := account.GetResourcePoliciesMap()
+ routers := account.GetResourceRoutersMap()
+
+ networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
+
+ normalizeAndSortNetworkMap(networkMap)
+
+ jsonData, err := json.MarshalIndent(networkMap, "", " ")
+ require.NoError(t, err, "error marshaling network map to JSON")
+
+ goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json")
+
+ t.Log("Update golden file with deleted peer...")
+ err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
+ require.NoError(t, err)
+ err = os.WriteFile(goldenFilePath, jsonData, 0644)
+ require.NoError(t, err)
+
+ expectedJSON, err := os.ReadFile(goldenFilePath)
+ require.NoError(t, err, "error reading golden file")
+
+ require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file")
+}
+
+func TestGetPeerNetworkMap_Golden_New_WithOnPeerDeleted(t *testing.T) {
+ account := createTestAccountWithEntities()
+
+ ctx := context.Background()
+ validatedPeersMap := make(map[string]struct{})
+ for i := range numPeers {
+ peerID := fmt.Sprintf("peer-%d", i)
+ if peerID == offlinePeerID {
+ continue
+ }
+ validatedPeersMap[peerID] = struct{}{}
+ }
+
+ builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
+
+ deletedPeerID := "peer-25" // devs group peer
+
+ delete(account.Peers, deletedPeerID)
+
+ if devGroup, exists := account.Groups[devGroupID]; exists {
+ devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool {
+ return id == deletedPeerID
+ })
+ }
+
+ if allGroup, exists := account.Groups[allGroupID]; exists {
+ allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool {
+ return id == deletedPeerID
+ })
+ }
+
+ delete(validatedPeersMap, deletedPeerID)
+
+ if account.Network != nil {
+ account.Network.Serial++
+ }
+
+ err := builder.OnPeerDeleted(deletedPeerID)
+ require.NoError(t, err, "error deleting peer from cache")
+
+ networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
+
+ normalizeAndSortNetworkMap(networkMap)
+
+ jsonData, err := json.MarshalIndent(networkMap, "", " ")
+ require.NoError(t, err, "error marshaling network map to JSON")
+
+ goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json")
+ t.Log("Update golden file with OnPeerDeleted...")
+ err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
+ require.NoError(t, err)
+ err = os.WriteFile(goldenFilePath, jsonData, 0644)
+ require.NoError(t, err)
+
+ expectedJSON, err := os.ReadFile(goldenFilePath)
+ require.NoError(t, err, "error reading golden file")
+
+ require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerDeleted does not match golden file")
+}
+
+func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
+ account := createTestAccountWithEntities()
+
+ ctx := context.Background()
+ validatedPeersMap := make(map[string]struct{})
+ for i := range numPeers {
+ peerID := fmt.Sprintf("peer-%d", i)
+ if peerID == offlinePeerID {
+ continue
+ }
+ validatedPeersMap[peerID] = struct{}{}
+ }
+
+ deletedRouterID := "peer-75" // router peer
+
+ var affectedRoute *route.Route
+ for _, r := range account.Routes {
+ if r.PeerID == deletedRouterID {
+ affectedRoute = r
+ break
+ }
+ }
+ require.NotNil(t, affectedRoute, "Router peer should have a route")
+
+ for _, group := range account.Groups {
+ group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool {
+ return id == deletedRouterID
+ })
+ }
+
+ for routeID, r := range account.Routes {
+ if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID {
+ delete(account.Routes, routeID)
+ }
+ }
+ delete(account.Peers, deletedRouterID)
+ delete(validatedPeersMap, deletedRouterID)
+
+ if account.Network != nil {
+ account.Network.Serial++
+ }
+
+ resourcePolicies := account.GetResourcePoliciesMap()
+ routers := account.GetResourceRoutersMap()
+
+ networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
+
+ normalizeAndSortNetworkMap(networkMap)
+
+ jsonData, err := json.MarshalIndent(networkMap, "", " ")
+ require.NoError(t, err, "error marshaling network map to JSON")
+
+ goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json")
+
+ t.Log("Update golden file with deleted peer...")
+ err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
+ require.NoError(t, err)
+ err = os.WriteFile(goldenFilePath, jsonData, 0644)
+ require.NoError(t, err)
+
+ expectedJSON, err := os.ReadFile(goldenFilePath)
+ require.NoError(t, err, "error reading golden file")
+
+ require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file")
+}
+
+func TestGetPeerNetworkMap_Golden_New_WithDeletedRouterPeer(t *testing.T) {
+ account := createTestAccountWithEntities()
+
+ ctx := context.Background()
+ validatedPeersMap := make(map[string]struct{})
+ for i := range numPeers {
+ peerID := fmt.Sprintf("peer-%d", i)
+ if peerID == offlinePeerID {
+ continue
+ }
+ validatedPeersMap[peerID] = struct{}{}
+ }
+
+ builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
+
+ deletedRouterID := "peer-75" // router peer
+
+ var affectedRoute *route.Route
+ for _, r := range account.Routes {
+ if r.PeerID == deletedRouterID {
+ affectedRoute = r
+ break
+ }
+ }
+ require.NotNil(t, affectedRoute, "Router peer should have a route")
+
+ for _, group := range account.Groups {
+ group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool {
+ return id == deletedRouterID
+ })
+ }
+ for routeID, r := range account.Routes {
+ if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID {
+ delete(account.Routes, routeID)
+ }
+ }
+ delete(account.Peers, deletedRouterID)
+ delete(validatedPeersMap, deletedRouterID)
+
+ if account.Network != nil {
+ account.Network.Serial++
+ }
+
+ err := builder.OnPeerDeleted(deletedRouterID)
+ require.NoError(t, err, "error deleting routing peer from cache")
+
+ networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
+
+ normalizeAndSortNetworkMap(networkMap)
+
+ jsonData, err := json.MarshalIndent(networkMap, "", " ")
+ require.NoError(t, err)
+
+ goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json")
+
+ t.Log("Update golden file with deleted router...")
+ err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
+ require.NoError(t, err)
+ err = os.WriteFile(goldenFilePath, jsonData, 0644)
+ require.NoError(t, err)
+
+ expectedJSON, err := os.ReadFile(goldenFilePath)
+ require.NoError(t, err)
+
+ require.JSONEq(t, string(expectedJSON), string(jsonData),
+ "network map after deleting router does not match golden file")
+}
+
+func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
+ account := createTestAccountWithEntities()
+ ctx := context.Background()
+ validatedPeersMap := make(map[string]struct{})
+ var peerIDs []string
+ for i := range numPeers {
+ peerID := fmt.Sprintf("peer-%d", i)
+ validatedPeersMap[peerID] = struct{}{}
+ peerIDs = append(peerIDs, peerID)
+ }
+
+ deletedPeerID := "peer-25"
+
+ delete(account.Peers, deletedPeerID)
+ account.Groups[devGroupID].Peers = slices.DeleteFunc(account.Groups[devGroupID].Peers, func(id string) bool {
+ return id == deletedPeerID
+ })
+ account.Groups[allGroupID].Peers = slices.DeleteFunc(account.Groups[allGroupID].Peers, func(id string) bool {
+ return id == deletedPeerID
+ })
+ delete(validatedPeersMap, deletedPeerID)
+
+ builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
+
+ b.ResetTimer()
+ b.Run("old builder after delete", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ for _, testingPeerID := range peerIDs {
+ _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
+ }
+ }
+ })
+
+ b.ResetTimer()
+ b.Run("new builder after delete", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ _ = builder.OnPeerDeleted(deletedPeerID)
+ for _, testingPeerID := range peerIDs {
+ _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
+ }
+ }
+ })
+}
+
+func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) {
+ for _, peer := range networkMap.Peers {
+ if peer.Status != nil {
+ peer.Status.LastSeen = time.Time{}
+ }
+ peer.LastLogin = &time.Time{}
+ }
+ for _, peer := range networkMap.OfflinePeers {
+ if peer.Status != nil {
+ peer.Status.LastSeen = time.Time{}
+ }
+ peer.LastLogin = &time.Time{}
+ }
+
+ sort.Slice(networkMap.Peers, func(i, j int) bool { return networkMap.Peers[i].ID < networkMap.Peers[j].ID })
+ sort.Slice(networkMap.OfflinePeers, func(i, j int) bool { return networkMap.OfflinePeers[i].ID < networkMap.OfflinePeers[j].ID })
+ sort.Slice(networkMap.Routes, func(i, j int) bool { return networkMap.Routes[i].ID < networkMap.Routes[j].ID })
+
+ sort.Slice(networkMap.FirewallRules, func(i, j int) bool {
+ r1, r2 := networkMap.FirewallRules[i], networkMap.FirewallRules[j]
+ if r1.PeerIP != r2.PeerIP {
+ return r1.PeerIP < r2.PeerIP
+ }
+ if r1.Protocol != r2.Protocol {
+ return r1.Protocol < r2.Protocol
+ }
+ if r1.Direction != r2.Direction {
+ return r1.Direction < r2.Direction
+ }
+ if r1.Action != r2.Action {
+ return r1.Action < r2.Action
+ }
+ return r1.Port < r2.Port
+ })
+
+ sort.Slice(networkMap.RoutesFirewallRules, func(i, j int) bool {
+ r1, r2 := networkMap.RoutesFirewallRules[i], networkMap.RoutesFirewallRules[j]
+ if r1.RouteID != r2.RouteID {
+ return r1.RouteID < r2.RouteID
+ }
+ if r1.Action != r2.Action {
+ return r1.Action < r2.Action
+ }
+ if r1.Destination != r2.Destination {
+ return r1.Destination < r2.Destination
+ }
+ if len(r1.SourceRanges) > 0 && len(r2.SourceRanges) > 0 {
+ if r1.SourceRanges[0] != r2.SourceRanges[0] {
+ return r1.SourceRanges[0] < r2.SourceRanges[0]
+ }
+ }
+ return r1.Port < r2.Port
+ })
+
+ for _, ranges := range networkMap.RoutesFirewallRules {
+ sort.Slice(ranges.SourceRanges, func(i, j int) bool {
+ return ranges.SourceRanges[i] < ranges.SourceRanges[j]
+ })
+ }
+}
+
+func createTestAccountWithEntities() *types.Account {
+ peers := make(map[string]*nbpeer.Peer)
+ devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
+
+ for i := range numPeers {
+ peerID := fmt.Sprintf("peer-%d", i)
+ ip := net.IP{100, 64, 0, byte(i + 1)}
+ wtVersion := "0.25.0"
+ if i%2 == 0 {
+ wtVersion = "0.40.0"
+ }
+
+ p := &nbpeer.Peer{
+ ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1),
+ Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
+ UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"},
+ }
+
+ if peerID == expiredPeerID {
+ p.LoginExpirationEnabled = true
+ pastTimestamp := time.Now().Add(-2 * time.Hour)
+ p.LastLogin = &pastTimestamp
+ }
+
+ peers[peerID] = p
+ allGroupPeers = append(allGroupPeers, peerID)
+ if i < numPeers/2 {
+ devGroupPeers = append(devGroupPeers, peerID)
+ } else {
+ opsGroupPeers = append(opsGroupPeers, peerID)
+ }
+
+ }
+
+ groups := map[string]*types.Group{
+ allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
+ devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
+ opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
+ }
+
+ policies := []*types.Policy{
+ {
+ ID: policyIDAll, Name: "Default-Allow", Enabled: true,
+ Rules: []*types.PolicyRule{{
+ ID: policyIDAll, Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept,
+ Protocol: types.PolicyRuleProtocolALL, Bidirectional: true,
+ Sources: []string{allGroupID}, Destinations: []string{allGroupID},
+ }},
+ },
+ {
+ ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true,
+ Rules: []*types.PolicyRule{{
+ ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: types.PolicyTrafficActionAccept,
+ Protocol: types.PolicyRuleProtocolTCP, Bidirectional: false,
+ PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}},
+ Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
+ }},
+ },
+ {
+ ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true,
+ Rules: []*types.PolicyRule{{
+ ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop,
+ Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true,
+ Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
+ }},
+ },
+ {
+ ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true,
+ SourcePostureChecks: []string{postureCheckID},
+ Rules: []*types.PolicyRule{{
+ ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: types.PolicyTrafficActionAccept,
+ Protocol: types.PolicyRuleProtocolALL, Bidirectional: true,
+ Sources: []string{opsGroupID}, DestinationResource: types.Resource{ID: networkResourceID},
+ }},
+ },
+ }
+
+ routes := map[route.ID]*route.Route{
+ routeID: {
+ ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"),
+ Peer: peers["peer-75"].Key,
+ PeerID: "peer-75",
+ Description: "Route to internal resource", Enabled: true,
+ PeerGroups: []string{devGroupID, opsGroupID},
+ Groups: []string{devGroupID, opsGroupID},
+ AccessControlGroups: []string{devGroupID},
+ },
+ routeHA1ID: {
+ ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
+ Peer: peers["peer-80"].Key,
+ PeerID: "peer-80",
+ Description: "HA Route 1", Enabled: true, Metric: 1000,
+ PeerGroups: []string{allGroupID},
+ Groups: []string{allGroupID},
+ AccessControlGroups: []string{allGroupID},
+ },
+ routeHA2ID: {
+ ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
+ Peer: peers["peer-90"].Key,
+ PeerID: "peer-90",
+ Description: "HA Route 2", Enabled: true, Metric: 900,
+ PeerGroups: []string{devGroupID, opsGroupID},
+ Groups: []string{devGroupID, opsGroupID},
+ AccessControlGroups: []string{allGroupID},
+ },
+ }
+
+ account := &types.Account{
+ Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
+ Network: &types.Network{
+ Identifier: "net-golden-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
+ },
+ DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{opsGroupID}},
+ NameServerGroups: map[string]*dns.NameServerGroup{
+ nameserverGroupID: {
+ ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID},
+ NameServers: []dns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: dns.UDPNameServerType, Port: 53}},
+ },
+ },
+ PostureChecks: []*posture.Checks{
+ {ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{
+ NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
+ }},
+ },
+ NetworkResources: []*resourceTypes.NetworkResource{
+ {ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"},
+ },
+ Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}},
+ NetworkRouters: []*routerTypes.NetworkRouter{
+ {ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID},
+ },
+ Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour},
+ }
+
+ for _, p := range account.Policies {
+ p.AccountID = account.Id
+ }
+ for _, r := range account.Routes {
+ r.AccountID = account.Id
+ }
+
+ return account
+}
diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go
new file mode 100644
index 000000000..5790f1646
--- /dev/null
+++ b/management/server/types/networkmapbuilder.go
@@ -0,0 +1,2018 @@
+package types
+
+import (
+ "context"
+ "fmt"
+ "slices"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/exp/maps"
+
+ nbdns "github.com/netbirdio/netbird/dns"
+ resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
+ routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/telemetry"
+ "github.com/netbirdio/netbird/route"
+)
+
+const (
+ allPeers = "0.0.0.0"
+ allWildcard = "0.0.0.0/0"
+ v6AllWildcard = "::/0"
+ fw = "fw:"
+ rfw = "route-fw:"
+)
+
+type NetworkMapCache struct {
+ globalRoutes map[route.ID]*route.Route
+ globalRules map[string]*FirewallRule //ruleId
+ globalRouteRules map[string]*RouteFirewallRule //ruleId
+ globalPeers map[string]*nbpeer.Peer
+
+ groupToPeers map[string][]string
+ peerToGroups map[string][]string
+ policyToRules map[string][]*PolicyRule //policyId
+ groupToPolicies map[string][]*Policy
+ groupToRoutes map[string][]*route.Route
+ peerToRoutes map[string][]*route.Route
+
+ peerACLs map[string]*PeerACLView
+ peerRoutes map[string]*PeerRoutesView
+ peerDNS map[string]*nbdns.Config
+
+ resourceRouters map[string]map[string]*routerTypes.NetworkRouter
+ resourcePolicies map[string][]*Policy
+
+ globalResources map[string]*resourceTypes.NetworkResource // resourceId
+
+ acgToRoutes map[string]map[route.ID]*RouteOwnerInfo // routeID -> owner info
+ noACGRoutes map[route.ID]*RouteOwnerInfo
+
+ mu sync.RWMutex
+}
+
+type RouteOwnerInfo struct {
+ PeerID string
+ RouteID route.ID
+}
+
+type PeerACLView struct {
+ ConnectedPeerIDs []string
+ FirewallRuleIDs []string
+}
+
+type PeerRoutesView struct {
+ OwnRouteIDs []route.ID
+ NetworkResourceIDs []route.ID
+ InheritedRouteIDs []route.ID
+ RouteFirewallRuleIDs []string
+}
+
+type NetworkMapBuilder struct {
+ account atomic.Pointer[Account]
+ cache *NetworkMapCache
+ validatedPeers map[string]struct{}
+}
+
+func NewNetworkMapBuilder(account *Account, validatedPeers map[string]struct{}) *NetworkMapBuilder {
+ builder := &NetworkMapBuilder{
+ cache: &NetworkMapCache{
+ globalRoutes: make(map[route.ID]*route.Route),
+ globalRules: make(map[string]*FirewallRule),
+ globalRouteRules: make(map[string]*RouteFirewallRule),
+ globalPeers: make(map[string]*nbpeer.Peer),
+ groupToPeers: make(map[string][]string),
+ peerToGroups: make(map[string][]string),
+ policyToRules: make(map[string][]*PolicyRule),
+ groupToPolicies: make(map[string][]*Policy),
+ groupToRoutes: make(map[string][]*route.Route),
+ peerToRoutes: make(map[string][]*route.Route),
+ peerACLs: make(map[string]*PeerACLView),
+ peerRoutes: make(map[string]*PeerRoutesView),
+ peerDNS: make(map[string]*nbdns.Config),
+ globalResources: make(map[string]*resourceTypes.NetworkResource),
+ acgToRoutes: make(map[string]map[route.ID]*RouteOwnerInfo),
+ noACGRoutes: make(map[route.ID]*RouteOwnerInfo),
+ },
+ validatedPeers: make(map[string]struct{}),
+ }
+ builder.account.Store(account)
+ maps.Copy(builder.validatedPeers, validatedPeers)
+
+ builder.initialBuild(account)
+
+ return builder
+}
+
+func (b *NetworkMapBuilder) initialBuild(account *Account) {
+ b.cache.mu.Lock()
+ defer b.cache.mu.Unlock()
+
+ start := time.Now()
+
+ b.buildGlobalIndexes(account)
+
+ resourceRouters := account.GetResourceRoutersMap()
+ resourcePolicies := account.GetResourcePoliciesMap()
+ b.cache.resourceRouters = resourceRouters
+ b.cache.resourcePolicies = resourcePolicies
+
+ for peerID := range account.Peers {
+ b.buildPeerACLView(account, peerID)
+ b.buildPeerRoutesView(account, peerID)
+ b.buildPeerDNSView(account, peerID)
+ }
+
+ log.Debugf("NetworkMapBuilder: Initial build completed in %v for account %s", time.Since(start), account.Id)
+}
+
+func (b *NetworkMapBuilder) buildGlobalIndexes(account *Account) {
+ clear(b.cache.globalPeers)
+ clear(b.cache.groupToPeers)
+ clear(b.cache.peerToGroups)
+ clear(b.cache.policyToRules)
+ clear(b.cache.groupToPolicies)
+ clear(b.cache.globalRoutes)
+ clear(b.cache.globalRules)
+ clear(b.cache.globalRouteRules)
+ clear(b.cache.globalResources)
+ clear(b.cache.groupToRoutes)
+ clear(b.cache.peerToRoutes)
+ clear(b.cache.acgToRoutes)
+ clear(b.cache.noACGRoutes)
+
+ maps.Copy(b.cache.globalPeers, account.Peers)
+
+ for groupID, group := range account.Groups {
+ peersCopy := make([]string, len(group.Peers))
+ copy(peersCopy, group.Peers)
+ b.cache.groupToPeers[groupID] = peersCopy
+
+ for _, peerID := range group.Peers {
+ b.cache.peerToGroups[peerID] = append(b.cache.peerToGroups[peerID], groupID)
+ }
+ }
+
+ for _, policy := range account.Policies {
+ if !policy.Enabled {
+ continue
+ }
+
+ b.cache.policyToRules[policy.ID] = policy.Rules
+
+ affectedGroups := make(map[string]struct{})
+ for _, rule := range policy.Rules {
+ if !rule.Enabled {
+ continue
+ }
+
+ for _, groupID := range rule.Sources {
+ affectedGroups[groupID] = struct{}{}
+ }
+ for _, groupID := range rule.Destinations {
+ affectedGroups[groupID] = struct{}{}
+ }
+ if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
+ groupId := rule.SourceResource.ID
+ affectedGroups[groupId] = struct{}{}
+ b.cache.peerToGroups[rule.SourceResource.ID] = append(b.cache.peerToGroups[rule.SourceResource.ID], groupId)
+ }
+ if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
+ groupId := rule.DestinationResource.ID
+ affectedGroups[groupId] = struct{}{}
+ b.cache.peerToGroups[rule.DestinationResource.ID] = append(b.cache.peerToGroups[rule.DestinationResource.ID], groupId)
+ }
+ }
+
+ for groupID := range affectedGroups {
+ b.cache.groupToPolicies[groupID] = append(b.cache.groupToPolicies[groupID], policy)
+ }
+ }
+
+ for _, resource := range account.NetworkResources {
+ if !resource.Enabled {
+ continue
+ }
+ b.cache.globalResources[resource.ID] = resource
+ }
+
+ for _, r := range account.Routes {
+ if !r.Enabled {
+ continue
+ }
+ for _, groupID := range r.PeerGroups {
+ b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r)
+ }
+ if r.Peer != "" {
+ if peer, ok := b.cache.globalPeers[r.Peer]; ok {
+ b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r)
+ }
+ }
+ }
+}
+
+func (b *NetworkMapBuilder) buildPeerACLView(account *Account, peerID string) {
+ peer := account.GetPeer(peerID)
+ if peer == nil {
+ return
+ }
+
+ allPotentialPeers, firewallRules := b.getPeerConnectionResources(account, peer, b.validatedPeers)
+
+ isRouter, networkResourcesRoutes, sourcePeers := b.getNetworkResourcesForPeer(account, peer)
+
+ var emptyExpiredPeers []*nbpeer.Peer
+ finalAllPeers := b.addNetworksRoutingPeers(
+ networkResourcesRoutes,
+ peer,
+ allPotentialPeers,
+ emptyExpiredPeers,
+ isRouter,
+ sourcePeers,
+ )
+
+ view := &PeerACLView{
+ ConnectedPeerIDs: make([]string, 0, len(finalAllPeers)),
+ FirewallRuleIDs: make([]string, 0, len(firewallRules)),
+ }
+
+ for _, p := range finalAllPeers {
+ view.ConnectedPeerIDs = append(view.ConnectedPeerIDs, p.ID)
+ }
+
+ for _, rule := range firewallRules {
+ ruleID := b.generateFirewallRuleID(rule)
+ view.FirewallRuleIDs = append(view.FirewallRuleIDs, ruleID)
+ b.cache.globalRules[ruleID] = rule
+ }
+
+ b.cache.peerACLs[peerID] = view
+}
+
+func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *nbpeer.Peer,
+ validatedPeersMap map[string]struct{},
+) ([]*nbpeer.Peer, []*FirewallRule) {
+ peerID := peer.ID
+
+ peerGroups := b.cache.peerToGroups[peerID]
+ peerGroupsMap := make(map[string]struct{}, len(peerGroups))
+ for _, groupID := range peerGroups {
+ peerGroupsMap[groupID] = struct{}{}
+ }
+
+ rulesExists := make(map[string]struct{})
+ peersExists := make(map[string]struct{})
+ fwRules := make([]*FirewallRule, 0)
+ peers := make([]*nbpeer.Peer, 0)
+
+ for _, group := range peerGroups {
+ policies := b.cache.groupToPolicies[group]
+ for _, policy := range policies {
+ rules := b.cache.policyToRules[policy.ID]
+ for _, rule := range rules {
+ var sourcePeers, destinationPeers []*nbpeer.Peer
+ var peerInSources, peerInDestinations bool
+
+ if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
+ peerInSources = rule.SourceResource.ID == peerID
+ } else {
+ peerInSources = b.isPeerInGroupscached(rule.Sources, peerGroupsMap)
+ }
+
+ if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
+ peerInDestinations = rule.DestinationResource.ID == peerID
+ } else {
+ peerInDestinations = b.isPeerInGroupscached(rule.Destinations, peerGroupsMap)
+ }
+
+ if !peerInSources && !peerInDestinations {
+ continue
+ }
+
+ if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
+ peer := account.GetPeer(rule.SourceResource.ID)
+ if peer != nil {
+ sourcePeers = []*nbpeer.Peer{peer}
+ }
+ } else {
+ sourcePeers = b.getPeersFromGroupscached(account, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
+ }
+
+ if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
+ peer := account.GetPeer(rule.DestinationResource.ID)
+ if peer != nil {
+ destinationPeers = []*nbpeer.Peer{peer}
+ }
+ } else {
+ destinationPeers = b.getPeersFromGroupscached(account, rule.Destinations, peerID, nil, validatedPeersMap)
+ }
+
+ if rule.Bidirectional {
+ if peerInSources {
+ b.generateResourcescached(
+ account, rule, destinationPeers, FirewallRuleDirectionIN,
+ peer, &peers, &fwRules, peersExists, rulesExists,
+ )
+ }
+ if peerInDestinations {
+ b.generateResourcescached(
+ account, rule, sourcePeers, FirewallRuleDirectionOUT,
+ peer, &peers, &fwRules, peersExists, rulesExists,
+ )
+ }
+ }
+
+ if peerInSources {
+ b.generateResourcescached(
+ account, rule, destinationPeers, FirewallRuleDirectionOUT,
+ peer, &peers, &fwRules, peersExists, rulesExists,
+ )
+ }
+
+ if peerInDestinations {
+ b.generateResourcescached(
+ account, rule, sourcePeers, FirewallRuleDirectionIN,
+ peer, &peers, &fwRules, peersExists, rulesExists,
+ )
+ }
+ }
+ }
+ }
+
+ return peers, fwRules
+}
+
+func (b *NetworkMapBuilder) isPeerInGroupscached(groupIDs []string, peerGroupsMap map[string]struct{}) bool {
+ for _, groupID := range groupIDs {
+ if _, exists := peerGroupsMap[groupID]; exists {
+ return true
+ }
+ }
+ return false
+}
+
+func (b *NetworkMapBuilder) getPeersFromGroupscached(account *Account, groupIDs []string,
+ excludePeerID string, postureChecksIDs []string, validatedPeersMap map[string]struct{},
+) []*nbpeer.Peer {
+ ctx := context.Background()
+ uniquePeers := make(map[string]*nbpeer.Peer)
+
+ for _, groupID := range groupIDs {
+ peerIDs := b.cache.groupToPeers[groupID]
+ for _, peerID := range peerIDs {
+ if peerID == excludePeerID {
+ continue
+ }
+
+ if _, ok := validatedPeersMap[peerID]; !ok {
+ continue
+ }
+
+ peer := b.cache.globalPeers[peerID]
+ if peer == nil {
+ continue
+ }
+
+ if len(postureChecksIDs) > 0 {
+ if !account.validatePostureChecksOnPeer(ctx, postureChecksIDs, peerID) {
+ continue
+ }
+ }
+
+ uniquePeers[peerID] = peer
+ }
+ }
+
+ result := make([]*nbpeer.Peer, 0, len(uniquePeers))
+ for _, peer := range uniquePeers {
+ result = append(result, peer)
+ }
+
+ return result
+}
+
+func (b *NetworkMapBuilder) generateResourcescached(
+ account *Account, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, targetPeer *nbpeer.Peer,
+ peers *[]*nbpeer.Peer, rules *[]*FirewallRule, peersExists map[string]struct{}, rulesExists map[string]struct{},
+) {
+ isAll := false
+ if allGroup, err := account.GetGroupAll(); err == nil {
+ isAll = (len(allGroup.Peers) - 1) == len(groupPeers)
+ }
+
+ for _, peer := range groupPeers {
+ if peer == nil {
+ continue
+ }
+ if _, ok := peersExists[peer.ID]; !ok {
+ *peers = append(*peers, peer)
+ peersExists[peer.ID] = struct{}{}
+ }
+
+ fr := FirewallRule{
+ PolicyID: rule.ID,
+ PeerIP: peer.IP.String(),
+ Direction: direction,
+ Action: string(rule.Action),
+ Protocol: string(rule.Protocol),
+ }
+
+ if isAll {
+ fr.PeerIP = allPeers
+ }
+
+ var s strings.Builder
+ s.WriteString(rule.ID)
+ s.WriteString(fr.PeerIP)
+ s.WriteString(strconv.Itoa(direction))
+ s.WriteString(fr.Protocol)
+ s.WriteString(fr.Action)
+ s.WriteString(strings.Join(rule.Ports, ","))
+
+ ruleID := s.String()
+
+ if _, ok := rulesExists[ruleID]; ok {
+ continue
+ }
+ rulesExists[ruleID] = struct{}{}
+
+ if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 {
+ *rules = append(*rules, &fr)
+ continue
+ }
+
+ *rules = append(*rules, expandPortsAndRanges(fr, rule, targetPeer)...)
+ }
+}
+
+func (b *NetworkMapBuilder) getNetworkResourcesForPeer(account *Account, peer *nbpeer.Peer) (bool, []*route.Route, map[string]struct{}) {
+ ctx := context.Background()
+ peerID := peer.ID
+
+ var isRoutingPeer bool
+ var routes []*route.Route
+ allSourcePeers := make(map[string]struct{})
+
+ peerGroups := b.cache.peerToGroups[peerID]
+ peerGroupsMap := make(map[string]struct{}, len(peerGroups))
+ for _, groupID := range peerGroups {
+ peerGroupsMap[groupID] = struct{}{}
+ }
+
+ for _, resource := range b.cache.globalResources {
+
+ networkRoutingPeers := b.cache.resourceRouters[resource.NetworkID]
+ resourcePolicies := b.cache.resourcePolicies[resource.ID]
+ if len(resourcePolicies) == 0 {
+ continue
+ }
+
+ isRouterForThisResource := false
+
+ if networkRoutingPeers != nil {
+ if router, ok := networkRoutingPeers[peerID]; ok && router.Enabled {
+ isRoutingPeer = true
+ isRouterForThisResource = true
+ if rt := b.createNetworkResourceRoutes(resource, peerID, router, resourcePolicies); rt != nil {
+ routes = append(routes, rt)
+ }
+ }
+ }
+
+ hasAccessAsClient := false
+ if !isRouterForThisResource {
+ for _, policy := range resourcePolicies {
+ if b.isPeerInGroupscached(policy.SourceGroups(), peerGroupsMap) {
+ if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
+ hasAccessAsClient = true
+ break
+ }
+ }
+ }
+ }
+
+ if hasAccessAsClient && networkRoutingPeers != nil {
+ for routerPeerID, router := range networkRoutingPeers {
+ if router.Enabled {
+ if rt := b.createNetworkResourceRoutes(resource, routerPeerID, router, resourcePolicies); rt != nil {
+ routes = append(routes, rt)
+ }
+ }
+ }
+ }
+
+ if isRouterForThisResource {
+ for _, policy := range resourcePolicies {
+ var peersWithAccess []*nbpeer.Peer
+ if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
+ peersWithAccess = []*nbpeer.Peer{peer}
+ } else {
+ peersWithAccess = b.getPeersFromGroupscached(account, policy.SourceGroups(), "", policy.SourcePostureChecks, b.validatedPeers)
+ }
+ for _, p := range peersWithAccess {
+ allSourcePeers[p.ID] = struct{}{}
+ }
+ }
+ }
+ }
+
+ return isRoutingPeer, routes, allSourcePeers
+}
+
+func (b *NetworkMapBuilder) createNetworkResourceRoutes(
+ resource *resourceTypes.NetworkResource, routerPeerID string,
+ router *routerTypes.NetworkRouter, resourcePolicies []*Policy,
+) *route.Route {
+ if len(resourcePolicies) > 0 {
+ peer := b.cache.globalPeers[routerPeerID]
+ if peer != nil {
+ return resource.ToRoute(peer, router)
+ }
+ }
+ return nil
+}
+
+func (b *NetworkMapBuilder) addNetworksRoutingPeers(
+ networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer,
+ expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers map[string]struct{},
+) []*nbpeer.Peer {
+
+ networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes))
+ for _, r := range networkResourcesRoutes {
+ networkRoutesPeers[r.PeerID] = struct{}{}
+ }
+
+ delete(sourcePeers, peer.ID)
+ delete(networkRoutesPeers, peer.ID)
+
+ for _, existingPeer := range peersToConnect {
+ delete(sourcePeers, existingPeer.ID)
+ delete(networkRoutesPeers, existingPeer.ID)
+ }
+ for _, expPeer := range expiredPeers {
+ delete(sourcePeers, expPeer.ID)
+ delete(networkRoutesPeers, expPeer.ID)
+ }
+
+ missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers))
+ if isRouter {
+ for p := range sourcePeers {
+ missingPeers[p] = struct{}{}
+ }
+ }
+ for p := range networkRoutesPeers {
+ missingPeers[p] = struct{}{}
+ }
+
+ for p := range missingPeers {
+ if missingPeer := b.cache.globalPeers[p]; missingPeer != nil {
+ peersToConnect = append(peersToConnect, missingPeer)
+ }
+ }
+
+ return peersToConnect
+}
+
+func (b *NetworkMapBuilder) buildPeerRoutesView(account *Account, peerID string) {
+ ctx := context.Background()
+ peer := account.GetPeer(peerID)
+ if peer == nil {
+ return
+ }
+ resourcePolicies := b.cache.resourcePolicies
+
+ view := &PeerRoutesView{
+ OwnRouteIDs: make([]route.ID, 0),
+ NetworkResourceIDs: make([]route.ID, 0),
+ RouteFirewallRuleIDs: make([]string, 0),
+ }
+
+ enabledRoutes, disabledRoutes := b.getRoutingPeerRoutes(peerID)
+ for _, rt := range enabledRoutes {
+ if rt.PeerID != "" && rt.PeerID != peerID {
+ if b.cache.globalPeers[rt.PeerID] == nil {
+ continue
+ }
+ }
+
+ view.OwnRouteIDs = append(view.OwnRouteIDs, rt.ID)
+ b.cache.globalRoutes[rt.ID] = rt
+ }
+
+ aclView := b.cache.peerACLs[peerID]
+ if aclView != nil {
+ peerRoutesMembership := make(LookupMap)
+ for _, r := range append(enabledRoutes, disabledRoutes...) {
+ peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
+ }
+
+ peerGroups := b.cache.peerToGroups[peerID]
+ peerGroupsMap := make(LookupMap)
+ for _, groupID := range peerGroups {
+ peerGroupsMap[groupID] = struct{}{}
+ }
+
+ for _, aclPeerID := range aclView.ConnectedPeerIDs {
+ if aclPeerID == peerID {
+ continue
+ }
+ activeRoutes, _ := b.getRoutingPeerRoutes(aclPeerID)
+ groupFilteredRoutes := account.filterRoutesByGroups(activeRoutes, peerGroupsMap)
+ haFilteredRoutes := account.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
+
+ for _, inheritedRoute := range haFilteredRoutes {
+ view.InheritedRouteIDs = append(view.InheritedRouteIDs, inheritedRoute.ID)
+ b.cache.globalRoutes[inheritedRoute.ID] = inheritedRoute
+ }
+ }
+ }
+
+ _, networkResourcesRoutes, _ := b.getNetworkResourcesForPeer(account, peer)
+
+ for _, rt := range networkResourcesRoutes {
+ view.NetworkResourceIDs = append(view.NetworkResourceIDs, rt.ID)
+ b.cache.globalRoutes[rt.ID] = rt
+ }
+
+ allRoutes := slices.Concat(enabledRoutes, networkResourcesRoutes)
+ b.updateACGIndexForPeer(peerID, allRoutes)
+
+ routeFirewallRules := b.getPeerRoutesFirewallRules(account, peerID, b.validatedPeers)
+ for _, rule := range routeFirewallRules {
+ ruleID := b.generateRouteFirewallRuleID(rule)
+ view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID)
+ b.cache.globalRouteRules[ruleID] = rule
+ }
+
+ if len(networkResourcesRoutes) > 0 {
+ networkResourceFirewallRules := account.GetPeerNetworkResourceFirewallRules(ctx, peer, b.validatedPeers, networkResourcesRoutes, resourcePolicies)
+ for _, rule := range networkResourceFirewallRules {
+ ruleID := b.generateRouteFirewallRuleID(rule)
+ view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID)
+ b.cache.globalRouteRules[ruleID] = rule
+ }
+ }
+
+ b.cache.peerRoutes[peerID] = view
+}
+
+func (b *NetworkMapBuilder) updateACGIndexForPeer(peerID string, routes []*route.Route) {
+ for acg, routeMap := range b.cache.acgToRoutes {
+ for routeID, info := range routeMap {
+ if info.PeerID == peerID {
+ delete(routeMap, routeID)
+ }
+ }
+ if len(routeMap) == 0 {
+ delete(b.cache.acgToRoutes, acg)
+ }
+ }
+
+ for routeID, info := range b.cache.noACGRoutes {
+ if info.PeerID == peerID {
+ delete(b.cache.noACGRoutes, routeID)
+ }
+ }
+
+ for _, rt := range routes {
+ if !rt.Enabled {
+ continue
+ }
+
+ if len(rt.AccessControlGroups) == 0 {
+ b.cache.noACGRoutes[rt.ID] = &RouteOwnerInfo{
+ PeerID: peerID,
+ RouteID: rt.ID,
+ }
+ } else {
+ for _, acg := range rt.AccessControlGroups {
+ if b.cache.acgToRoutes[acg] == nil {
+ b.cache.acgToRoutes[acg] = make(map[route.ID]*RouteOwnerInfo)
+ }
+
+ b.cache.acgToRoutes[acg][rt.ID] = &RouteOwnerInfo{
+ PeerID: peerID,
+ RouteID: rt.ID,
+ }
+ }
+ }
+ }
+}
+
+func (b *NetworkMapBuilder) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
+ peer := b.cache.globalPeers[peerID]
+ if peer == nil {
+ return enabledRoutes, disabledRoutes
+ }
+
+ seenRoute := make(map[route.ID]struct{})
+
+ takeRoute := func(r *route.Route, id string) {
+ if _, ok := seenRoute[r.ID]; ok {
+ return
+ }
+ seenRoute[r.ID] = struct{}{}
+
+ if r.Enabled {
+ // maybe here is some mess - here we store peer key (see comment below)
+ r.Peer = peer.Key
+ enabledRoutes = append(enabledRoutes, r)
+ return
+ }
+ disabledRoutes = append(disabledRoutes, r)
+ }
+
+ peerGroups := b.cache.peerToGroups[peerID]
+ for _, groupID := range peerGroups {
+ groupRoutes := b.cache.groupToRoutes[groupID]
+ for _, r := range groupRoutes {
+ newPeerRoute := r.Copy()
+ // and here we store peer ID - this logic is taken from original account.getRoutingPeerRoutes
+ newPeerRoute.Peer = peerID
+ newPeerRoute.PeerGroups = nil
+ newPeerRoute.ID = route.ID(string(r.ID) + ":" + peerID)
+ takeRoute(newPeerRoute, peerID)
+ }
+ }
+ for _, r := range b.cache.peerToRoutes[peerID] {
+ takeRoute(r.Copy(), peerID)
+ }
+ return enabledRoutes, disabledRoutes
+}
+
+func (b *NetworkMapBuilder) getPeerRoutesFirewallRules(account *Account, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
+ routesFirewallRules := make([]*RouteFirewallRule, 0)
+
+ enabledRoutes, _ := b.getRoutingPeerRoutes(peerID)
+ for _, route := range enabledRoutes {
+ if len(route.AccessControlGroups) == 0 {
+ defaultPermit := getDefaultPermit(route)
+ routesFirewallRules = append(routesFirewallRules, defaultPermit...)
+ continue
+ }
+
+ distributionPeers := b.getDistributionGroupsPeers(route)
+
+ for _, accessGroup := range route.AccessControlGroups {
+ policies := b.getAllRoutePoliciesFromGroups([]string{accessGroup})
+
+ rules := b.getRouteFirewallRules(peerID, policies, route, validatedPeersMap, distributionPeers, account)
+ routesFirewallRules = append(routesFirewallRules, rules...)
+ }
+ }
+
+ return routesFirewallRules
+}
+
+func (b *NetworkMapBuilder) getDistributionGroupsPeers(route *route.Route) map[string]struct{} {
+ distPeers := make(map[string]struct{})
+ for _, id := range route.Groups {
+ groupPeers := b.cache.groupToPeers[id]
+ if groupPeers == nil {
+ continue
+ }
+
+ for _, pID := range groupPeers {
+ distPeers[pID] = struct{}{}
+ }
+ }
+ return distPeers
+}
+
+func (b *NetworkMapBuilder) getAllRoutePoliciesFromGroups(accessControlGroups []string) []*Policy {
+ routePolicies := make(map[string]*Policy)
+
+ for _, groupID := range accessControlGroups {
+ candidatePolicies := b.cache.groupToPolicies[groupID]
+
+ for _, policy := range candidatePolicies {
+ if _, found := routePolicies[policy.ID]; found {
+ continue
+ }
+ policyRules := b.cache.policyToRules[policy.ID]
+ for _, rule := range policyRules {
+ if slices.Contains(rule.Destinations, groupID) {
+ routePolicies[policy.ID] = policy
+ break
+ }
+ }
+ }
+ }
+
+ return maps.Values(routePolicies)
+}
+
+func (b *NetworkMapBuilder) getRouteFirewallRules(
+ peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{},
+ distributionPeers map[string]struct{}, account *Account,
+) []*RouteFirewallRule {
+ ctx := context.Background()
+ var fwRules []*RouteFirewallRule
+ for _, policy := range policies {
+ if !policy.Enabled {
+ continue
+ }
+
+ for _, rule := range policy.Rules {
+ if !rule.Enabled {
+ continue
+ }
+
+ rulePeers := b.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap, account)
+
+ rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN)
+ fwRules = append(fwRules, rules...)
+ }
+ }
+ return fwRules
+}
+
+func (b *NetworkMapBuilder) getRulePeers(
+ rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{},
+ validatedPeersMap map[string]struct{}, account *Account,
+) []*nbpeer.Peer {
+ distPeersWithPolicy := make(map[string]struct{})
+
+ for _, id := range rule.Sources {
+ groupPeers := b.cache.groupToPeers[id]
+ if groupPeers == nil {
+ continue
+ }
+
+ for _, pID := range groupPeers {
+ if pID == peerID {
+ continue
+ }
+ _, distPeer := distributionPeers[pID]
+ _, valid := validatedPeersMap[pID]
+
+ if distPeer && valid && account.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) {
+ distPeersWithPolicy[pID] = struct{}{}
+ }
+ }
+ }
+
+ if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
+ _, distPeer := distributionPeers[rule.SourceResource.ID]
+ _, valid := validatedPeersMap[rule.SourceResource.ID]
+ if distPeer && valid && account.validatePostureChecksOnPeer(context.Background(), postureChecks, rule.SourceResource.ID) {
+ distPeersWithPolicy[rule.SourceResource.ID] = struct{}{}
+ }
+ }
+
+ distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy))
+ for pID := range distPeersWithPolicy {
+ peer := b.cache.globalPeers[pID]
+ if peer == nil {
+ continue
+ }
+ distributionGroupPeers = append(distributionGroupPeers, peer)
+ }
+ return distributionGroupPeers
+}
+
+func (b *NetworkMapBuilder) buildPeerDNSView(account *Account, peerID string) {
+ peerGroups := b.cache.peerToGroups[peerID]
+ checkGroups := make(map[string]struct{}, len(peerGroups))
+ for _, groupID := range peerGroups {
+ checkGroups[groupID] = struct{}{}
+ }
+
+ dnsManagementStatus := b.getPeerDNSManagementStatus(account, checkGroups)
+ dnsConfig := &nbdns.Config{
+ ServiceEnable: dnsManagementStatus,
+ }
+
+ if dnsManagementStatus {
+ dnsConfig.NameServerGroups = b.getPeerNSGroups(account, peerID, checkGroups)
+ }
+
+ b.cache.peerDNS[peerID] = dnsConfig
+}
+
+func (b *NetworkMapBuilder) getPeerDNSManagementStatus(account *Account, checkGroups map[string]struct{}) bool {
+
+ enabled := true
+ for _, groupID := range account.DNSSettings.DisabledManagementGroups {
+ _, found := checkGroups[groupID]
+ if found {
+ enabled = false
+ break
+ }
+ }
+ return enabled
+}
+
+func (b *NetworkMapBuilder) getPeerNSGroups(account *Account, peerID string, checkGroups map[string]struct{}) []*nbdns.NameServerGroup {
+ var peerNSGroups []*nbdns.NameServerGroup
+
+ for _, nsGroup := range account.NameServerGroups {
+ if !nsGroup.Enabled {
+ continue
+ }
+ for _, gID := range nsGroup.Groups {
+ _, found := checkGroups[gID]
+ if found {
+ peer := b.cache.globalPeers[peerID]
+ if !peerIsNameserver(peer, nsGroup) {
+ peerNSGroups = append(peerNSGroups, nsGroup.Copy())
+ break
+ }
+ }
+ }
+ }
+
+ return peerNSGroups
+}
+
+func (b *NetworkMapBuilder) UpdateAccountPointer(account *Account) {
+ b.account.Store(account)
+}
+
+func (b *NetworkMapBuilder) GetPeerNetworkMap(
+ ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone,
+ validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics,
+) *NetworkMap {
+ start := time.Now()
+ account := b.account.Load()
+
+ peer := account.GetPeer(peerID)
+ if peer == nil {
+ return &NetworkMap{Network: account.Network.Copy()}
+ }
+
+ b.cache.mu.RLock()
+ defer b.cache.mu.RUnlock()
+
+ aclView := b.cache.peerACLs[peerID]
+ routesView := b.cache.peerRoutes[peerID]
+ dnsConfig := b.cache.peerDNS[peerID]
+
+ if aclView == nil || routesView == nil || dnsConfig == nil {
+ return &NetworkMap{Network: account.Network.Copy()}
+ }
+
+ nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, peersCustomZone, validatedPeers)
+
+ if metrics != nil {
+ objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
+ metrics.CountNetworkMapObjects(objectCount)
+ metrics.CountGetPeerNetworkMapDuration(time.Since(start))
+
+ if objectCount > 5000 {
+ log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from cache",
+ account.Id, objectCount)
+ }
+ }
+
+ return nm
+}
+
+func (b *NetworkMapBuilder) assembleNetworkMap(
+ account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView,
+ dnsConfig *nbdns.Config, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
+) *NetworkMap {
+
+ var peersToConnect []*nbpeer.Peer
+ var expiredPeers []*nbpeer.Peer
+
+ for _, peerID := range aclView.ConnectedPeerIDs {
+ if _, ok := validatedPeers[peerID]; !ok {
+ continue
+ }
+
+ peer := b.cache.globalPeers[peerID]
+ if peer == nil {
+ continue
+ }
+
+ expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration)
+ if account.Settings.PeerLoginExpirationEnabled && expired {
+ expiredPeers = append(expiredPeers, peer)
+ } else {
+ peersToConnect = append(peersToConnect, peer)
+ }
+ }
+
+ var routes []*route.Route
+ allRouteIDs := slices.Concat(routesView.OwnRouteIDs, routesView.NetworkResourceIDs, routesView.InheritedRouteIDs)
+
+ for _, routeID := range allRouteIDs {
+ if route := b.cache.globalRoutes[routeID]; route != nil {
+ routes = append(routes, route)
+ }
+ }
+
+ var firewallRules []*FirewallRule
+ for _, ruleID := range aclView.FirewallRuleIDs {
+ if rule := b.cache.globalRules[ruleID]; rule != nil {
+ firewallRules = append(firewallRules, rule)
+ }
+ }
+
+ var routesFirewallRules []*RouteFirewallRule
+ for _, ruleID := range routesView.RouteFirewallRuleIDs {
+ if rule := b.cache.globalRouteRules[ruleID]; rule != nil {
+ routesFirewallRules = append(routesFirewallRules, rule)
+ }
+ }
+
+ finalDNSConfig := *dnsConfig
+ if finalDNSConfig.ServiceEnable && customZone.Domain != "" {
+ var zones []nbdns.CustomZone
+ records := filterZoneRecordsForPeers(peer, customZone, peersToConnect, expiredPeers)
+ zones = append(zones, nbdns.CustomZone{
+ Domain: customZone.Domain,
+ Records: records,
+ })
+ finalDNSConfig.CustomZones = zones
+ }
+
+ return &NetworkMap{
+ Peers: peersToConnect,
+ Network: account.Network.Copy(),
+ Routes: routes,
+ DNSConfig: finalDNSConfig,
+ OfflinePeers: expiredPeers,
+ FirewallRules: firewallRules,
+ RoutesFirewallRules: routesFirewallRules,
+ }
+}
+
+func (b *NetworkMapBuilder) generateFirewallRuleID(rule *FirewallRule) string {
+ var s strings.Builder
+ s.WriteString(fw)
+ s.WriteString(rule.PolicyID)
+ s.WriteRune(':')
+ s.WriteString(rule.PeerIP)
+ s.WriteRune(':')
+ s.WriteString(strconv.Itoa(rule.Direction))
+ s.WriteRune(':')
+ s.WriteString(rule.Protocol)
+ s.WriteRune(':')
+ s.WriteString(rule.Action)
+ s.WriteRune(':')
+ s.WriteString(rule.Port)
+ s.WriteRune(':')
+ s.WriteString(strconv.Itoa(int(rule.PortRange.Start)))
+ s.WriteRune('-')
+ s.WriteString(strconv.Itoa(int(rule.PortRange.End)))
+ return s.String()
+}
+
+func (b *NetworkMapBuilder) generateRouteFirewallRuleID(rule *RouteFirewallRule) string {
+ var s strings.Builder
+ s.WriteString(rfw)
+ s.WriteString(string(rule.RouteID))
+ s.WriteRune(':')
+ s.WriteString(rule.Destination)
+ s.WriteRune(':')
+ s.WriteString(rule.Action)
+ s.WriteRune(':')
+ s.WriteString(strings.Join(rule.SourceRanges, ","))
+ s.WriteRune(':')
+ s.WriteString(rule.Protocol)
+ s.WriteRune(':')
+ s.WriteString(strconv.Itoa(int(rule.Port)))
+ return s.String()
+}
+
+func (b *NetworkMapBuilder) isPeerInGroups(groupIDs []string, peerGroups []string) bool {
+ for _, groupID := range groupIDs {
+ if slices.Contains(peerGroups, groupID) {
+ return true
+ }
+ }
+ return false
+}
+
+func (b *NetworkMapBuilder) isPeerRouter(account *Account, peerID string) bool {
+ for _, r := range account.Routes {
+ if !r.Enabled {
+ continue
+ }
+
+ if r.PeerID == peerID {
+ return true
+ }
+
+ if peer := b.cache.globalPeers[peerID]; peer != nil {
+ if r.Peer == peer.Key && r.PeerID == "" {
+ return true
+ }
+ }
+ }
+
+ routers := account.GetResourceRoutersMap()
+ for _, networkRouters := range routers {
+ if router, exists := networkRouters[peerID]; exists && router.Enabled {
+ return true
+ }
+ }
+
+ return false
+}
+
+type ViewDelta struct {
+ AddedPeerIDs []string
+ RemovedPeerIDs []string
+ AddedRuleIDs []string
+ RemovedRuleIDs []string
+}
+
+func (b *NetworkMapBuilder) OnPeerAddedIncremental(peerID string) error {
+ tt := time.Now()
+ account := b.account.Load()
+ peer := account.GetPeer(peerID)
+ if peer == nil {
+ return fmt.Errorf("peer %s not found in account", peerID)
+ }
+
+ b.cache.mu.Lock()
+ defer b.cache.mu.Unlock()
+
+ log.Debugf("NetworkMapBuilder: Adding peer %s (IP: %s) to cache", peerID, peer.IP.String())
+
+ b.validatedPeers[peerID] = struct{}{}
+
+ b.cache.globalPeers[peerID] = peer
+
+ peerGroups := b.updateIndexesForNewPeer(account, peerID)
+
+ b.buildPeerACLView(account, peerID)
+ b.buildPeerRoutesView(account, peerID)
+ b.buildPeerDNSView(account, peerID)
+
+ log.Debugf("NetworkMapBuilder: Adding peer %s to cache, views took %s", peerID, time.Since(tt))
+
+ b.incrementalUpdateAffectedPeers(account, peerID, peerGroups)
+
+ log.Debugf("NetworkMapBuilder: Added peer %s to cache, took %s", peerID, time.Since(tt))
+
+ return nil
+}
+
+func (b *NetworkMapBuilder) updateIndexesForNewPeer(account *Account, peerID string) []string {
+ peerGroups := make([]string, 0)
+
+ for groupID, group := range account.Groups {
+ if slices.Contains(group.Peers, peerID) {
+ if !slices.Contains(b.cache.groupToPeers[groupID], peerID) {
+ b.cache.groupToPeers[groupID] = append(b.cache.groupToPeers[groupID], peerID)
+ }
+ peerGroups = append(peerGroups, groupID)
+ }
+ }
+
+ b.cache.peerToGroups[peerID] = peerGroups
+
+ for _, r := range account.Routes {
+ if !r.Enabled || b.cache.globalRoutes[r.ID] != nil {
+ continue
+ }
+ for _, groupID := range r.PeerGroups {
+ if !slices.Contains(b.cache.groupToRoutes[groupID], r) {
+ b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r)
+ }
+ }
+ if r.Peer != "" {
+ if peer, ok := b.cache.globalPeers[r.Peer]; ok {
+ if !slices.Contains(b.cache.peerToRoutes[peer.ID], r) {
+ b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r)
+ }
+ }
+ }
+ b.cache.globalRoutes[r.ID] = r
+ }
+
+ return peerGroups
+}
+
+func (b *NetworkMapBuilder) incrementalUpdateAffectedPeers(account *Account, newPeerID string, peerGroups []string) {
+ updates := b.calculateIncrementalUpdates(account, newPeerID, peerGroups)
+
+ if b.isPeerRouter(account, newPeerID) {
+ affectedByRoutes := b.findPeersAffectedByNewRouter(account, newPeerID, peerGroups)
+ for affectedPeerID := range affectedByRoutes {
+ if affectedPeerID == newPeerID {
+ continue
+ }
+ if _, exists := updates[affectedPeerID]; !exists {
+ updates[affectedPeerID] = &PeerUpdateDelta{
+ PeerID: affectedPeerID,
+ RebuildRoutesView: true,
+ }
+ } else {
+ updates[affectedPeerID].RebuildRoutesView = true
+ }
+ }
+ }
+
+ for affectedPeerID, delta := range updates {
+ b.applyDeltaToPeer(account, affectedPeerID, delta)
+ }
+}
+
+func (b *NetworkMapBuilder) findPeersAffectedByNewRouter(account *Account, newRouterID string, routerGroups []string) map[string]struct{} {
+ affected := make(map[string]struct{})
+ enabledRoutes, _ := b.getRoutingPeerRoutes(newRouterID)
+
+ for _, route := range enabledRoutes {
+ for _, distGroupID := range route.Groups {
+ if peers := b.cache.groupToPeers[distGroupID]; peers != nil {
+ for _, peerID := range peers {
+ if peerID != newRouterID {
+ affected[peerID] = struct{}{}
+ }
+ }
+ }
+ }
+
+ for _, peerGroupID := range route.PeerGroups {
+ if peers := b.cache.groupToPeers[peerGroupID]; peers != nil {
+ for _, peerID := range peers {
+ if peerID != newRouterID {
+ affected[peerID] = struct{}{}
+ }
+ }
+ }
+ }
+ }
+
+ for _, route := range account.Routes {
+ if !route.Enabled {
+ continue
+ }
+
+ routerInPeerGroups := false
+ for _, peerGroupID := range route.PeerGroups {
+ if slices.Contains(routerGroups, peerGroupID) {
+ routerInPeerGroups = true
+ break
+ }
+ }
+
+ if routerInPeerGroups {
+ for _, distGroupID := range route.Groups {
+ if peers := b.cache.groupToPeers[distGroupID]; peers != nil {
+ for _, peerID := range peers {
+ affected[peerID] = struct{}{}
+ }
+ }
+ }
+ }
+ }
+
+ return affected
+}
+
+func (b *NetworkMapBuilder) calculateIncrementalUpdates(account *Account, newPeerID string, peerGroups []string) map[string]*PeerUpdateDelta {
+ updates := make(map[string]*PeerUpdateDelta)
+ ctx := context.Background()
+
+ groupAllLn := 0
+ if allGroup, err := account.GetGroupAll(); err == nil {
+ groupAllLn = len(allGroup.Peers) - 1
+ }
+
+ newPeer := b.cache.globalPeers[newPeerID]
+ if newPeer == nil {
+ return updates
+ }
+
+ for _, policy := range account.Policies {
+ if !policy.Enabled {
+ continue
+ }
+
+ for _, rule := range policy.Rules {
+ if !rule.Enabled {
+ continue
+ }
+ var peerInSources, peerInDestinations bool
+
+ if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID == newPeerID {
+ peerInSources = true
+ } else {
+ peerInSources = b.isPeerInGroups(rule.Sources, peerGroups)
+ }
+
+ if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID == newPeerID {
+ peerInDestinations = true
+ } else {
+ peerInDestinations = b.isPeerInGroups(rule.Destinations, peerGroups)
+ }
+
+ if peerInSources {
+ if len(rule.Destinations) > 0 {
+ b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn)
+ }
+ if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
+ b.addUpdateForDirectPeerResource(updates, rule.DestinationResource.ID, newPeerID, rule, FirewallRuleDirectionIN)
+ }
+ }
+
+ if peerInDestinations {
+ if len(rule.Sources) > 0 {
+ b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn)
+ }
+ if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
+ b.addUpdateForDirectPeerResource(updates, rule.SourceResource.ID, newPeerID, rule, FirewallRuleDirectionOUT)
+ }
+ }
+
+ if rule.Bidirectional {
+ if peerInSources {
+ if len(rule.Destinations) > 0 {
+ b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn)
+ }
+ if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
+ b.addUpdateForDirectPeerResource(updates, rule.DestinationResource.ID, newPeerID, rule, FirewallRuleDirectionOUT)
+ }
+ }
+ if peerInDestinations {
+ if len(rule.Sources) > 0 {
+ b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn)
+ }
+ if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
+ b.addUpdateForDirectPeerResource(updates, rule.SourceResource.ID, newPeerID, rule, FirewallRuleDirectionIN)
+ }
+ }
+ }
+ }
+ }
+
+ b.calculateRouteFirewallUpdates(newPeerID, newPeer, peerGroups, updates)
+
+ b.calculateNetworkResourceFirewallUpdates(ctx, account, newPeerID, newPeer, peerGroups, updates)
+
+ b.calculateNewRouterNetworkResourceUpdates(ctx, account, newPeerID, updates)
+
+ return updates
+}
+
+func (b *NetworkMapBuilder) calculateNewRouterNetworkResourceUpdates(
+ ctx context.Context, account *Account, newPeerID string,
+ updates map[string]*PeerUpdateDelta,
+) {
+ resourceRouters := b.cache.resourceRouters
+
+ for networkID, routers := range resourceRouters {
+ router, isRouter := routers[newPeerID]
+ if !isRouter || !router.Enabled {
+ continue
+ }
+
+ for _, resource := range b.cache.globalResources {
+ if resource.NetworkID != networkID {
+ continue
+ }
+
+ policies := b.cache.resourcePolicies[resource.ID]
+ if len(policies) == 0 {
+ continue
+ }
+
+ peersWithAccess := make(map[string]struct{})
+
+ for _, policy := range policies {
+ if !policy.Enabled {
+ continue
+ }
+
+ sourceGroups := policy.SourceGroups()
+ for _, sourceGroup := range sourceGroups {
+ groupPeers := b.cache.groupToPeers[sourceGroup]
+ for _, peerID := range groupPeers {
+ if peerID == newPeerID {
+ continue
+ }
+
+ if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
+ peersWithAccess[peerID] = struct{}{}
+ }
+ }
+ }
+ }
+
+ for peerID := range peersWithAccess {
+ delta := updates[peerID]
+ if delta == nil {
+ delta = &PeerUpdateDelta{
+ PeerID: peerID,
+ }
+ updates[peerID] = delta
+ }
+
+ if delta.AddConnectedPeer == "" {
+ delta.AddConnectedPeer = newPeerID
+ }
+
+ delta.RebuildRoutesView = true
+ }
+ }
+ }
+}
+
+func (b *NetworkMapBuilder) calculateRouteFirewallUpdates(
+ newPeerID string, newPeer *nbpeer.Peer,
+ peerGroups []string, updates map[string]*PeerUpdateDelta,
+) {
+ processedPeerRoutes := make(map[string]map[route.ID]struct{})
+
+ for routeID, info := range b.cache.noACGRoutes {
+ if info.PeerID == newPeerID {
+ continue
+ }
+
+ b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String())
+
+ if processedPeerRoutes[info.PeerID] == nil {
+ processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{})
+ }
+ processedPeerRoutes[info.PeerID][routeID] = struct{}{}
+ }
+
+ for _, acg := range peerGroups {
+ routeInfos := b.cache.acgToRoutes[acg]
+ if routeInfos == nil {
+ continue
+ }
+
+ for routeID, info := range routeInfos {
+ if info.PeerID == newPeerID {
+ continue
+ }
+
+ if processedRoutes, exists := processedPeerRoutes[info.PeerID]; exists {
+ if _, processed := processedRoutes[routeID]; processed {
+ continue
+ }
+ }
+
+ b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String())
+
+ if processedPeerRoutes[info.PeerID] == nil {
+ processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{})
+ }
+ processedPeerRoutes[info.PeerID][routeID] = struct{}{}
+ }
+ }
+}
+
+func (b *NetworkMapBuilder) addRouteFirewallUpdate(
+ updates map[string]*PeerUpdateDelta, peerID string,
+ routeID string, sourceIP string,
+) {
+ delta := updates[peerID]
+ if delta == nil {
+ delta = &PeerUpdateDelta{
+ PeerID: peerID,
+ UpdateRouteFirewallRules: make([]*RouteFirewallRuleUpdate, 0),
+ }
+ updates[peerID] = delta
+ }
+
+ for _, existing := range delta.UpdateRouteFirewallRules {
+ if existing.RuleID == routeID && existing.AddSourceIP == sourceIP {
+ return
+ }
+ }
+
+ delta.UpdateRouteFirewallRules = append(delta.UpdateRouteFirewallRules, &RouteFirewallRuleUpdate{
+ RuleID: routeID,
+ AddSourceIP: sourceIP,
+ })
+}
+
+func (b *NetworkMapBuilder) calculateNetworkResourceFirewallUpdates(
+ ctx context.Context, account *Account, newPeerID string,
+ newPeer *nbpeer.Peer, peerGroups []string, updates map[string]*PeerUpdateDelta,
+) {
+ for _, resource := range b.cache.globalResources {
+ resourcePolicies := b.cache.resourcePolicies
+ resourceRouters := b.cache.resourceRouters
+
+ policies := resourcePolicies[resource.ID]
+ peerHasAccess := false
+
+ for _, policy := range policies {
+ if !policy.Enabled {
+ continue
+ }
+
+ sourceGroups := policy.SourceGroups()
+ for _, sourceGroup := range sourceGroups {
+ if slices.Contains(peerGroups, sourceGroup) {
+ if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, newPeerID) {
+ peerHasAccess = true
+ break
+ }
+ }
+ }
+
+ if peerHasAccess {
+ break
+ }
+ }
+
+ if !peerHasAccess {
+ continue
+ }
+
+ networkRouters := resourceRouters[resource.NetworkID]
+ for routerPeerID, router := range networkRouters {
+ if !router.Enabled || routerPeerID == newPeerID {
+ continue
+ }
+
+ delta := updates[routerPeerID]
+ if delta == nil {
+ delta = &PeerUpdateDelta{
+ PeerID: routerPeerID,
+ }
+ updates[routerPeerID] = delta
+ }
+
+ if delta.AddConnectedPeer == "" {
+ delta.AddConnectedPeer = newPeerID
+ }
+
+ delta.RebuildRoutesView = true
+ }
+ }
+}
+
+type PeerUpdateDelta struct {
+ PeerID string
+ AddConnectedPeer string
+ AddFirewallRules []*FirewallRuleDelta
+ AddRoutes []route.ID
+ UpdateRouteFirewallRules []*RouteFirewallRuleUpdate
+ UpdateDNS bool
+ RebuildRoutesView bool
+}
+type FirewallRuleDelta struct {
+ Rule *FirewallRule
+ RuleID string
+ Direction int
+}
+
+type RouteFirewallRuleUpdate struct {
+ RuleID string
+ AddSourceIP string
+}
+
+func (b *NetworkMapBuilder) addUpdateForPeersInGroups(
+ updates map[string]*PeerUpdateDelta, groupIDs []string, newPeerID string,
+ rule *PolicyRule, direction int, allGroupLn int,
+) {
+ for _, groupID := range groupIDs {
+ peers := b.cache.groupToPeers[groupID]
+ cnt := 0
+ for _, peerID := range peers {
+ if peerID == newPeerID {
+ continue
+ }
+ if _, ok := b.validatedPeers[peerID]; !ok {
+ continue
+ }
+ cnt++
+ }
+ all := false
+ if allGroupLn > 0 && cnt == allGroupLn {
+ all = true
+ }
+ newPeer := b.cache.globalPeers[newPeerID]
+ fr := &FirewallRule{
+ PolicyID: rule.ID,
+ PeerIP: newPeer.IP.String(),
+ Direction: direction,
+ Action: string(rule.Action),
+ Protocol: string(rule.Protocol),
+ }
+ for _, peerID := range peers {
+ if peerID == newPeerID {
+ continue
+ }
+ if _, ok := b.validatedPeers[peerID]; !ok {
+ continue
+ }
+ targetPeer := b.cache.globalPeers[peerID]
+ if targetPeer == nil {
+ continue
+ }
+
+ peerIPForRule := fr.PeerIP
+ if all {
+ peerIPForRule = allPeers
+ }
+
+ b.addOrUpdateFirewallRuleInDelta(updates, peerID, newPeerID, rule, direction, fr, peerIPForRule, targetPeer)
+ }
+ }
+}
+
+func (b *NetworkMapBuilder) addUpdateForDirectPeerResource(
+ updates map[string]*PeerUpdateDelta, targetPeerID string, newPeerID string,
+ rule *PolicyRule, direction int,
+) {
+ if targetPeerID == newPeerID {
+ return
+ }
+
+ if _, ok := b.validatedPeers[targetPeerID]; !ok {
+ return
+ }
+
+ newPeer := b.cache.globalPeers[newPeerID]
+ if newPeer == nil {
+ return
+ }
+
+ targetPeer := b.cache.globalPeers[targetPeerID]
+ if targetPeer == nil {
+ return
+ }
+
+ fr := &FirewallRule{
+ PolicyID: rule.ID,
+ PeerIP: newPeer.IP.String(),
+ Direction: direction,
+ Action: string(rule.Action),
+ Protocol: string(rule.Protocol),
+ }
+
+ b.addOrUpdateFirewallRuleInDelta(updates, targetPeerID, newPeerID, rule, direction, fr, fr.PeerIP, targetPeer)
+}
+
+func (b *NetworkMapBuilder) addOrUpdateFirewallRuleInDelta(
+ updates map[string]*PeerUpdateDelta, targetPeerID string, newPeerID string,
+ rule *PolicyRule, direction int, baseRule *FirewallRule, peerIP string, targetPeer *nbpeer.Peer,
+) {
+ delta := updates[targetPeerID]
+ if delta == nil {
+ delta = &PeerUpdateDelta{
+ PeerID: targetPeerID,
+ AddConnectedPeer: newPeerID,
+ AddFirewallRules: make([]*FirewallRuleDelta, 0),
+ }
+ updates[targetPeerID] = delta
+ }
+
+ baseRule.PeerIP = peerIP
+
+ if len(rule.Ports) > 0 || len(rule.PortRanges) > 0 {
+ expandedRules := expandPortsAndRanges(*baseRule, rule, targetPeer)
+ for _, expandedRule := range expandedRules {
+ ruleID := b.generateFirewallRuleID(expandedRule)
+ delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{
+ Rule: expandedRule,
+ RuleID: ruleID,
+ Direction: direction,
+ })
+ }
+ } else {
+ ruleID := b.generateFirewallRuleID(baseRule)
+ delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{
+ Rule: baseRule,
+ RuleID: ruleID,
+ Direction: direction,
+ })
+ }
+}
+
+func (b *NetworkMapBuilder) applyDeltaToPeer(account *Account, peerID string, delta *PeerUpdateDelta) {
+ if delta.AddConnectedPeer != "" || len(delta.AddFirewallRules) > 0 {
+ if aclView := b.cache.peerACLs[peerID]; aclView != nil {
+ if delta.AddConnectedPeer != "" && !slices.Contains(aclView.ConnectedPeerIDs, delta.AddConnectedPeer) {
+ aclView.ConnectedPeerIDs = append(aclView.ConnectedPeerIDs, delta.AddConnectedPeer)
+ }
+
+ for _, ruleDelta := range delta.AddFirewallRules {
+ b.cache.globalRules[ruleDelta.RuleID] = ruleDelta.Rule
+
+ if !slices.Contains(aclView.FirewallRuleIDs, ruleDelta.RuleID) {
+ aclView.FirewallRuleIDs = append(aclView.FirewallRuleIDs, ruleDelta.RuleID)
+ }
+ }
+ }
+ }
+
+ if delta.RebuildRoutesView {
+ b.buildPeerRoutesView(account, peerID)
+ } else if len(delta.UpdateRouteFirewallRules) > 0 {
+ if routesView := b.cache.peerRoutes[peerID]; routesView != nil {
+ b.updateRouteFirewallRules(routesView, delta.UpdateRouteFirewallRules)
+ }
+ }
+
+ if delta.UpdateDNS {
+ b.buildPeerDNSView(account, peerID)
+ }
+}
+
+func (b *NetworkMapBuilder) updateRouteFirewallRules(routesView *PeerRoutesView, updates []*RouteFirewallRuleUpdate) {
+ for _, update := range updates {
+ for _, ruleID := range routesView.RouteFirewallRuleIDs {
+ rule := b.cache.globalRouteRules[ruleID]
+ if rule == nil {
+ continue
+ }
+
+ if string(rule.RouteID) == update.RuleID {
+ if hasWildcard := slices.Contains(rule.SourceRanges, allWildcard) || slices.Contains(rule.SourceRanges, v6AllWildcard); hasWildcard {
+ break
+ }
+
+ sourceIP := update.AddSourceIP
+
+ if strings.Contains(sourceIP, ":") {
+ sourceIP += "/128" // IPv6
+ } else {
+ sourceIP += "/32" // IPv4
+ }
+
+ if !slices.Contains(rule.SourceRanges, sourceIP) {
+ rule.SourceRanges = append(rule.SourceRanges, sourceIP)
+ }
+ break
+ }
+ }
+ }
+}
+
+func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error {
+ b.cache.mu.Lock()
+ defer b.cache.mu.Unlock()
+
+ account := b.account.Load()
+
+ deletedPeer := b.cache.globalPeers[peerID]
+ if deletedPeer == nil {
+ return fmt.Errorf("peer %s not found in cache", peerID)
+ }
+
+ deletedPeerKey := deletedPeer.Key
+ peerGroups := b.cache.peerToGroups[peerID]
+ peerIP := deletedPeer.IP.String()
+
+ log.Debugf("NetworkMapBuilder: Deleting peer %s (IP: %s) from cache", peerID, peerIP)
+
+ delete(b.validatedPeers, peerID)
+
+ routesToDelete := []route.ID{}
+
+ for routeID, r := range account.Routes {
+ if r.Peer != deletedPeerKey && r.PeerID != peerID {
+ continue
+ }
+ if len(r.PeerGroups) == 0 {
+ routesToDelete = append(routesToDelete, routeID)
+ continue
+ }
+ newPeerAssigned := false
+ for _, groupID := range r.PeerGroups {
+ candidatePeerIDs := b.cache.groupToPeers[groupID]
+ for _, candidatePeerID := range candidatePeerIDs {
+ if candidatePeerID == peerID {
+ continue
+ }
+ if candidatePeer := b.cache.globalPeers[candidatePeerID]; candidatePeer != nil {
+ r.Peer = candidatePeer.Key
+ r.PeerID = candidatePeerID
+ newPeerAssigned = true
+ break
+ }
+ }
+ if newPeerAssigned {
+ break
+ }
+ }
+
+ if !newPeerAssigned {
+ routesToDelete = append(routesToDelete, routeID)
+ }
+ }
+
+ for _, routeID := range routesToDelete {
+ delete(account.Routes, routeID)
+ }
+
+ delete(b.cache.peerACLs, peerID)
+ delete(b.cache.peerRoutes, peerID)
+ delete(b.cache.peerDNS, peerID)
+
+ delete(b.cache.globalPeers, peerID)
+
+ for acg, routeMap := range b.cache.acgToRoutes {
+ for routeID, info := range routeMap {
+ if info.PeerID == peerID {
+ delete(routeMap, routeID)
+ }
+ }
+ if len(routeMap) == 0 {
+ delete(b.cache.acgToRoutes, acg)
+ }
+ }
+
+ for _, groupID := range peerGroups {
+ if peers := b.cache.groupToPeers[groupID]; peers != nil {
+ b.cache.groupToPeers[groupID] = slices.DeleteFunc(peers, func(id string) bool {
+ return id == peerID
+ })
+ }
+ }
+ delete(b.cache.peerToGroups, peerID)
+
+ affectedPeers := make(map[string]struct{})
+
+ for _, r := range account.Routes {
+ for _, groupID := range r.Groups {
+ if peers := b.cache.groupToPeers[groupID]; peers != nil {
+ for _, p := range peers {
+ affectedPeers[p] = struct{}{}
+ }
+ }
+ }
+
+ for _, groupID := range r.PeerGroups {
+ if peers := b.cache.groupToPeers[groupID]; peers != nil {
+ for _, p := range peers {
+ affectedPeers[p] = struct{}{}
+ }
+ }
+ }
+ }
+
+ for affectedPeerID := range affectedPeers {
+ if affectedPeerID == peerID {
+ continue
+ }
+ b.buildPeerRoutesView(account, affectedPeerID)
+ }
+
+ peerDeletionUpdates := b.findPeersAffectedByDeletedPeerACL(peerID, peerIP)
+ for affectedPeerID, updates := range peerDeletionUpdates {
+ b.applyDeletionUpdates(affectedPeerID, updates)
+ }
+
+ b.cleanupUnusedRules()
+
+ log.Debugf("NetworkMapBuilder: Deleted peer %s, affected %d other peers", peerID, len(affectedPeers))
+
+ return nil
+}
+
+func (b *NetworkMapBuilder) findPeersAffectedByDeletedPeerACL(
+ deletedPeerID string,
+ peerIP string,
+) map[string]*PeerDeletionUpdate {
+
+ affected := make(map[string]*PeerDeletionUpdate)
+
+ for peerID, aclView := range b.cache.peerACLs {
+ if peerID == deletedPeerID {
+ continue
+ }
+
+ if !slices.Contains(aclView.ConnectedPeerIDs, deletedPeerID) {
+ continue
+ }
+ if affected[peerID] == nil {
+ affected[peerID] = &PeerDeletionUpdate{
+ RemovePeerID: deletedPeerID,
+ PeerIP: peerIP,
+ }
+ }
+
+ for _, ruleID := range aclView.FirewallRuleIDs {
+ if rule := b.cache.globalRules[ruleID]; rule != nil && rule.PeerIP == peerIP {
+ affected[peerID].RemoveFirewallRuleIDs = append(
+ affected[peerID].RemoveFirewallRuleIDs,
+ ruleID,
+ )
+ }
+ }
+ }
+
+ return affected
+}
+
+type PeerDeletionUpdate struct {
+ RemovePeerID string
+ RemoveFirewallRuleIDs []string
+ RemoveRouteIDs []route.ID
+ RemoveFromSourceRanges bool
+ PeerIP string
+}
+
+func (b *NetworkMapBuilder) applyDeletionUpdates(peerID string, updates *PeerDeletionUpdate) {
+ if aclView := b.cache.peerACLs[peerID]; aclView != nil {
+ aclView.ConnectedPeerIDs = slices.DeleteFunc(aclView.ConnectedPeerIDs, func(id string) bool {
+ return id == updates.RemovePeerID
+ })
+
+ if len(updates.RemoveFirewallRuleIDs) > 0 {
+ aclView.FirewallRuleIDs = slices.DeleteFunc(aclView.FirewallRuleIDs, func(ruleID string) bool {
+ return slices.Contains(updates.RemoveFirewallRuleIDs, ruleID)
+ })
+ }
+ }
+
+ if routesView := b.cache.peerRoutes[peerID]; routesView != nil {
+ if len(updates.RemoveRouteIDs) > 0 {
+ routesView.NetworkResourceIDs = slices.DeleteFunc(routesView.NetworkResourceIDs, func(routeID route.ID) bool {
+ return slices.Contains(updates.RemoveRouteIDs, routeID)
+ })
+ }
+
+ if updates.RemoveFromSourceRanges {
+ b.removeIPFromRouteFirewallRules(routesView, updates.PeerIP)
+ }
+ }
+}
+
+func (b *NetworkMapBuilder) removeIPFromRouteFirewallRules(routesView *PeerRoutesView, peerIP string) {
+ sourceIPv4 := peerIP + "/32"
+ sourceIPv6 := peerIP + "/128"
+
+ rulesToRemove := []string{}
+
+ for _, ruleID := range routesView.RouteFirewallRuleIDs {
+ if rule := b.cache.globalRouteRules[ruleID]; rule != nil {
+ rule.SourceRanges = slices.DeleteFunc(rule.SourceRanges, func(source string) bool {
+ return source == sourceIPv4 || source == sourceIPv6 || source == peerIP
+ })
+
+ if len(rule.SourceRanges) == 0 {
+ rulesToRemove = append(rulesToRemove, ruleID)
+ }
+ }
+ }
+
+ if len(rulesToRemove) > 0 {
+ routesView.RouteFirewallRuleIDs = slices.DeleteFunc(routesView.RouteFirewallRuleIDs, func(ruleID string) bool {
+ return slices.Contains(rulesToRemove, ruleID)
+ })
+ }
+}
+
+func (b *NetworkMapBuilder) cleanupUnusedRules() {
+ usedFirewallRules := make(map[string]struct{})
+ usedRouteRules := make(map[string]struct{})
+ usedRoutes := make(map[route.ID]struct{})
+
+ for _, aclView := range b.cache.peerACLs {
+ for _, ruleID := range aclView.FirewallRuleIDs {
+ usedFirewallRules[ruleID] = struct{}{}
+ }
+ }
+
+ for _, routesView := range b.cache.peerRoutes {
+ for _, ruleID := range routesView.RouteFirewallRuleIDs {
+ usedRouteRules[ruleID] = struct{}{}
+ }
+
+ for _, routeID := range routesView.OwnRouteIDs {
+ usedRoutes[routeID] = struct{}{}
+ }
+ for _, routeID := range routesView.NetworkResourceIDs {
+ usedRoutes[routeID] = struct{}{}
+ }
+ }
+
+ for ruleID := range b.cache.globalRules {
+ if _, used := usedFirewallRules[ruleID]; !used {
+ delete(b.cache.globalRules, ruleID)
+ }
+ }
+
+ for ruleID := range b.cache.globalRouteRules {
+ if _, used := usedRouteRules[ruleID]; !used {
+ delete(b.cache.globalRouteRules, ruleID)
+ }
+ }
+
+ for routeID := range b.cache.globalRoutes {
+ if _, used := usedRoutes[routeID]; !used {
+ delete(b.cache.globalRoutes, routeID)
+ }
+ }
+}
+
+func (b *NetworkMapBuilder) UpdatePeer(peer *nbpeer.Peer) {
+ b.cache.mu.Lock()
+ defer b.cache.mu.Unlock()
+ peerStored, ok := b.cache.globalPeers[peer.ID]
+ if !ok {
+ return
+ }
+ *peerStored = *peer
+}
diff --git a/management/server/types/policy.go b/management/server/types/policy.go
index 5e86a87c6..d4e1a8816 100644
--- a/management/server/types/policy.go
+++ b/management/server/types/policy.go
@@ -23,6 +23,8 @@ const (
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
// PolicyRuleProtocolICMP type of traffic
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
+ // PolicyRuleProtocolNetbirdSSH type of traffic
+ PolicyRuleProtocolNetbirdSSH = PolicyRuleProtocolType("netbird-ssh")
)
const (
@@ -167,6 +169,8 @@ func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error)
protocol = PolicyRuleProtocolUDP
case "icmp":
return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'")
+ case "netbird-ssh":
+ return PolicyRuleProtocolNetbirdSSH, RulePortRange{Start: nativeSSHPortNumber, End: nativeSSHPortNumber}, nil
default:
return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr)
}
diff --git a/management/server/types/policyrule.go b/management/server/types/policyrule.go
index 2643ae45c..bb75dd555 100644
--- a/management/server/types/policyrule.go
+++ b/management/server/types/policyrule.go
@@ -80,6 +80,12 @@ type PolicyRule struct {
// PortRanges a list of port ranges.
PortRanges []RulePortRange `gorm:"serializer:json"`
+
+ // AuthorizedGroups is a map of groupIDs and their respective access to local users via ssh
+ AuthorizedGroups map[string][]string `gorm:"serializer:json"`
+
+ // AuthorizedUser is a list of userIDs that are authorized to access local resources via ssh
+ AuthorizedUser string
}
// Copy returns a copy of a policy rule
@@ -99,10 +105,16 @@ func (pm *PolicyRule) Copy() *PolicyRule {
Protocol: pm.Protocol,
Ports: make([]string, len(pm.Ports)),
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
+ AuthorizedGroups: make(map[string][]string, len(pm.AuthorizedGroups)),
+ AuthorizedUser: pm.AuthorizedUser,
}
copy(rule.Destinations, pm.Destinations)
copy(rule.Sources, pm.Sources)
copy(rule.Ports, pm.Ports)
copy(rule.PortRanges, pm.PortRanges)
+ for k, v := range pm.AuthorizedGroups {
+ rule.AuthorizedGroups[k] = make([]string, len(v))
+ copy(rule.AuthorizedGroups[k], v)
+ }
return rule
}
diff --git a/management/server/types/route_firewall_rule.go b/management/server/types/route_firewall_rule.go
index 6eb391cb5..da29e1d87 100644
--- a/management/server/types/route_firewall_rule.go
+++ b/management/server/types/route_firewall_rule.go
@@ -1,8 +1,8 @@
package types
import (
- "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
+ "github.com/netbirdio/netbird/shared/management/domain"
)
// RouteFirewallRule a firewall rule applicable for a routed network.
diff --git a/management/server/types/settings.go b/management/server/types/settings.go
index b4afb2f5e..867e12bef 100644
--- a/management/server/types/settings.go
+++ b/management/server/types/settings.go
@@ -52,6 +52,9 @@ type Settings struct {
// LazyConnectionEnabled indicates if the experimental feature is enabled or disabled
LazyConnectionEnabled bool `gorm:"default:false"`
+
+ // AutoUpdateVersion client auto-update version
+ AutoUpdateVersion string `gorm:"default:'disabled'"`
}
// Copy copies the Settings struct
@@ -72,6 +75,7 @@ func (s *Settings) Copy() *Settings {
LazyConnectionEnabled: s.LazyConnectionEnabled,
DNSDomain: s.DNSDomain,
NetworkRange: s.NetworkRange,
+ AutoUpdateVersion: s.AutoUpdateVersion,
}
if s.Extra != nil {
settings.Extra = s.Extra.Copy()
diff --git a/management/server/user.go b/management/server/user.go
index d40d33c6a..9d4620462 100644
--- a/management/server/user.go
+++ b/management/server/user.go
@@ -7,12 +7,13 @@ import (
"strings"
"time"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
+ "github.com/netbirdio/netbird/shared/auth"
+
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
- nbContext "github.com/netbirdio/netbird/management/server/context"
- nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions/modules"
@@ -175,9 +176,9 @@ func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*t
return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
}
-// GetUser looks up a user by provided nbContext.UserAuths.
+// GetUser looks up a user by provided auth.UserAuths.
// Expects account to have been created already.
-func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) {
+func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil {
return nil, err
@@ -262,15 +263,11 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
return err
}
- updateAccountPeers, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo)
+ _, err = am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo)
if err != nil {
return err
}
- if updateAccountPeers {
- am.UpdateAccountPeers(ctx, accountID)
- }
-
return nil
}
@@ -526,16 +523,14 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
- userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
+ _, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
)
if err != nil {
return fmt.Errorf("failed to process update for user %s: %w", update.Id, err)
}
- if userHadPeers {
- updateAccountPeers = true
- }
+ updateAccountPeers = true
err = transaction.SaveUser(ctx, updatedUser)
if err != nil {
@@ -584,7 +579,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
}
- if settings.GroupsPropagationEnabled && updateAccountPeers {
+ if updateAccountPeers {
if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil {
return nil, fmt.Errorf("failed to increment network serial: %w", err)
}
@@ -595,9 +590,15 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
-func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool) []func() {
+func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool, isNewUser bool, removedGroupIDs, addedGroupIDs []string, tx store.Store) []func() {
var eventsToStore []func()
+ if isNewUser {
+ eventsToStore = append(eventsToStore, func() {
+ am.StoreEvent(ctx, initiatorUserID, newUser.Id, accountID, activity.UserCreated, nil)
+ })
+ }
+
if oldUser.IsBlocked() != newUser.IsBlocked() {
if newUser.IsBlocked() {
eventsToStore = append(eventsToStore, func() {
@@ -621,6 +622,35 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac
})
}
+ addedGroups, err := tx.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, addedGroupIDs)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to get added groups for user %s update event: %v", oldUser.Id, err)
+ }
+
+ for _, group := range addedGroups {
+ meta := map[string]any{
+ "group": group.Name, "group_id": group.ID,
+ "is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName,
+ }
+ eventsToStore = append(eventsToStore, func() {
+ am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupAddedToUser, meta)
+ })
+ }
+
+ removedGroups, err := tx.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, removedGroupIDs)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to get removed groups for user %s update event: %v", oldUser.Id, err)
+ }
+ for _, group := range removedGroups {
+ meta := map[string]any{
+ "group": group.Name, "group_id": group.ID,
+ "is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName,
+ }
+ eventsToStore = append(eventsToStore, func() {
+ am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta)
+ })
+ }
+
return eventsToStore
}
@@ -631,7 +661,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
}
- oldUser, err := getUserOrCreateIfNotExists(ctx, transaction, accountID, update, addIfNotExists)
+ oldUser, isNewUser, err := getUserOrCreateIfNotExists(ctx, transaction, accountID, update, addIfNotExists)
if err != nil {
return false, nil, nil, nil, err
}
@@ -667,9 +697,10 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
peersToExpire = userPeers
}
+ var removedGroups, addedGroups []string
if update.AutoGroups != nil && settings.GroupsPropagationEnabled {
- removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups)
- addedGroups := util.Difference(update.AutoGroups, oldUser.AutoGroups)
+ removedGroups = util.Difference(oldUser.AutoGroups, update.AutoGroups)
+ addedGroups = util.Difference(update.AutoGroups, oldUser.AutoGroups)
for _, peer := range userPeers {
for _, groupID := range removedGroups {
if err := transaction.RemovePeerFromGroup(ctx, peer.ID, groupID); err != nil {
@@ -685,30 +716,30 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
}
updateAccountPeers := len(userPeers) > 0
- userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole)
+ userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, isNewUser, removedGroups, addedGroups, transaction)
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
}
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
-func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) {
+func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, bool, error) {
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, update.Id)
if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
if !addIfNotExists {
- return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id)
+ return nil, false, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id)
}
update.AccountID = accountID
- return update, nil // use all fields from update if addIfNotExists is true
+ return update, true, nil // use all fields from update if addIfNotExists is true
}
- return nil, err
+ return nil, false, err
}
if existingUser.AccountID != accountID {
- return nil, status.Errorf(status.InvalidArgument, "user account ID mismatch")
+ return nil, false, status.Errorf(status.InvalidArgument, "user account ID mismatch")
}
- return existingUser, nil
+ return existingUser, false, nil
}
func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *types.User) (bool, error) {
@@ -935,12 +966,12 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
if err != nil {
return err
}
- dnsDomain := am.GetDNSDomain(settings)
+ dnsDomain := am.networkMapController.GetDNSDomain(settings)
var peerIDs []string
for _, peer := range peers {
// nolint:staticcheck
- ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key)
+ ctx = context.WithValue(ctx, nbcontext.PeerIDKey, peer.Key)
if peer.UserID == "" {
// we do not want to expire peers that are added via setup key
@@ -963,11 +994,15 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
)
}
+ err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs)
+ if err != nil {
+ return fmt.Errorf("notify network map controller of peer update: %w", err)
+ }
+
if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service
log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID)
- am.peersUpdateManager.CloseChannels(ctx, peerIDs)
- am.BufferUpdateAccountPeers(ctx, accountID)
+ am.networkMapController.DisconnectPeers(ctx, accountID, peerIDs)
}
return nil
}
@@ -1013,7 +1048,6 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
}
var allErrors error
- var updateAccountPeers bool
for _, targetUserID := range targetUserIDs {
if initiatorUserID == targetUserID {
@@ -1044,19 +1078,11 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
continue
}
- userHadPeers, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo)
+ _, err = am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo)
if err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
-
- if userHadPeers {
- updateAccountPeers = true
- }
- }
-
- if updateAccountPeers {
- am.UpdateAccountPeers(ctx, accountID)
}
return allErrors
@@ -1081,6 +1107,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
var addPeerRemovedEvents []func()
var updateAccountPeers bool
+ var userPeers []*nbpeer.Peer
var targetUser *types.User
var err error
@@ -1090,7 +1117,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
return fmt.Errorf("failed to get user to delete: %w", err)
}
- userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID)
+ userPeers, err = transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID)
if err != nil {
return fmt.Errorf("failed to get user peers: %w", err)
}
@@ -1113,6 +1140,14 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
return false, err
}
+ var peerIDs []string
+ for _, peer := range userPeers {
+ peerIDs = append(peerIDs, peer.ID)
+ }
+ if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs); err != nil {
+ log.WithContext(ctx).Errorf("failed to delete peers %s from network map: %v", peerIDs, err)
+ }
+
for _, addPeerRemovedEvent := range addPeerRemovedEvents {
addPeerRemovedEvent()
}
@@ -1171,7 +1206,7 @@ func validateUserInvite(invite *types.UserInfo) error {
}
// GetCurrentUserInfo retrieves the account's current user info and permissions
-func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
+func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) {
accountID, userID := userAuth.AccountId, userAuth.UserId
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
diff --git a/management/server/user_test.go b/management/server/user_test.go
index 5920a2a33..3032ee3e8 100644
--- a/management/server/user_test.go
+++ b/management/server/user_test.go
@@ -8,15 +8,17 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "go.uber.org/mock/gomock"
"golang.org/x/exp/maps"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbcache "github.com/netbirdio/netbird/management/server/cache"
- nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/management/server/util"
+ "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/status"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -547,7 +549,7 @@ func TestUser_InviteNewUser(t *testing.T) {
permissionsManager: permissionsManager,
}
- cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval)
+ cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn)
require.NoError(t, err)
am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cs)
@@ -739,11 +741,18 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
+ ctrl := gomock.NewController(t)
+ networkMapControllerMock := network_map.NewMockController(ctrl)
+ networkMapControllerMock.EXPECT().
+ OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
+ Return(nil)
+
permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
- Store: store,
- eventStore: &activity.InMemoryEventStore{},
- permissionsManager: permissionsManager,
+ Store: store,
+ eventStore: &activity.InMemoryEventStore{},
+ permissionsManager: permissionsManager,
+ networkMapController: networkMapControllerMock,
}
testCases := []struct {
@@ -848,12 +857,20 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
+ ctrl := gomock.NewController(t)
+ networkMapControllerMock := network_map.NewMockController(ctrl)
+ networkMapControllerMock.EXPECT().
+ OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
+ Return(nil).
+ AnyTimes()
+
permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
integratedPeerValidator: MockIntegratedValidator{},
permissionsManager: permissionsManager,
+ networkMapController: networkMapControllerMock,
}
testCases := []struct {
@@ -966,7 +983,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
permissionsManager: permissionsManager,
}
- claims := nbcontext.UserAuth{
+ claims := auth.UserAuth{
UserId: mockUserID,
AccountId: mockAccountID,
}
@@ -1056,7 +1073,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
permissionsManager: permissionsManager,
}
- cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval)
+ cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn)
assert.NoError(t, err)
am.externalCacheManager = nbcache.NewUserDataCache(cacheStore)
am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore)
@@ -1161,7 +1178,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
}
func TestDefaultAccountManager_SaveUser(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -1333,7 +1350,7 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
func TestUserAccountPeersUpdate(t *testing.T) {
// account groups propagation is enabled
- manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
+ manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
@@ -1357,16 +1374,16 @@ func TestUserAccountPeersUpdate(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
require.NoError(t, err)
- updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
+ updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
- manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
+ updateManager.CloseChannel(context.Background(), peer1.ID)
})
- // Creating a new regular user should not update account peers and not send peer update
+ // Creating a new regular user should send peer update (as users are not filtered yet)
t.Run("creating new regular user with no groups", func(t *testing.T) {
done := make(chan struct{})
go func() {
- peerShouldNotReceiveUpdate(t, updMsg)
+ peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
@@ -1385,11 +1402,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
}
})
- // updating user with no linked peers should not update account peers and not send peer update
+ // updating user with no linked peers should update account peers and send peer update (as users are not filtered yet)
t.Run("updating user with no linked peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
- peerShouldNotReceiveUpdate(t, updMsg)
+ peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
@@ -1412,7 +1429,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
t.Run("deleting user with no linked peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
- peerShouldNotReceiveUpdate(t, updMsg)
+ peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
@@ -1468,9 +1485,9 @@ func TestUserAccountPeersUpdate(t *testing.T) {
}
})
- peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID)
+ peer4UpdMsg := updateManager.CreateChannel(context.Background(), peer4.ID)
t.Cleanup(func() {
- manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID)
+ updateManager.CloseChannel(context.Background(), peer4.ID)
})
// deleting user with linked peers should update account peers and send peer update
@@ -1573,33 +1590,33 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
tt := []struct {
name string
- userAuth nbcontext.UserAuth
+ userAuth auth.UserAuth
expectedErr error
expectedResult *users.UserInfoWithPermissions
}{
{
name: "not found",
- userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "not-found"},
+ userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "not-found"},
expectedErr: status.NewUserNotFoundError("not-found"),
},
{
name: "not part of account",
- userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account2Owner"},
+ userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account2Owner"},
expectedErr: status.NewUserNotPartOfAccountError(),
},
{
name: "blocked",
- userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "blocked-user"},
+ userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "blocked-user"},
expectedErr: status.NewUserBlockedError(),
},
{
name: "service user",
- userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "service-user"},
+ userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "service-user"},
expectedErr: status.NewPermissionDeniedError(),
},
{
name: "owner user",
- userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account1Owner"},
+ userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account1Owner"},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "account1Owner",
@@ -1619,7 +1636,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
},
{
name: "regular user",
- userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "regular-user"},
+ userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "regular-user"},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "regular-user",
@@ -1638,7 +1655,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
},
{
name: "admin user",
- userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "admin-user"},
+ userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "admin-user"},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "admin-user",
@@ -1657,7 +1674,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
},
{
name: "settings blocked regular user",
- userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"},
+ userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "settings-blocked-user",
@@ -1678,7 +1695,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
{
name: "settings blocked regular user child account",
- userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true},
+ userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "settings-blocked-user",
@@ -1698,7 +1715,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
},
{
name: "settings blocked owner user",
- userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "account2Owner"},
+ userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "account2Owner"},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "account2Owner",
@@ -1748,7 +1765,7 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
}
func TestApproveUser(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -1807,7 +1824,7 @@ func TestApproveUser(t *testing.T) {
}
func TestRejectUser(t *testing.T) {
- manager, err := createManager(t)
+ manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
diff --git a/relay/cmd/root.go b/relay/cmd/root.go
index eb2cdebf8..e7dadcfdf 100644
--- a/relay/cmd/root.go
+++ b/relay/cmd/root.go
@@ -160,7 +160,8 @@ func execute(cmd *cobra.Command, args []string) error {
log.Debugf("failed to create relay server: %v", err)
return fmt.Errorf("failed to create relay server: %v", err)
}
- log.Infof("server will be available on: %s", srv.InstanceURL())
+ instanceURL := srv.InstanceURL()
+ log.Infof("server will be available on: %s", instanceURL.String())
wg.Add(1)
go func() {
defer wg.Done()
diff --git a/relay/healthcheck/healthcheck.go b/relay/healthcheck/healthcheck.go
index eedd62394..b54d4b33b 100644
--- a/relay/healthcheck/healthcheck.go
+++ b/relay/healthcheck/healthcheck.go
@@ -6,14 +6,14 @@ import (
"errors"
"net"
"net/http"
+ "net/url"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol"
- "github.com/netbirdio/netbird/relay/server/listener/quic"
- "github.com/netbirdio/netbird/relay/server/listener/ws"
+ "github.com/netbirdio/netbird/relay/server"
)
const (
@@ -27,7 +27,7 @@ const (
type ServiceChecker interface {
ListenerProtocols() []protocol.Protocol
- ListenAddress() string
+ InstanceURL() url.URL
}
type HealthStatus struct {
@@ -135,7 +135,11 @@ func (s *Server) getHealthStatus(ctx context.Context) (*HealthStatus, bool) {
}
status.Listeners = listeners
- if ok := s.validateCertificate(ctx); !ok {
+ if s.config.ServiceChecker.InstanceURL().Scheme != server.SchemeRELS {
+ status.CertificateValid = false
+ }
+
+ if ok := s.validateConnection(ctx); !ok {
status.Status = statusUnhealthy
status.CertificateValid = false
healthy = false
@@ -152,32 +156,13 @@ func (s *Server) validateListeners() ([]protocol.Protocol, bool) {
return listeners, true
}
-func (s *Server) validateCertificate(ctx context.Context) bool {
- listenAddress := s.config.ServiceChecker.ListenAddress()
- if listenAddress == "" {
- log.Warn("listen address is empty")
+func (s *Server) validateConnection(ctx context.Context) bool {
+ addr := s.config.ServiceChecker.InstanceURL()
+ if err := dialWS(ctx, addr); err != nil {
+ log.Errorf("failed to dial WebSocket listener at %s: %v", addr.String(), err)
return false
}
- dAddr := dialAddress(listenAddress)
-
- for _, proto := range s.config.ServiceChecker.ListenerProtocols() {
- switch proto {
- case ws.Proto:
- if err := dialWS(ctx, dAddr); err != nil {
- log.Errorf("failed to dial WebSocket listener: %v", err)
- return false
- }
- case quic.Proto:
- if err := dialQUIC(ctx, dAddr); err != nil {
- log.Errorf("failed to dial QUIC listener: %v", err)
- return false
- }
- default:
- log.Warnf("unknown protocol for healthcheck: %s", proto)
- return false
- }
- }
return true
}
@@ -187,8 +172,9 @@ func dialAddress(listenAddress string) string {
return listenAddress // fallback, might be invalid for dialing
}
+ // When listening on all interfaces, show localhost for better readability
if host == "" || host == "::" || host == "0.0.0.0" {
- host = "0.0.0.0"
+ host = "localhost"
}
return net.JoinHostPort(host, port)
diff --git a/relay/healthcheck/peerid/peerid.go b/relay/healthcheck/peerid/peerid.go
new file mode 100644
index 000000000..cd8696817
--- /dev/null
+++ b/relay/healthcheck/peerid/peerid.go
@@ -0,0 +1,31 @@
+package peerid
+
+import (
+ "crypto/sha256"
+
+ v2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
+ "github.com/netbirdio/netbird/shared/relay/messages"
+)
+
+var (
+ // HealthCheckPeerID is the hashed peer ID for health check connections
+ HealthCheckPeerID = messages.HashID("healthcheck-agent")
+
+ // DummyAuthToken is a structurally valid auth token for health check.
+ // The signature is not valid but the format is correct (1 byte algo + 32 bytes signature + payload).
+ DummyAuthToken = createDummyToken()
+)
+
+func createDummyToken() []byte {
+ token := v2.Token{
+ AuthAlgo: v2.AuthAlgoHMACSHA256,
+ Signature: make([]byte, sha256.Size),
+ Payload: []byte("healthcheck"),
+ }
+ return token.Marshal()
+}
+
+// IsHealthCheck checks if the given peer ID is the health check agent
+func IsHealthCheck(peerID *messages.PeerID) bool {
+ return peerID != nil && *peerID == HealthCheckPeerID
+}
diff --git a/relay/healthcheck/quic.go b/relay/healthcheck/quic.go
deleted file mode 100644
index 1582edf7b..000000000
--- a/relay/healthcheck/quic.go
+++ /dev/null
@@ -1,31 +0,0 @@
-package healthcheck
-
-import (
- "context"
- "crypto/tls"
- "fmt"
- "time"
-
- "github.com/quic-go/quic-go"
-
- tlsnb "github.com/netbirdio/netbird/shared/relay/tls"
-)
-
-func dialQUIC(ctx context.Context, address string) error {
- tlsConfig := &tls.Config{
- InsecureSkipVerify: false, // Keep certificate validation enabled
- NextProtos: []string{tlsnb.NBalpn},
- }
-
- conn, err := quic.DialAddr(ctx, address, tlsConfig, &quic.Config{
- MaxIdleTimeout: 30 * time.Second,
- KeepAlivePeriod: 10 * time.Second,
- EnableDatagrams: true,
- })
- if err != nil {
- return fmt.Errorf("failed to connect to QUIC server: %w", err)
- }
-
- _ = conn.CloseWithError(0, "availability check complete")
- return nil
-}
diff --git a/relay/healthcheck/ws.go b/relay/healthcheck/ws.go
index 49694356c..9267096f5 100644
--- a/relay/healthcheck/ws.go
+++ b/relay/healthcheck/ws.go
@@ -3,26 +3,47 @@ package healthcheck
import (
"context"
"fmt"
+ "net/url"
"github.com/coder/websocket"
+ "github.com/netbirdio/netbird/relay/healthcheck/peerid"
+ "github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/shared/relay"
+ "github.com/netbirdio/netbird/shared/relay/messages"
)
-func dialWS(ctx context.Context, address string) error {
- url := fmt.Sprintf("wss://%s%s", address, relay.WebSocketURLPath)
+func dialWS(ctx context.Context, address url.URL) error {
+ scheme := "ws"
+ if address.Scheme == server.SchemeRELS {
+ scheme = "wss"
+ }
+ wsURL := fmt.Sprintf("%s://%s%s", scheme, address.Host, relay.WebSocketURLPath)
- conn, resp, err := websocket.Dial(ctx, url, nil)
+ conn, resp, err := websocket.Dial(ctx, wsURL, nil)
if resp != nil {
defer func() {
- _ = resp.Body.Close()
+ if resp.Body != nil {
+ _ = resp.Body.Close()
+ }
}()
}
if err != nil {
return fmt.Errorf("failed to connect to websocket: %w", err)
}
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ authMsg, err := messages.MarshalAuthMsg(peerid.HealthCheckPeerID, peerid.DummyAuthToken)
+ if err != nil {
+ return fmt.Errorf("failed to marshal auth message: %w", err)
+ }
+
+ if err := conn.Write(ctx, websocket.MessageBinary, authMsg); err != nil {
+ return fmt.Errorf("failed to write auth message: %w", err)
+ }
- _ = conn.Close(websocket.StatusNormalClosure, "availability check complete")
return nil
}
diff --git a/relay/server/handshake.go b/relay/server/handshake.go
index 922369798..8c3ee1899 100644
--- a/relay/server/handshake.go
+++ b/relay/server/handshake.go
@@ -97,7 +97,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
}
if err != nil {
- return nil, err
+ return peerID, err
}
h.peerID = peerID
return peerID, nil
@@ -147,7 +147,7 @@ func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) {
}
if err := h.validator.Validate(authPayload); err != nil {
- return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err)
+ return rawPeerID, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err)
}
return rawPeerID, nil
diff --git a/relay/server/peer.go b/relay/server/peer.go
index c47f2e960..c5ff41857 100644
--- a/relay/server/peer.go
+++ b/relay/server/peer.go
@@ -9,10 +9,10 @@ import (
log "github.com/sirupsen/logrus"
- "github.com/netbirdio/netbird/shared/relay/healthcheck"
- "github.com/netbirdio/netbird/shared/relay/messages"
"github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store"
+ "github.com/netbirdio/netbird/shared/relay/healthcheck"
+ "github.com/netbirdio/netbird/shared/relay/messages"
)
const (
diff --git a/relay/server/relay.go b/relay/server/relay.go
index d86684937..bb355f58f 100644
--- a/relay/server/relay.go
+++ b/relay/server/relay.go
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
+ "net/url"
"sync"
"time"
@@ -11,6 +12,7 @@ import (
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/metric"
+ "github.com/netbirdio/netbird/relay/healthcheck/peerid"
//nolint:staticcheck
"github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store"
@@ -22,7 +24,7 @@ type Config struct {
TLSSupport bool
AuthValidator Validator
- instanceURL string
+ instanceURL url.URL
}
func (c *Config) validate() error {
@@ -37,7 +39,7 @@ func (c *Config) validate() error {
if err != nil {
return fmt.Errorf("invalid url: %v", err)
}
- c.instanceURL = instanceURL
+ c.instanceURL = *instanceURL
if c.AuthValidator == nil {
return fmt.Errorf("auth validator is required")
@@ -51,10 +53,11 @@ type Relay struct {
metricsCancel context.CancelFunc
validator Validator
- store *store.Store
- notifier *store.PeerNotifier
- instanceURL string
- preparedMsg *preparedMsg
+ store *store.Store
+ notifier *store.PeerNotifier
+ instanceURL url.URL
+ exposedAddress string
+ preparedMsg *preparedMsg
closed bool
closeMu sync.RWMutex
@@ -87,15 +90,16 @@ func NewRelay(config Config) (*Relay, error) {
}
r := &Relay{
- metrics: m,
- metricsCancel: metricsCancel,
- validator: config.AuthValidator,
- instanceURL: config.instanceURL,
- store: store.NewStore(),
- notifier: store.NewPeerNotifier(),
+ metrics: m,
+ metricsCancel: metricsCancel,
+ validator: config.AuthValidator,
+ instanceURL: config.instanceURL,
+ exposedAddress: config.ExposedAddress,
+ store: store.NewStore(),
+ notifier: store.NewPeerNotifier(),
}
- r.preparedMsg, err = newPreparedMsg(r.instanceURL)
+ r.preparedMsg, err = newPreparedMsg(r.instanceURL.String())
if err != nil {
metricsCancel()
return nil, fmt.Errorf("prepare message: %v", err)
@@ -120,7 +124,11 @@ func (r *Relay) Accept(conn net.Conn) {
}
peerID, err := h.handshakeReceive()
if err != nil {
- log.Errorf("failed to handshake: %s", err)
+ if peerid.IsHealthCheck(peerID) {
+ log.Debugf("health check connection from %s", conn.RemoteAddr())
+ } else {
+ log.Errorf("failed to handshake: %s", err)
+ }
if cErr := conn.Close(); cErr != nil {
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
}
@@ -175,6 +183,6 @@ func (r *Relay) Shutdown(ctx context.Context) {
}
// InstanceURL returns the instance URL of the relay server
-func (r *Relay) InstanceURL() string {
+func (r *Relay) InstanceURL() url.URL {
return r.instanceURL
}
diff --git a/relay/server/server.go b/relay/server/server.go
index 4c30e7fdc..8e4333064 100644
--- a/relay/server/server.go
+++ b/relay/server/server.go
@@ -3,6 +3,7 @@ package server
import (
"context"
"crypto/tls"
+ "net/url"
"sync"
"github.com/hashicorp/go-multierror"
@@ -28,8 +29,6 @@ type ListenerConfig struct {
// It is the gate between the WebSocket listener and the Relay server logic.
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
type Server struct {
- listenAddr string
-
relay *Relay
listeners []listener.Listener
listenerMux sync.Mutex
@@ -41,7 +40,7 @@ type Server struct {
//
// config: A Config struct containing the necessary configuration:
// - Meter: An OpenTelemetry metric.Meter used for recording metrics. If nil, a default no-op meter is used.
-// - ExposedAddress: The public address (in domain:port format) used as the server's instance URL. Required.
+// - InstanceURL: The public address (in domain:port format) used as the server's instance URL. Required.
// - TLSSupport: A boolean indicating whether TLS is enabled for the server.
// - AuthValidator: A Validator used to authenticate peers. Required.
//
@@ -62,8 +61,6 @@ func NewServer(config Config) (*Server, error) {
// Listen starts the relay server.
func (r *Server) Listen(cfg ListenerConfig) error {
- r.listenAddr = cfg.Address
-
wSListener := &ws.Listener{
Address: cfg.Address,
TLSConfig: cfg.TLSConfig,
@@ -123,11 +120,6 @@ func (r *Server) Shutdown(ctx context.Context) error {
return nberrors.FormatErrorOrNil(multiErr)
}
-// InstanceURL returns the instance URL of the relay server.
-func (r *Server) InstanceURL() string {
- return r.relay.instanceURL
-}
-
func (r *Server) ListenerProtocols() []protocol.Protocol {
result := make([]protocol.Protocol, 0)
@@ -139,6 +131,6 @@ func (r *Server) ListenerProtocols() []protocol.Protocol {
return result
}
-func (r *Server) ListenAddress() string {
- return r.listenAddr
+func (r *Server) InstanceURL() url.URL {
+ return r.relay.InstanceURL()
}
diff --git a/relay/server/url.go b/relay/server/url.go
index 9cbf44642..aeae1c068 100644
--- a/relay/server/url.go
+++ b/relay/server/url.go
@@ -6,9 +6,14 @@ import (
"strings"
)
+const (
+ SchemeREL = "rel"
+ SchemeRELS = "rels"
+)
+
// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
// provided address according to TLS definition and parses the address before returning it
-func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
+func getInstanceURL(exposedAddress string, tlsSupported bool) (*url.URL, error) {
addr := exposedAddress
split := strings.Split(exposedAddress, "://")
switch {
@@ -17,17 +22,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
case len(split) == 1 && !tlsSupported:
addr = "rel://" + exposedAddress
case len(split) > 2:
- return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
+ return nil, fmt.Errorf("invalid exposed address: %s", exposedAddress)
}
parsedURL, err := url.ParseRequestURI(addr)
if err != nil {
- return "", fmt.Errorf("invalid exposed address: %v", err)
+ return nil, fmt.Errorf("invalid exposed address: %v", err)
}
- if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
- return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
+ if parsedURL.Scheme != SchemeREL && parsedURL.Scheme != SchemeRELS {
+ return nil, fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
}
- return parsedURL.String(), nil
+ // Validate scheme matches TLS configuration
+ if tlsSupported && parsedURL.Scheme == SchemeREL {
+ return nil, fmt.Errorf("non-TLS scheme '%s' provided but TLS is supported", SchemeREL)
+ }
+
+ return parsedURL, nil
}
diff --git a/relay/server/relay_test.go b/relay/server/url_test.go
similarity index 78%
rename from relay/server/relay_test.go
rename to relay/server/url_test.go
index 062039ab9..ca455f45a 100644
--- a/relay/server/relay_test.go
+++ b/relay/server/url_test.go
@@ -13,7 +13,7 @@ func TestGetInstanceURL(t *testing.T) {
{"Valid address with TLS", "example.com", true, "rels://example.com", false},
{"Valid address without TLS", "example.com", false, "rel://example.com", false},
{"Valid address with scheme", "rel://example.com", false, "rel://example.com", false},
- {"Valid address with non TLS scheme and TLS true", "rel://example.com", true, "rel://example.com", false},
+ {"Invalid address with non TLS scheme and TLS true", "rel://example.com", true, "", true},
{"Valid address with TLS scheme", "rels://example.com", true, "rels://example.com", false},
{"Valid address with TLS scheme and TLS false", "rels://example.com", false, "rels://example.com", false},
{"Valid address with TLS scheme and custom port", "rels://example.com:9300", true, "rels://example.com:9300", false},
@@ -28,8 +28,11 @@ func TestGetInstanceURL(t *testing.T) {
if (err != nil) != tt.expectError {
t.Errorf("expected error: %v, got: %v", tt.expectError, err)
}
- if url != tt.expectedURL {
- t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url)
+ if !tt.expectError && url != nil && url.String() != tt.expectedURL {
+ t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url.String())
+ }
+ if tt.expectError && url != nil {
+ t.Errorf("expected nil URL on error, got: %s", url.String())
}
})
}
diff --git a/release_files/freebsd-port-diff.sh b/release_files/freebsd-port-diff.sh
new file mode 100755
index 000000000..b030b9164
--- /dev/null
+++ b/release_files/freebsd-port-diff.sh
@@ -0,0 +1,216 @@
+#!/bin/bash
+#
+# FreeBSD Port Diff Generator for NetBird
+#
+# This script generates the diff file required for submitting a FreeBSD port update.
+# It works on macOS, Linux, and FreeBSD by fetching files from FreeBSD cgit and
+# computing checksums from the Go module proxy.
+#
+# Usage: ./freebsd-port-diff.sh [new_version]
+# Example: ./freebsd-port-diff.sh 0.60.7
+#
+# If no version is provided, it fetches the latest from GitHub.
+
+set -e
+
+GITHUB_REPO="netbirdio/netbird"
+PORTS_CGIT_BASE="https://cgit.freebsd.org/ports/plain/security/netbird"
+GO_PROXY="https://proxy.golang.org/github.com/netbirdio/netbird/@v"
+OUTPUT_DIR="${OUTPUT_DIR:-.}"
+AWK_FIRST_FIELD='{print $1}'
+
+fetch_all_tags() {
+ curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
+ grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
+ sed 's/.*\/v//' | \
+ sort -u -V
+ return 0
+}
+
+fetch_current_ports_version() {
+ echo "Fetching current version from FreeBSD ports..." >&2
+ curl -sL "${PORTS_CGIT_BASE}/Makefile" 2>/dev/null | \
+ grep -E "^DISTVERSION=" | \
+ sed 's/DISTVERSION=[[:space:]]*//' | \
+ tr -d '\t '
+ return 0
+}
+
+fetch_latest_github_release() {
+ echo "Fetching latest release from GitHub..." >&2
+ fetch_all_tags | tail -1
+ return 0
+}
+
+fetch_ports_file() {
+ local filename="$1"
+ curl -sL "${PORTS_CGIT_BASE}/${filename}" 2>/dev/null
+ return 0
+}
+
+compute_checksums() {
+ local version="$1"
+ local tmpdir
+ tmpdir=$(mktemp -d)
+ # shellcheck disable=SC2064
+ trap "rm -rf '$tmpdir'" EXIT
+
+ echo "Downloading files from Go module proxy for v${version}..." >&2
+
+ local mod_file="${tmpdir}/v${version}.mod"
+ local zip_file="${tmpdir}/v${version}.zip"
+
+ curl -sL "${GO_PROXY}/v${version}.mod" -o "$mod_file" 2>/dev/null
+ curl -sL "${GO_PROXY}/v${version}.zip" -o "$zip_file" 2>/dev/null
+
+ if [[ ! -s "$mod_file" ]] || [[ ! -s "$zip_file" ]]; then
+ echo "Error: Could not download files from Go module proxy" >&2
+ return 1
+ fi
+
+ local mod_sha256 mod_size zip_sha256 zip_size
+
+ if command -v sha256sum &>/dev/null; then
+ mod_sha256=$(sha256sum "$mod_file" | awk "$AWK_FIRST_FIELD")
+ zip_sha256=$(sha256sum "$zip_file" | awk "$AWK_FIRST_FIELD")
+ elif command -v shasum &>/dev/null; then
+ mod_sha256=$(shasum -a 256 "$mod_file" | awk "$AWK_FIRST_FIELD")
+ zip_sha256=$(shasum -a 256 "$zip_file" | awk "$AWK_FIRST_FIELD")
+ else
+ echo "Error: No sha256 command found" >&2
+ return 1
+ fi
+
+ if [[ "$OSTYPE" == "darwin"* ]]; then
+ mod_size=$(stat -f%z "$mod_file")
+ zip_size=$(stat -f%z "$zip_file")
+ else
+ mod_size=$(stat -c%s "$mod_file")
+ zip_size=$(stat -c%s "$zip_file")
+ fi
+
+ echo "TIMESTAMP = $(date +%s)"
+ echo "SHA256 (go/security_netbird/netbird-v${version}/v${version}.mod) = ${mod_sha256}"
+ echo "SIZE (go/security_netbird/netbird-v${version}/v${version}.mod) = ${mod_size}"
+ echo "SHA256 (go/security_netbird/netbird-v${version}/v${version}.zip) = ${zip_sha256}"
+ echo "SIZE (go/security_netbird/netbird-v${version}/v${version}.zip) = ${zip_size}"
+ return 0
+}
+
+generate_new_makefile() {
+ local new_version="$1"
+ local old_makefile="$2"
+
+ # Check if old version had PORTREVISION
+ if echo "$old_makefile" | grep -q "^PORTREVISION="; then
+ # Remove PORTREVISION line and update DISTVERSION
+ echo "$old_makefile" | \
+ sed "s/^DISTVERSION=.*/DISTVERSION= ${new_version}/" | \
+ grep -v "^PORTREVISION="
+ else
+ # Just update DISTVERSION
+ echo "$old_makefile" | \
+ sed "s/^DISTVERSION=.*/DISTVERSION= ${new_version}/"
+ fi
+ return 0
+}
+
+# Parse arguments
+NEW_VERSION="${1:-}"
+
+# Auto-detect versions if not provided
+OLD_VERSION=$(fetch_current_ports_version)
+if [[ -z "$OLD_VERSION" ]]; then
+ echo "Error: Could not fetch current version from FreeBSD ports" >&2
+ exit 1
+fi
+echo "Current FreeBSD ports version: ${OLD_VERSION}" >&2
+
+if [[ -z "$NEW_VERSION" ]]; then
+ NEW_VERSION=$(fetch_latest_github_release)
+ if [[ -z "$NEW_VERSION" ]]; then
+ echo "Error: Could not fetch latest release from GitHub" >&2
+ exit 1
+ fi
+fi
+echo "Target version: ${NEW_VERSION}" >&2
+
+if [[ "$OLD_VERSION" = "$NEW_VERSION" ]]; then
+ echo "Port is already at version ${NEW_VERSION}. Nothing to do." >&2
+ exit 0
+fi
+
+echo "" >&2
+
+# Fetch current files
+echo "Fetching current Makefile from FreeBSD ports..." >&2
+OLD_MAKEFILE=$(fetch_ports_file "Makefile")
+if [[ -z "$OLD_MAKEFILE" ]]; then
+ echo "Error: Could not fetch Makefile" >&2
+ exit 1
+fi
+
+echo "Fetching current distinfo from FreeBSD ports..." >&2
+OLD_DISTINFO=$(fetch_ports_file "distinfo")
+if [[ -z "$OLD_DISTINFO" ]]; then
+ echo "Error: Could not fetch distinfo" >&2
+ exit 1
+fi
+
+# Generate new files
+echo "Generating new Makefile..." >&2
+NEW_MAKEFILE=$(generate_new_makefile "$NEW_VERSION" "$OLD_MAKEFILE")
+
+echo "Computing checksums for new version..." >&2
+NEW_DISTINFO=$(compute_checksums "$NEW_VERSION")
+if [[ -z "$NEW_DISTINFO" ]]; then
+ echo "Error: Could not compute checksums" >&2
+ exit 1
+fi
+
+# Create temp files for diff
+TMPDIR=$(mktemp -d)
+# shellcheck disable=SC2064
+trap "rm -rf '$TMPDIR'" EXIT
+
+mkdir -p "${TMPDIR}/a/security/netbird" "${TMPDIR}/b/security/netbird"
+
+echo "$OLD_MAKEFILE" > "${TMPDIR}/a/security/netbird/Makefile"
+echo "$OLD_DISTINFO" > "${TMPDIR}/a/security/netbird/distinfo"
+echo "$NEW_MAKEFILE" > "${TMPDIR}/b/security/netbird/Makefile"
+echo "$NEW_DISTINFO" > "${TMPDIR}/b/security/netbird/distinfo"
+
+# Generate diff
+OUTPUT_FILE="${OUTPUT_DIR}/netbird-${NEW_VERSION}.diff"
+
+echo "" >&2
+echo "Generating diff..." >&2
+
+# Generate diff and clean up temp paths to show standard a/b paths
+(cd "${TMPDIR}" && diff -ruN "a/security/netbird" "b/security/netbird") > "$OUTPUT_FILE" || true
+
+if [[ ! -s "$OUTPUT_FILE" ]]; then
+ echo "Error: Generated diff is empty" >&2
+ exit 1
+fi
+
+echo "" >&2
+echo "========================================="
+echo "Diff saved to: ${OUTPUT_FILE}"
+echo "========================================="
+echo ""
+cat "$OUTPUT_FILE"
+echo ""
+echo "========================================="
+echo ""
+echo "Next steps:"
+echo "1. Review the diff above"
+echo "2. Submit to https://bugs.freebsd.org/bugzilla/"
+echo "3. Use ./freebsd-port-issue-body.sh to generate the issue content"
+echo ""
+echo "For FreeBSD testing (optional but recommended):"
+echo " cd /usr/ports/security/netbird"
+echo " patch < ${OUTPUT_FILE}"
+echo " make stage && make stage-qa && make package && make install"
+echo " netbird status"
+echo " make deinstall"
diff --git a/release_files/freebsd-port-issue-body.sh b/release_files/freebsd-port-issue-body.sh
new file mode 100755
index 000000000..b7ad0f5b1
--- /dev/null
+++ b/release_files/freebsd-port-issue-body.sh
@@ -0,0 +1,159 @@
+#!/bin/bash
+#
+# FreeBSD Port Issue Body Generator for NetBird
+#
+# This script generates the issue body content for submitting a FreeBSD port update
+# to the FreeBSD Bugzilla at https://bugs.freebsd.org/bugzilla/
+#
+# Usage: ./freebsd-port-issue-body.sh [old_version] [new_version]
+# Example: ./freebsd-port-issue-body.sh 0.56.0 0.59.1
+#
+# If no versions are provided, the script will:
+# - Fetch OLD version from FreeBSD ports cgit (current version in ports tree)
+# - Fetch NEW version from latest NetBird GitHub release tag
+
+set -e
+
+GITHUB_REPO="netbirdio/netbird"
+PORTS_CGIT_URL="https://cgit.freebsd.org/ports/plain/security/netbird/Makefile"
+
+fetch_current_ports_version() {
+ echo "Fetching current version from FreeBSD ports..." >&2
+ local makefile_content
+ makefile_content=$(curl -sL "$PORTS_CGIT_URL" 2>/dev/null)
+ if [[ -z "$makefile_content" ]]; then
+ echo "Error: Could not fetch Makefile from FreeBSD ports" >&2
+ return 1
+ fi
+ echo "$makefile_content" | grep -E "^DISTVERSION=" | sed 's/DISTVERSION=[[:space:]]*//' | tr -d '\t '
+ return 0
+}
+
+fetch_all_tags() {
+ # Fetch tags from GitHub tags page (no rate limiting, no auth needed)
+ curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
+ grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
+ sed 's/.*\/v//' | \
+ sort -u -V
+ return 0
+}
+
+fetch_latest_github_release() {
+ echo "Fetching latest release from GitHub..." >&2
+ local latest
+
+ # Fetch from GitHub tags page
+ latest=$(fetch_all_tags | tail -1)
+
+ if [[ -z "$latest" ]]; then
+ # Fallback to GitHub API
+ latest=$(curl -sL "https://api.github.com/repos/${GITHUB_REPO}/releases/latest" 2>/dev/null | \
+ grep '"tag_name"' | sed 's/.*"tag_name": *"v\([^"]*\)".*/\1/')
+ fi
+
+ if [[ -z "$latest" ]]; then
+ echo "Error: Could not fetch latest release from GitHub" >&2
+ return 1
+ fi
+ echo "$latest"
+ return 0
+}
+
+OLD_VERSION="${1:-}"
+NEW_VERSION="${2:-}"
+
+if [[ -z "$OLD_VERSION" ]]; then
+ OLD_VERSION=$(fetch_current_ports_version)
+ if [[ -z "$OLD_VERSION" ]]; then
+ echo "Error: Could not determine old version. Please provide it manually." >&2
+ echo "Usage: $0 " >&2
+ exit 1
+ fi
+ echo "Detected OLD version from FreeBSD ports: $OLD_VERSION" >&2
+fi
+
+if [[ -z "$NEW_VERSION" ]]; then
+ NEW_VERSION=$(fetch_latest_github_release)
+ if [[ -z "$NEW_VERSION" ]]; then
+ echo "Error: Could not determine new version. Please provide it manually." >&2
+ echo "Usage: $0 " >&2
+ exit 1
+ fi
+ echo "Detected NEW version from GitHub: $NEW_VERSION" >&2
+fi
+
+if [[ "$OLD_VERSION" = "$NEW_VERSION" ]]; then
+ echo "Warning: OLD and NEW versions are the same ($OLD_VERSION). Port may already be up to date." >&2
+fi
+
+echo "" >&2
+
+OUTPUT_DIR="${OUTPUT_DIR:-.}"
+
+fetch_releases_between_versions() {
+ echo "Fetching release history from GitHub..." >&2
+
+ # Fetch all tags and filter to those between OLD and NEW versions
+ fetch_all_tags | \
+ while read -r ver; do
+ if [[ "$(printf '%s\n' "$OLD_VERSION" "$ver" | sort -V | head -n1)" = "$OLD_VERSION" ]] && \
+ [[ "$(printf '%s\n' "$ver" "$NEW_VERSION" | sort -V | head -n1)" = "$ver" ]] && \
+ [[ "$ver" != "$OLD_VERSION" ]]; then
+ echo "$ver"
+ fi
+ done
+ return 0
+}
+
+generate_changelog_section() {
+ local releases
+ releases=$(fetch_releases_between_versions)
+
+ echo "Changelogs:"
+ if [[ -n "$releases" ]]; then
+ echo "$releases" | while read -r ver; do
+ echo "https://github.com/${GITHUB_REPO}/releases/tag/v${ver}"
+ done
+ else
+ echo "https://github.com/${GITHUB_REPO}/releases/tag/v${NEW_VERSION}"
+ fi
+ return 0
+}
+
+OUTPUT_FILE="${OUTPUT_DIR}/netbird-${NEW_VERSION}-issue.txt"
+
+cat << EOF > "$OUTPUT_FILE"
+BUGZILLA ISSUE DETAILS
+======================
+
+Severity: Affects Some People
+
+Summary: security/netbird: Update to ${NEW_VERSION}
+
+Description:
+------------
+security/netbird: Update ${OLD_VERSION} => ${NEW_VERSION}
+
+$(generate_changelog_section)
+
+Commit log:
+https://github.com/${GITHUB_REPO}/compare/v${OLD_VERSION}...v${NEW_VERSION}
+EOF
+
+echo "========================================="
+echo "Issue body saved to: ${OUTPUT_FILE}"
+echo "========================================="
+echo ""
+cat "$OUTPUT_FILE"
+echo ""
+echo "========================================="
+echo ""
+echo "Next steps:"
+echo "1. Go to https://bugs.freebsd.org/bugzilla/ and login"
+echo "2. Click 'Report an update or defect to a port'"
+echo "3. Fill in:"
+echo " - Severity: Affects Some People"
+echo " - Summary: security/netbird: Update to ${NEW_VERSION}"
+echo " - Description: Copy content from ${OUTPUT_FILE}"
+echo "4. Attach diff file: netbird-${NEW_VERSION}.diff"
+echo "5. Submit the bug report"
diff --git a/release_files/install.sh b/release_files/install.sh
index 5d5349ec4..6a2c5f458 100755
--- a/release_files/install.sh
+++ b/release_files/install.sh
@@ -29,6 +29,8 @@ if [ -z ${NETBIRD_RELEASE+x} ]; then
NETBIRD_RELEASE=latest
fi
+TAG_NAME=""
+
get_release() {
local RELEASE=$1
if [ "$RELEASE" = "latest" ]; then
@@ -38,17 +40,19 @@ get_release() {
local TAG="tags/${RELEASE}"
local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}"
fi
+ OUTPUT=""
if [ -n "$GITHUB_TOKEN" ]; then
- curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \
- | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
+ OUTPUT=$(curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}")
else
- curl -s "${URL}" \
- | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
+ OUTPUT=$(curl -s "${URL}")
fi
+ TAG_NAME=$(echo ${OUTPUT} | grep -Eo '\"tag_name\":\s*\"v([0-9]+\.){2}[0-9]+"' | tail -n 1)
+ echo "${TAG_NAME}" | grep -oE 'v[0-9]+\.[0-9]+\.[0-9]+'
}
download_release_binary() {
VERSION=$(get_release "$NETBIRD_RELEASE")
+ echo "Using the following tag name for binary installation: ${TAG_NAME}"
BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download"
BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz"
diff --git a/route/route.go b/route/route.go
index 08a2d37dc..c724e7c7d 100644
--- a/route/route.go
+++ b/route/route.go
@@ -124,6 +124,7 @@ func (r *Route) EventMeta() map[string]any {
func (r *Route) Copy() *Route {
route := &Route{
ID: r.ID,
+ AccountID: r.AccountID,
Description: r.Description,
NetID: r.NetID,
Network: r.Network,
diff --git a/management/server/auth/jwt/extractor.go b/shared/auth/jwt/extractor.go
similarity index 92%
rename from management/server/auth/jwt/extractor.go
rename to shared/auth/jwt/extractor.go
index d270d0ff1..a41d5f07a 100644
--- a/management/server/auth/jwt/extractor.go
+++ b/shared/auth/jwt/extractor.go
@@ -8,7 +8,7 @@ import (
"github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
- nbcontext "github.com/netbirdio/netbird/management/server/context"
+ "github.com/netbirdio/netbird/shared/auth"
)
const (
@@ -87,9 +87,10 @@ func (c ClaimsExtractor) audienceClaim(claimName string) string {
return url
}
-func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, error) {
+// ToUserAuth extracts user authentication information from a JWT token
+func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (auth.UserAuth, error) {
claims := token.Claims.(jwt.MapClaims)
- userAuth := nbcontext.UserAuth{}
+ userAuth := auth.UserAuth{}
userID, ok := claims[c.userIDClaim].(string)
if !ok {
@@ -122,6 +123,7 @@ func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, erro
return userAuth, nil
}
+// ToGroups extracts group information from a JWT token
func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string {
claims := token.Claims.(jwt.MapClaims)
userJWTGroups := make([]string, 0)
diff --git a/management/server/auth/jwt/validator.go b/shared/auth/jwt/validator.go
similarity index 100%
rename from management/server/auth/jwt/validator.go
rename to shared/auth/jwt/validator.go
diff --git a/shared/auth/user.go b/shared/auth/user.go
new file mode 100644
index 000000000..c1bae808e
--- /dev/null
+++ b/shared/auth/user.go
@@ -0,0 +1,28 @@
+package auth
+
+import (
+ "time"
+)
+
+type UserAuth struct {
+ // The account id the user is accessing
+ AccountId string
+ // The account domain
+ Domain string
+ // The account domain category, TBC values
+ DomainCategory string
+ // Indicates whether this user was invited, TBC logic
+ Invited bool
+ // Indicates whether this is a child account
+ IsChild bool
+
+ // The user id
+ UserId string
+ // Last login time for this user
+ LastLogin time.Time
+ // The Groups the user belongs to on this account
+ Groups []string
+
+ // Indicates whether this user has authenticated with a Personal Access Token
+ IsPAT bool
+}
diff --git a/shared/context/keys.go b/shared/context/keys.go
index 5345ee214..c5b5da044 100644
--- a/shared/context/keys.go
+++ b/shared/context/keys.go
@@ -5,4 +5,4 @@ const (
AccountIDKey = "accountID"
UserIDKey = "userID"
PeerIDKey = "peerID"
-)
\ No newline at end of file
+)
diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go
index d4a9f1823..9fbe70948 100644
--- a/shared/management/client/client_test.go
+++ b/shared/management/client/client_test.go
@@ -19,6 +19,12 @@ import (
"github.com/netbirdio/management-integrations/integrations"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
+ "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
+ "github.com/netbirdio/netbird/management/internals/modules/peers"
+ "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
+ nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
+
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/internals/server/config"
@@ -27,8 +33,6 @@ import (
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/mock_server"
- "github.com/netbirdio/netbird/management/server/peers"
- "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -68,7 +72,6 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
}
t.Cleanup(cleanUp)
- peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
ctrl := gomock.NewController(t)
@@ -111,15 +114,22 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
Return(&types.ExtraSettings{}, nil).
AnyTimes()
- accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
+ ctx := context.Background()
+ updateManager := update_channel.NewPeersUpdateManager(metrics)
+ requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
+ networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManger), config)
+ accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
groupsManager := groups.NewManagerMock()
- secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
- mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{})
+ secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
+ if err != nil {
+ t.Fatal(err)
+ }
+ mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController)
if err != nil {
t.Fatal(err)
}
diff --git a/shared/management/client/common/types.go b/shared/management/client/common/types.go
index 699617574..451578358 100644
--- a/shared/management/client/common/types.go
+++ b/shared/management/client/common/types.go
@@ -1,19 +1,20 @@
package common
-// LoginFlag introduces additional login flags to the PKCE authorization request
+// LoginFlag introduces additional login flags to the PKCE authorization request.
+//
+// # Config Values
+//
+// | Value | Flag | OAuth Parameters |
+// |-------|----------------------|-----------------------------------------|
+// | 0 | LoginFlagPromptLogin | prompt=login |
+// | 1 | LoginFlagMaxAge0 | max_age=0 |
type LoginFlag uint8
const (
- // LoginFlagPrompt adds prompt=login to the authorization request
- LoginFlagPrompt LoginFlag = iota
+ // LoginFlagPromptLogin adds prompt=login to the authorization request
+ LoginFlagPromptLogin LoginFlag = iota
// LoginFlagMaxAge0 adds max_age=0 to the authorization request
LoginFlagMaxAge0
+ // LoginFlagNone disables all login flags
+ LoginFlagNone
)
-
-func (l LoginFlag) IsPromptLogin() bool {
- return l == LoginFlagPrompt
-}
-
-func (l LoginFlag) IsMaxAge0Login() bool {
- return l == LoginFlagMaxAge0
-}
diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go
index 076f2532b..89860ac9b 100644
--- a/shared/management/client/grpc.go
+++ b/shared/management/client/grpc.go
@@ -55,8 +55,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent)
if err != nil {
- log.Printf("createConnection error: %v", err)
- return err
+ return fmt.Errorf("create connection: %w", err)
}
return nil
}
@@ -112,6 +111,8 @@ func (c *GrpcClient) ready() bool {
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function
func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
+ backOff := defaultBackoff(ctx)
+
operation := func() error {
log.Debugf("management connection state %v", c.conn.GetState())
connState := c.conn.GetState()
@@ -129,10 +130,10 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
return err
}
- return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler)
+ return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler, backOff)
}
- err := backoff.Retry(operation, defaultBackoff(ctx))
+ err := backoff.Retry(operation, backOff)
if err != nil {
log.Warnf("exiting the Management service connection retry loop due to the unrecoverable error: %s", err)
}
@@ -141,7 +142,7 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
}
func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info,
- msgHandler func(msg *proto.SyncResponse) error) error {
+ msgHandler func(msg *proto.SyncResponse) error, backOff backoff.BackOff) error {
ctx, cancelStream := context.WithCancel(ctx)
defer cancelStream()
@@ -159,6 +160,9 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key,
// blocking until error
err = c.receiveEvents(stream, serverPubKey, msgHandler)
+ // we need this reset because after a successful connection and a consequent error, backoff lib doesn't
+ // reset times and next try will start with a long delay
+ backOff.Reset()
if err != nil {
c.notifyDisconnected(err)
s, _ := gstatus.FromError(err)
diff --git a/shared/management/client/rest/groups.go b/shared/management/client/rest/groups.go
index af068e077..7cd9535dd 100644
--- a/shared/management/client/rest/groups.go
+++ b/shared/management/client/rest/groups.go
@@ -4,10 +4,14 @@ import (
"bytes"
"context"
"encoding/json"
+ "errors"
"github.com/netbirdio/netbird/shared/management/http/api"
)
+// ErrGroupNotFound is returned when a group is not found
+var ErrGroupNotFound = errors.New("group not found")
+
// GroupsAPI APIs for Groups, do not use directly
type GroupsAPI struct {
c *Client
@@ -27,6 +31,27 @@ func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) {
return ret, err
}
+// GetByName get group by name
+// See more: https://docs.netbird.io/api/resources/groups#list-all-groups
+func (a *GroupsAPI) GetByName(ctx context.Context, groupName string) (*api.Group, error) {
+ params := map[string]string{"name": groupName}
+ resp, err := a.c.NewRequest(ctx, "GET", "/api/groups", nil, params)
+ if err != nil {
+ return nil, err
+ }
+ if resp.Body != nil {
+ defer resp.Body.Close()
+ }
+ ret, err := parseResponse[[]api.Group](resp)
+ if err != nil {
+ return nil, err
+ }
+ if len(ret) == 0 {
+ return nil, ErrGroupNotFound
+ }
+ return &ret[0], nil
+}
+
// Get get group info
// See more: https://docs.netbird.io/api/resources/groups#retrieve-a-group
func (a *GroupsAPI) Get(ctx context.Context, groupID string) (*api.Group, error) {
diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml
index 93578b1ae..c9edcdda6 100644
--- a/shared/management/http/api/openapi.yml
+++ b/shared/management/http/api/openapi.yml
@@ -145,6 +145,10 @@ components:
description: Enables or disables experimental lazy connection
type: boolean
example: true
+ auto_update_version:
+ description: Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1")
+ type: string
+ example: "0.51.2"
required:
- peer_login_expiration_enabled
- peer_login_expiration
@@ -463,6 +467,9 @@ components:
description: (Cloud only) Indicates whether peer needs approval
type: boolean
example: true
+ disapproval_reason:
+ description: (Cloud only) Reason why the peer requires approval
+ type: string
country_code:
$ref: '#/components/schemas/CountryCode'
city_name:
@@ -481,6 +488,8 @@ components:
description: Indicates whether the peer is ephemeral or not
type: boolean
example: false
+ local_flags:
+ $ref: '#/components/schemas/PeerLocalFlags'
required:
- city_name
- connected
@@ -507,6 +516,49 @@ components:
- serial_number
- extra_dns_labels
- ephemeral
+ PeerLocalFlags:
+ type: object
+ properties:
+ rosenpass_enabled:
+ description: Indicates whether Rosenpass is enabled on this peer
+ type: boolean
+ example: true
+ rosenpass_permissive:
+ description: Indicates whether Rosenpass is in permissive mode or not
+ type: boolean
+ example: false
+ server_ssh_allowed:
+ description: Indicates whether SSH access this peer is allowed or not
+ type: boolean
+ example: true
+ disable_client_routes:
+ description: Indicates whether client routes are disabled on this peer or not
+ type: boolean
+ example: false
+ disable_server_routes:
+ description: Indicates whether server routes are disabled on this peer or not
+ type: boolean
+ example: false
+ disable_dns:
+ description: Indicates whether DNS management is disabled on this peer or not
+ type: boolean
+ example: false
+ disable_firewall:
+ description: Indicates whether firewall management is disabled on this peer or not
+ type: boolean
+ example: false
+ block_lan_access:
+ description: Indicates whether LAN access is blocked on this peer when used as a routing peer
+ type: boolean
+ example: false
+ block_inbound:
+ description: Indicates whether inbound traffic is blocked on this peer
+ type: boolean
+ example: false
+ lazy_connection_enabled:
+ description: Indicates whether lazy connection is enabled on this peer
+ type: boolean
+ example: false
PeerTemporaryAccessRequest:
type: object
properties:
@@ -929,7 +981,7 @@ components:
protocol:
description: Policy rule type of the traffic
type: string
- enum: ["all", "tcp", "udp", "icmp"]
+ enum: ["all", "tcp", "udp", "icmp", "netbird-ssh"]
example: "tcp"
ports:
description: Policy rule affected ports
@@ -942,6 +994,14 @@ components:
type: array
items:
$ref: '#/components/schemas/RulePortRange'
+ authorized_groups:
+ description: Map of user group ids to a list of local users
+ type: object
+ additionalProperties:
+ type: array
+ items:
+ type: string
+ example: "group1"
required:
- name
- enabled
@@ -3359,6 +3419,14 @@ paths:
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
+ parameters:
+ - in: query
+ name: name
+ required: false
+ schema:
+ type: string
+ description: Filter groups by name (exact match)
+ example: "devs"
responses:
'200':
description: A JSON Array of Groups
@@ -3372,6 +3440,8 @@ paths:
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
+ '404':
+ "$ref": "#/components/responses/not_found"
'403':
"$ref": "#/components/responses/forbidden"
'500':
diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go
index 3dbb32ef6..f242f5a18 100644
--- a/shared/management/http/api/types.gen.go
+++ b/shared/management/http/api/types.gen.go
@@ -130,10 +130,11 @@ const (
// Defines values for PolicyRuleProtocol.
const (
- PolicyRuleProtocolAll PolicyRuleProtocol = "all"
- PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp"
- PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp"
- PolicyRuleProtocolUdp PolicyRuleProtocol = "udp"
+ PolicyRuleProtocolAll PolicyRuleProtocol = "all"
+ PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp"
+ PolicyRuleProtocolNetbirdSsh PolicyRuleProtocol = "netbird-ssh"
+ PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp"
+ PolicyRuleProtocolUdp PolicyRuleProtocol = "udp"
)
// Defines values for PolicyRuleMinimumAction.
@@ -144,10 +145,11 @@ const (
// Defines values for PolicyRuleMinimumProtocol.
const (
- PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all"
- PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp"
- PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp"
- PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp"
+ PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all"
+ PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp"
+ PolicyRuleMinimumProtocolNetbirdSsh PolicyRuleMinimumProtocol = "netbird-ssh"
+ PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp"
+ PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp"
)
// Defines values for PolicyRuleUpdateAction.
@@ -158,10 +160,11 @@ const (
// Defines values for PolicyRuleUpdateProtocol.
const (
- PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all"
- PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp"
- PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp"
- PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp"
+ PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all"
+ PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp"
+ PolicyRuleUpdateProtocolNetbirdSsh PolicyRuleUpdateProtocol = "netbird-ssh"
+ PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp"
+ PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp"
)
// Defines values for ResourceType.
@@ -291,6 +294,9 @@ type AccountRequest struct {
// AccountSettings defines model for AccountSettings.
type AccountSettings struct {
+ // AutoUpdateVersion Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1")
+ AutoUpdateVersion *string `json:"auto_update_version,omitempty"`
+
// DnsDomain Allows to define a custom dns domain for the account
DnsDomain *string `json:"dns_domain,omitempty"`
Extra *AccountExtraSettings `json:"extra,omitempty"`
@@ -1037,6 +1043,9 @@ type Peer struct {
// CreatedAt Peer creation date (UTC)
CreatedAt time.Time `json:"created_at"`
+ // DisapprovalReason (Cloud only) Reason why the peer requires approval
+ DisapprovalReason *string `json:"disapproval_reason,omitempty"`
+
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
@@ -1071,7 +1080,8 @@ type Peer struct {
LastLogin time.Time `json:"last_login"`
// LastSeen Last time peer connected to Netbird's management service
- LastSeen time.Time `json:"last_seen"`
+ LastSeen time.Time `json:"last_seen"`
+ LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"`
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
@@ -1124,6 +1134,9 @@ type PeerBatch struct {
// CreatedAt Peer creation date (UTC)
CreatedAt time.Time `json:"created_at"`
+ // DisapprovalReason (Cloud only) Reason why the peer requires approval
+ DisapprovalReason *string `json:"disapproval_reason,omitempty"`
+
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
@@ -1158,7 +1171,8 @@ type PeerBatch struct {
LastLogin time.Time `json:"last_login"`
// LastSeen Last time peer connected to Netbird's management service
- LastSeen time.Time `json:"last_seen"`
+ LastSeen time.Time `json:"last_seen"`
+ LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"`
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
@@ -1188,6 +1202,39 @@ type PeerBatch struct {
Version string `json:"version"`
}
+// PeerLocalFlags defines model for PeerLocalFlags.
+type PeerLocalFlags struct {
+ // BlockInbound Indicates whether inbound traffic is blocked on this peer
+ BlockInbound *bool `json:"block_inbound,omitempty"`
+
+ // BlockLanAccess Indicates whether LAN access is blocked on this peer when used as a routing peer
+ BlockLanAccess *bool `json:"block_lan_access,omitempty"`
+
+ // DisableClientRoutes Indicates whether client routes are disabled on this peer or not
+ DisableClientRoutes *bool `json:"disable_client_routes,omitempty"`
+
+ // DisableDns Indicates whether DNS management is disabled on this peer or not
+ DisableDns *bool `json:"disable_dns,omitempty"`
+
+ // DisableFirewall Indicates whether firewall management is disabled on this peer or not
+ DisableFirewall *bool `json:"disable_firewall,omitempty"`
+
+ // DisableServerRoutes Indicates whether server routes are disabled on this peer or not
+ DisableServerRoutes *bool `json:"disable_server_routes,omitempty"`
+
+ // LazyConnectionEnabled Indicates whether lazy connection is enabled on this peer
+ LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"`
+
+ // RosenpassEnabled Indicates whether Rosenpass is enabled on this peer
+ RosenpassEnabled *bool `json:"rosenpass_enabled,omitempty"`
+
+ // RosenpassPermissive Indicates whether Rosenpass is in permissive mode or not
+ RosenpassPermissive *bool `json:"rosenpass_permissive,omitempty"`
+
+ // ServerSshAllowed Indicates whether SSH access this peer is allowed or not
+ ServerSshAllowed *bool `json:"server_ssh_allowed,omitempty"`
+}
+
// PeerMinimum defines model for PeerMinimum.
type PeerMinimum struct {
// Id Peer ID
@@ -1340,6 +1387,9 @@ type PolicyRule struct {
// Action Policy rule accept or drops packets
Action PolicyRuleAction `json:"action"`
+ // AuthorizedGroups Map of user group ids to a list of local users
+ AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
+
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
Bidirectional bool `json:"bidirectional"`
@@ -1384,6 +1434,9 @@ type PolicyRuleMinimum struct {
// Action Policy rule accept or drops packets
Action PolicyRuleMinimumAction `json:"action"`
+ // AuthorizedGroups Map of user group ids to a list of local users
+ AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
+
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
Bidirectional bool `json:"bidirectional"`
@@ -1417,6 +1470,9 @@ type PolicyRuleUpdate struct {
// Action Policy rule accept or drops packets
Action PolicyRuleUpdateAction `json:"action"`
+ // AuthorizedGroups Map of user group ids to a list of local users
+ AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
+
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
Bidirectional bool `json:"bidirectional"`
@@ -1902,6 +1958,12 @@ type GetApiEventsNetworkTrafficParamsConnectionType string
// GetApiEventsNetworkTrafficParamsDirection defines parameters for GetApiEventsNetworkTraffic.
type GetApiEventsNetworkTrafficParamsDirection string
+// GetApiGroupsParams defines parameters for GetApiGroups.
+type GetApiGroupsParams struct {
+ // Name Filter groups by name (exact match)
+ Name *string `form:"name,omitempty" json:"name,omitempty"`
+}
+
// GetApiPeersParams defines parameters for GetApiPeers.
type GetApiPeersParams struct {
// Name Filter peers by name
diff --git a/shared/management/http/util/util.go b/shared/management/http/util/util.go
index 3ae321023..0a29469da 100644
--- a/shared/management/http/util/util.go
+++ b/shared/management/http/util/util.go
@@ -106,6 +106,8 @@ func WriteError(ctx context.Context, err error, w http.ResponseWriter) {
httpStatus = http.StatusUnauthorized
case status.BadRequest:
httpStatus = http.StatusBadRequest
+ case status.TooManyRequests:
+ httpStatus = http.StatusTooManyRequests
default:
}
msg = strings.ToLower(err.Error())
diff --git a/shared/management/operations/operation.go b/shared/management/operations/operation.go
index b9b500362..b1ba12815 100644
--- a/shared/management/operations/operation.go
+++ b/shared/management/operations/operation.go
@@ -1,4 +1,4 @@
package operations
// Operation represents a permission operation type
-type Operation string
\ No newline at end of file
+type Operation string
diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go
index 0de00ec0c..2047c51ea 100644
--- a/shared/management/proto/management.pb.go
+++ b/shared/management/proto/management.pb.go
@@ -1,19 +1,18 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
-// protoc v6.32.0
+// protoc v6.33.1
// source: management.proto
package proto
import (
- reflect "reflect"
- sync "sync"
-
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
durationpb "google.golang.org/protobuf/types/known/durationpb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
+ reflect "reflect"
+ sync "sync"
)
const (
@@ -268,7 +267,7 @@ func (x DeviceAuthorizationFlowProvider) Number() protoreflect.EnumNumber {
// Deprecated: Use DeviceAuthorizationFlowProvider.Descriptor instead.
func (DeviceAuthorizationFlowProvider) EnumDescriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{23, 0}
+ return file_management_proto_rawDescGZIP(), []int{27, 0}
}
type EncryptedMessage struct {
@@ -799,16 +798,21 @@ type Flags struct {
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
- RosenpassEnabled bool `protobuf:"varint,1,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"`
- RosenpassPermissive bool `protobuf:"varint,2,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"`
- ServerSSHAllowed bool `protobuf:"varint,3,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"`
- DisableClientRoutes bool `protobuf:"varint,4,opt,name=disableClientRoutes,proto3" json:"disableClientRoutes,omitempty"`
- DisableServerRoutes bool `protobuf:"varint,5,opt,name=disableServerRoutes,proto3" json:"disableServerRoutes,omitempty"`
- DisableDNS bool `protobuf:"varint,6,opt,name=disableDNS,proto3" json:"disableDNS,omitempty"`
- DisableFirewall bool `protobuf:"varint,7,opt,name=disableFirewall,proto3" json:"disableFirewall,omitempty"`
- BlockLANAccess bool `protobuf:"varint,8,opt,name=blockLANAccess,proto3" json:"blockLANAccess,omitempty"`
- BlockInbound bool `protobuf:"varint,9,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"`
- LazyConnectionEnabled bool `protobuf:"varint,10,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"`
+ RosenpassEnabled bool `protobuf:"varint,1,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"`
+ RosenpassPermissive bool `protobuf:"varint,2,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"`
+ ServerSSHAllowed bool `protobuf:"varint,3,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"`
+ DisableClientRoutes bool `protobuf:"varint,4,opt,name=disableClientRoutes,proto3" json:"disableClientRoutes,omitempty"`
+ DisableServerRoutes bool `protobuf:"varint,5,opt,name=disableServerRoutes,proto3" json:"disableServerRoutes,omitempty"`
+ DisableDNS bool `protobuf:"varint,6,opt,name=disableDNS,proto3" json:"disableDNS,omitempty"`
+ DisableFirewall bool `protobuf:"varint,7,opt,name=disableFirewall,proto3" json:"disableFirewall,omitempty"`
+ BlockLANAccess bool `protobuf:"varint,8,opt,name=blockLANAccess,proto3" json:"blockLANAccess,omitempty"`
+ BlockInbound bool `protobuf:"varint,9,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"`
+ LazyConnectionEnabled bool `protobuf:"varint,10,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"`
+ EnableSSHRoot bool `protobuf:"varint,11,opt,name=enableSSHRoot,proto3" json:"enableSSHRoot,omitempty"`
+ EnableSSHSFTP bool `protobuf:"varint,12,opt,name=enableSSHSFTP,proto3" json:"enableSSHSFTP,omitempty"`
+ EnableSSHLocalPortForwarding bool `protobuf:"varint,13,opt,name=enableSSHLocalPortForwarding,proto3" json:"enableSSHLocalPortForwarding,omitempty"`
+ EnableSSHRemotePortForwarding bool `protobuf:"varint,14,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"`
+ DisableSSHAuth bool `protobuf:"varint,15,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"`
}
func (x *Flags) Reset() {
@@ -913,6 +917,41 @@ func (x *Flags) GetLazyConnectionEnabled() bool {
return false
}
+func (x *Flags) GetEnableSSHRoot() bool {
+ if x != nil {
+ return x.EnableSSHRoot
+ }
+ return false
+}
+
+func (x *Flags) GetEnableSSHSFTP() bool {
+ if x != nil {
+ return x.EnableSSHSFTP
+ }
+ return false
+}
+
+func (x *Flags) GetEnableSSHLocalPortForwarding() bool {
+ if x != nil {
+ return x.EnableSSHLocalPortForwarding
+ }
+ return false
+}
+
+func (x *Flags) GetEnableSSHRemotePortForwarding() bool {
+ if x != nil {
+ return x.EnableSSHRemotePortForwarding
+ }
+ return false
+}
+
+func (x *Flags) GetDisableSSHAuth() bool {
+ if x != nil {
+ return x.DisableSSHAuth
+ }
+ return false
+}
+
// PeerSystemMeta is machine meta data like OS and version.
type PeerSystemMeta struct {
state protoimpl.MessageState
@@ -1568,6 +1607,78 @@ func (x *FlowConfig) GetDnsCollection() bool {
return false
}
+// JWTConfig represents JWT authentication configuration
+type JWTConfig struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ Issuer string `protobuf:"bytes,1,opt,name=issuer,proto3" json:"issuer,omitempty"`
+ Audience string `protobuf:"bytes,2,opt,name=audience,proto3" json:"audience,omitempty"`
+ KeysLocation string `protobuf:"bytes,3,opt,name=keysLocation,proto3" json:"keysLocation,omitempty"`
+ MaxTokenAge int64 `protobuf:"varint,4,opt,name=maxTokenAge,proto3" json:"maxTokenAge,omitempty"`
+}
+
+func (x *JWTConfig) Reset() {
+ *x = JWTConfig{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_management_proto_msgTypes[17]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *JWTConfig) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*JWTConfig) ProtoMessage() {}
+
+func (x *JWTConfig) ProtoReflect() protoreflect.Message {
+ mi := &file_management_proto_msgTypes[17]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use JWTConfig.ProtoReflect.Descriptor instead.
+func (*JWTConfig) Descriptor() ([]byte, []int) {
+ return file_management_proto_rawDescGZIP(), []int{17}
+}
+
+func (x *JWTConfig) GetIssuer() string {
+ if x != nil {
+ return x.Issuer
+ }
+ return ""
+}
+
+func (x *JWTConfig) GetAudience() string {
+ if x != nil {
+ return x.Audience
+ }
+ return ""
+}
+
+func (x *JWTConfig) GetKeysLocation() string {
+ if x != nil {
+ return x.KeysLocation
+ }
+ return ""
+}
+
+func (x *JWTConfig) GetMaxTokenAge() int64 {
+ if x != nil {
+ return x.MaxTokenAge
+ }
+ return 0
+}
+
// ProtectedHostConfig is similar to HostConfig but has additional user and password
// Mostly used for TURN servers
type ProtectedHostConfig struct {
@@ -1583,7 +1694,7 @@ type ProtectedHostConfig struct {
func (x *ProtectedHostConfig) Reset() {
*x = ProtectedHostConfig{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[17]
+ mi := &file_management_proto_msgTypes[18]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1596,7 +1707,7 @@ func (x *ProtectedHostConfig) String() string {
func (*ProtectedHostConfig) ProtoMessage() {}
func (x *ProtectedHostConfig) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[17]
+ mi := &file_management_proto_msgTypes[18]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1609,7 +1720,7 @@ func (x *ProtectedHostConfig) ProtoReflect() protoreflect.Message {
// Deprecated: Use ProtectedHostConfig.ProtoReflect.Descriptor instead.
func (*ProtectedHostConfig) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{17}
+ return file_management_proto_rawDescGZIP(), []int{18}
}
func (x *ProtectedHostConfig) GetHostConfig() *HostConfig {
@@ -1651,12 +1762,14 @@ type PeerConfig struct {
RoutingPeerDnsResolutionEnabled bool `protobuf:"varint,5,opt,name=RoutingPeerDnsResolutionEnabled,proto3" json:"RoutingPeerDnsResolutionEnabled,omitempty"`
LazyConnectionEnabled bool `protobuf:"varint,6,opt,name=LazyConnectionEnabled,proto3" json:"LazyConnectionEnabled,omitempty"`
Mtu int32 `protobuf:"varint,7,opt,name=mtu,proto3" json:"mtu,omitempty"`
+ // Auto-update config
+ AutoUpdate *AutoUpdateSettings `protobuf:"bytes,8,opt,name=autoUpdate,proto3" json:"autoUpdate,omitempty"`
}
func (x *PeerConfig) Reset() {
*x = PeerConfig{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[18]
+ mi := &file_management_proto_msgTypes[19]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1669,7 +1782,7 @@ func (x *PeerConfig) String() string {
func (*PeerConfig) ProtoMessage() {}
func (x *PeerConfig) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[18]
+ mi := &file_management_proto_msgTypes[19]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1682,7 +1795,7 @@ func (x *PeerConfig) ProtoReflect() protoreflect.Message {
// Deprecated: Use PeerConfig.ProtoReflect.Descriptor instead.
func (*PeerConfig) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{18}
+ return file_management_proto_rawDescGZIP(), []int{19}
}
func (x *PeerConfig) GetAddress() string {
@@ -1734,6 +1847,70 @@ func (x *PeerConfig) GetMtu() int32 {
return 0
}
+func (x *PeerConfig) GetAutoUpdate() *AutoUpdateSettings {
+ if x != nil {
+ return x.AutoUpdate
+ }
+ return nil
+}
+
+type AutoUpdateSettings struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"`
+ // alwaysUpdate = true → Updates happen automatically in the background
+ // alwaysUpdate = false → Updates only happen when triggered by a peer connection
+ AlwaysUpdate bool `protobuf:"varint,2,opt,name=alwaysUpdate,proto3" json:"alwaysUpdate,omitempty"`
+}
+
+func (x *AutoUpdateSettings) Reset() {
+ *x = AutoUpdateSettings{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_management_proto_msgTypes[20]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *AutoUpdateSettings) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*AutoUpdateSettings) ProtoMessage() {}
+
+func (x *AutoUpdateSettings) ProtoReflect() protoreflect.Message {
+ mi := &file_management_proto_msgTypes[20]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use AutoUpdateSettings.ProtoReflect.Descriptor instead.
+func (*AutoUpdateSettings) Descriptor() ([]byte, []int) {
+ return file_management_proto_rawDescGZIP(), []int{20}
+}
+
+func (x *AutoUpdateSettings) GetVersion() string {
+ if x != nil {
+ return x.Version
+ }
+ return ""
+}
+
+func (x *AutoUpdateSettings) GetAlwaysUpdate() bool {
+ if x != nil {
+ return x.AlwaysUpdate
+ }
+ return false
+}
+
// NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections
type NetworkMap struct {
state protoimpl.MessageState
@@ -1765,12 +1942,14 @@ type NetworkMap struct {
// RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality.
RoutesFirewallRulesIsEmpty bool `protobuf:"varint,11,opt,name=routesFirewallRulesIsEmpty,proto3" json:"routesFirewallRulesIsEmpty,omitempty"`
ForwardingRules []*ForwardingRule `protobuf:"bytes,12,rep,name=forwardingRules,proto3" json:"forwardingRules,omitempty"`
+ // SSHAuth represents SSH authorization configuration
+ SshAuth *SSHAuth `protobuf:"bytes,13,opt,name=sshAuth,proto3" json:"sshAuth,omitempty"`
}
func (x *NetworkMap) Reset() {
*x = NetworkMap{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[19]
+ mi := &file_management_proto_msgTypes[21]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1783,7 +1962,7 @@ func (x *NetworkMap) String() string {
func (*NetworkMap) ProtoMessage() {}
func (x *NetworkMap) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[19]
+ mi := &file_management_proto_msgTypes[21]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1796,7 +1975,7 @@ func (x *NetworkMap) ProtoReflect() protoreflect.Message {
// Deprecated: Use NetworkMap.ProtoReflect.Descriptor instead.
func (*NetworkMap) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{19}
+ return file_management_proto_rawDescGZIP(), []int{21}
}
func (x *NetworkMap) GetSerial() uint64 {
@@ -1883,6 +2062,126 @@ func (x *NetworkMap) GetForwardingRules() []*ForwardingRule {
return nil
}
+func (x *NetworkMap) GetSshAuth() *SSHAuth {
+ if x != nil {
+ return x.SshAuth
+ }
+ return nil
+}
+
+type SSHAuth struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ // UserIDClaim is the JWT claim to be used to get the users ID
+ UserIDClaim string `protobuf:"bytes,1,opt,name=UserIDClaim,proto3" json:"UserIDClaim,omitempty"`
+ // AuthorizedUsers is a list of hashed user IDs authorized to access this peer via SSH
+ AuthorizedUsers [][]byte `protobuf:"bytes,2,rep,name=AuthorizedUsers,proto3" json:"AuthorizedUsers,omitempty"`
+ // MachineUsers is a map of machine user names to their corresponding indexes in the AuthorizedUsers list
+ MachineUsers map[string]*MachineUserIndexes `protobuf:"bytes,3,rep,name=machine_users,json=machineUsers,proto3" json:"machine_users,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
+}
+
+func (x *SSHAuth) Reset() {
+ *x = SSHAuth{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_management_proto_msgTypes[22]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *SSHAuth) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*SSHAuth) ProtoMessage() {}
+
+func (x *SSHAuth) ProtoReflect() protoreflect.Message {
+ mi := &file_management_proto_msgTypes[22]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use SSHAuth.ProtoReflect.Descriptor instead.
+func (*SSHAuth) Descriptor() ([]byte, []int) {
+ return file_management_proto_rawDescGZIP(), []int{22}
+}
+
+func (x *SSHAuth) GetUserIDClaim() string {
+ if x != nil {
+ return x.UserIDClaim
+ }
+ return ""
+}
+
+func (x *SSHAuth) GetAuthorizedUsers() [][]byte {
+ if x != nil {
+ return x.AuthorizedUsers
+ }
+ return nil
+}
+
+func (x *SSHAuth) GetMachineUsers() map[string]*MachineUserIndexes {
+ if x != nil {
+ return x.MachineUsers
+ }
+ return nil
+}
+
+type MachineUserIndexes struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ Indexes []uint32 `protobuf:"varint,1,rep,packed,name=indexes,proto3" json:"indexes,omitempty"`
+}
+
+func (x *MachineUserIndexes) Reset() {
+ *x = MachineUserIndexes{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_management_proto_msgTypes[23]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *MachineUserIndexes) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*MachineUserIndexes) ProtoMessage() {}
+
+func (x *MachineUserIndexes) ProtoReflect() protoreflect.Message {
+ mi := &file_management_proto_msgTypes[23]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use MachineUserIndexes.ProtoReflect.Descriptor instead.
+func (*MachineUserIndexes) Descriptor() ([]byte, []int) {
+ return file_management_proto_rawDescGZIP(), []int{23}
+}
+
+func (x *MachineUserIndexes) GetIndexes() []uint32 {
+ if x != nil {
+ return x.Indexes
+ }
+ return nil
+}
+
// RemotePeerConfig represents a configuration of a remote peer.
// The properties are used to configure WireGuard Peers sections
type RemotePeerConfig struct {
@@ -1904,7 +2203,7 @@ type RemotePeerConfig struct {
func (x *RemotePeerConfig) Reset() {
*x = RemotePeerConfig{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[20]
+ mi := &file_management_proto_msgTypes[24]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1917,7 +2216,7 @@ func (x *RemotePeerConfig) String() string {
func (*RemotePeerConfig) ProtoMessage() {}
func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[20]
+ mi := &file_management_proto_msgTypes[24]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1930,7 +2229,7 @@ func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message {
// Deprecated: Use RemotePeerConfig.ProtoReflect.Descriptor instead.
func (*RemotePeerConfig) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{20}
+ return file_management_proto_rawDescGZIP(), []int{24}
}
func (x *RemotePeerConfig) GetWgPubKey() string {
@@ -1978,13 +2277,14 @@ type SSHConfig struct {
SshEnabled bool `protobuf:"varint,1,opt,name=sshEnabled,proto3" json:"sshEnabled,omitempty"`
// sshPubKey is a SSH public key of a peer to be added to authorized_hosts.
// This property should be ignore if SSHConfig comes from PeerConfig.
- SshPubKey []byte `protobuf:"bytes,2,opt,name=sshPubKey,proto3" json:"sshPubKey,omitempty"`
+ SshPubKey []byte `protobuf:"bytes,2,opt,name=sshPubKey,proto3" json:"sshPubKey,omitempty"`
+ JwtConfig *JWTConfig `protobuf:"bytes,3,opt,name=jwtConfig,proto3" json:"jwtConfig,omitempty"`
}
func (x *SSHConfig) Reset() {
*x = SSHConfig{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[21]
+ mi := &file_management_proto_msgTypes[25]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1997,7 +2297,7 @@ func (x *SSHConfig) String() string {
func (*SSHConfig) ProtoMessage() {}
func (x *SSHConfig) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[21]
+ mi := &file_management_proto_msgTypes[25]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2010,7 +2310,7 @@ func (x *SSHConfig) ProtoReflect() protoreflect.Message {
// Deprecated: Use SSHConfig.ProtoReflect.Descriptor instead.
func (*SSHConfig) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{21}
+ return file_management_proto_rawDescGZIP(), []int{25}
}
func (x *SSHConfig) GetSshEnabled() bool {
@@ -2027,6 +2327,13 @@ func (x *SSHConfig) GetSshPubKey() []byte {
return nil
}
+func (x *SSHConfig) GetJwtConfig() *JWTConfig {
+ if x != nil {
+ return x.JwtConfig
+ }
+ return nil
+}
+
// DeviceAuthorizationFlowRequest empty struct for future expansion
type DeviceAuthorizationFlowRequest struct {
state protoimpl.MessageState
@@ -2037,7 +2344,7 @@ type DeviceAuthorizationFlowRequest struct {
func (x *DeviceAuthorizationFlowRequest) Reset() {
*x = DeviceAuthorizationFlowRequest{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[22]
+ mi := &file_management_proto_msgTypes[26]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2050,7 +2357,7 @@ func (x *DeviceAuthorizationFlowRequest) String() string {
func (*DeviceAuthorizationFlowRequest) ProtoMessage() {}
func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[22]
+ mi := &file_management_proto_msgTypes[26]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2063,7 +2370,7 @@ func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use DeviceAuthorizationFlowRequest.ProtoReflect.Descriptor instead.
func (*DeviceAuthorizationFlowRequest) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{22}
+ return file_management_proto_rawDescGZIP(), []int{26}
}
// DeviceAuthorizationFlow represents Device Authorization Flow information
@@ -2082,7 +2389,7 @@ type DeviceAuthorizationFlow struct {
func (x *DeviceAuthorizationFlow) Reset() {
*x = DeviceAuthorizationFlow{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[23]
+ mi := &file_management_proto_msgTypes[27]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2095,7 +2402,7 @@ func (x *DeviceAuthorizationFlow) String() string {
func (*DeviceAuthorizationFlow) ProtoMessage() {}
func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[23]
+ mi := &file_management_proto_msgTypes[27]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2108,7 +2415,7 @@ func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message {
// Deprecated: Use DeviceAuthorizationFlow.ProtoReflect.Descriptor instead.
func (*DeviceAuthorizationFlow) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{23}
+ return file_management_proto_rawDescGZIP(), []int{27}
}
func (x *DeviceAuthorizationFlow) GetProvider() DeviceAuthorizationFlowProvider {
@@ -2135,7 +2442,7 @@ type PKCEAuthorizationFlowRequest struct {
func (x *PKCEAuthorizationFlowRequest) Reset() {
*x = PKCEAuthorizationFlowRequest{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[24]
+ mi := &file_management_proto_msgTypes[28]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2148,7 +2455,7 @@ func (x *PKCEAuthorizationFlowRequest) String() string {
func (*PKCEAuthorizationFlowRequest) ProtoMessage() {}
func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[24]
+ mi := &file_management_proto_msgTypes[28]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2161,7 +2468,7 @@ func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use PKCEAuthorizationFlowRequest.ProtoReflect.Descriptor instead.
func (*PKCEAuthorizationFlowRequest) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{24}
+ return file_management_proto_rawDescGZIP(), []int{28}
}
// PKCEAuthorizationFlow represents Authorization Code Flow information
@@ -2178,7 +2485,7 @@ type PKCEAuthorizationFlow struct {
func (x *PKCEAuthorizationFlow) Reset() {
*x = PKCEAuthorizationFlow{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[25]
+ mi := &file_management_proto_msgTypes[29]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2191,7 +2498,7 @@ func (x *PKCEAuthorizationFlow) String() string {
func (*PKCEAuthorizationFlow) ProtoMessage() {}
func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[25]
+ mi := &file_management_proto_msgTypes[29]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2204,7 +2511,7 @@ func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message {
// Deprecated: Use PKCEAuthorizationFlow.ProtoReflect.Descriptor instead.
func (*PKCEAuthorizationFlow) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{25}
+ return file_management_proto_rawDescGZIP(), []int{29}
}
func (x *PKCEAuthorizationFlow) GetProviderConfig() *ProviderConfig {
@@ -2250,7 +2557,7 @@ type ProviderConfig struct {
func (x *ProviderConfig) Reset() {
*x = ProviderConfig{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[26]
+ mi := &file_management_proto_msgTypes[30]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2263,7 +2570,7 @@ func (x *ProviderConfig) String() string {
func (*ProviderConfig) ProtoMessage() {}
func (x *ProviderConfig) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[26]
+ mi := &file_management_proto_msgTypes[30]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2276,7 +2583,7 @@ func (x *ProviderConfig) ProtoReflect() protoreflect.Message {
// Deprecated: Use ProviderConfig.ProtoReflect.Descriptor instead.
func (*ProviderConfig) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{26}
+ return file_management_proto_rawDescGZIP(), []int{30}
}
func (x *ProviderConfig) GetClientID() string {
@@ -2384,7 +2691,7 @@ type Route struct {
func (x *Route) Reset() {
*x = Route{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[27]
+ mi := &file_management_proto_msgTypes[31]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2397,7 +2704,7 @@ func (x *Route) String() string {
func (*Route) ProtoMessage() {}
func (x *Route) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[27]
+ mi := &file_management_proto_msgTypes[31]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2410,7 +2717,7 @@ func (x *Route) ProtoReflect() protoreflect.Message {
// Deprecated: Use Route.ProtoReflect.Descriptor instead.
func (*Route) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{27}
+ return file_management_proto_rawDescGZIP(), []int{31}
}
func (x *Route) GetID() string {
@@ -2492,13 +2799,14 @@ type DNSConfig struct {
ServiceEnable bool `protobuf:"varint,1,opt,name=ServiceEnable,proto3" json:"ServiceEnable,omitempty"`
NameServerGroups []*NameServerGroup `protobuf:"bytes,2,rep,name=NameServerGroups,proto3" json:"NameServerGroups,omitempty"`
CustomZones []*CustomZone `protobuf:"bytes,3,rep,name=CustomZones,proto3" json:"CustomZones,omitempty"`
- ForwarderPort int64 `protobuf:"varint,4,opt,name=ForwarderPort,proto3" json:"ForwarderPort,omitempty"`
+ // Deprecated: Do not use.
+ ForwarderPort int64 `protobuf:"varint,4,opt,name=ForwarderPort,proto3" json:"ForwarderPort,omitempty"`
}
func (x *DNSConfig) Reset() {
*x = DNSConfig{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[28]
+ mi := &file_management_proto_msgTypes[32]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2511,7 +2819,7 @@ func (x *DNSConfig) String() string {
func (*DNSConfig) ProtoMessage() {}
func (x *DNSConfig) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[28]
+ mi := &file_management_proto_msgTypes[32]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2524,7 +2832,7 @@ func (x *DNSConfig) ProtoReflect() protoreflect.Message {
// Deprecated: Use DNSConfig.ProtoReflect.Descriptor instead.
func (*DNSConfig) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{28}
+ return file_management_proto_rawDescGZIP(), []int{32}
}
func (x *DNSConfig) GetServiceEnable() bool {
@@ -2548,6 +2856,7 @@ func (x *DNSConfig) GetCustomZones() []*CustomZone {
return nil
}
+// Deprecated: Do not use.
func (x *DNSConfig) GetForwarderPort() int64 {
if x != nil {
return x.ForwarderPort
@@ -2561,14 +2870,16 @@ type CustomZone struct {
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
- Domain string `protobuf:"bytes,1,opt,name=Domain,proto3" json:"Domain,omitempty"`
- Records []*SimpleRecord `protobuf:"bytes,2,rep,name=Records,proto3" json:"Records,omitempty"`
+ Domain string `protobuf:"bytes,1,opt,name=Domain,proto3" json:"Domain,omitempty"`
+ Records []*SimpleRecord `protobuf:"bytes,2,rep,name=Records,proto3" json:"Records,omitempty"`
+ SearchDomainDisabled bool `protobuf:"varint,3,opt,name=SearchDomainDisabled,proto3" json:"SearchDomainDisabled,omitempty"`
+ SkipPTRProcess bool `protobuf:"varint,4,opt,name=SkipPTRProcess,proto3" json:"SkipPTRProcess,omitempty"`
}
func (x *CustomZone) Reset() {
*x = CustomZone{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[29]
+ mi := &file_management_proto_msgTypes[33]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2581,7 +2892,7 @@ func (x *CustomZone) String() string {
func (*CustomZone) ProtoMessage() {}
func (x *CustomZone) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[29]
+ mi := &file_management_proto_msgTypes[33]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2594,7 +2905,7 @@ func (x *CustomZone) ProtoReflect() protoreflect.Message {
// Deprecated: Use CustomZone.ProtoReflect.Descriptor instead.
func (*CustomZone) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{29}
+ return file_management_proto_rawDescGZIP(), []int{33}
}
func (x *CustomZone) GetDomain() string {
@@ -2611,6 +2922,20 @@ func (x *CustomZone) GetRecords() []*SimpleRecord {
return nil
}
+func (x *CustomZone) GetSearchDomainDisabled() bool {
+ if x != nil {
+ return x.SearchDomainDisabled
+ }
+ return false
+}
+
+func (x *CustomZone) GetSkipPTRProcess() bool {
+ if x != nil {
+ return x.SkipPTRProcess
+ }
+ return false
+}
+
// SimpleRecord represents a dns.SimpleRecord
type SimpleRecord struct {
state protoimpl.MessageState
@@ -2627,7 +2952,7 @@ type SimpleRecord struct {
func (x *SimpleRecord) Reset() {
*x = SimpleRecord{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[30]
+ mi := &file_management_proto_msgTypes[34]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2640,7 +2965,7 @@ func (x *SimpleRecord) String() string {
func (*SimpleRecord) ProtoMessage() {}
func (x *SimpleRecord) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[30]
+ mi := &file_management_proto_msgTypes[34]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2653,7 +2978,7 @@ func (x *SimpleRecord) ProtoReflect() protoreflect.Message {
// Deprecated: Use SimpleRecord.ProtoReflect.Descriptor instead.
func (*SimpleRecord) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{30}
+ return file_management_proto_rawDescGZIP(), []int{34}
}
func (x *SimpleRecord) GetName() string {
@@ -2706,7 +3031,7 @@ type NameServerGroup struct {
func (x *NameServerGroup) Reset() {
*x = NameServerGroup{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[31]
+ mi := &file_management_proto_msgTypes[35]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2719,7 +3044,7 @@ func (x *NameServerGroup) String() string {
func (*NameServerGroup) ProtoMessage() {}
func (x *NameServerGroup) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[31]
+ mi := &file_management_proto_msgTypes[35]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2732,7 +3057,7 @@ func (x *NameServerGroup) ProtoReflect() protoreflect.Message {
// Deprecated: Use NameServerGroup.ProtoReflect.Descriptor instead.
func (*NameServerGroup) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{31}
+ return file_management_proto_rawDescGZIP(), []int{35}
}
func (x *NameServerGroup) GetNameServers() []*NameServer {
@@ -2777,7 +3102,7 @@ type NameServer struct {
func (x *NameServer) Reset() {
*x = NameServer{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[32]
+ mi := &file_management_proto_msgTypes[36]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2790,7 +3115,7 @@ func (x *NameServer) String() string {
func (*NameServer) ProtoMessage() {}
func (x *NameServer) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[32]
+ mi := &file_management_proto_msgTypes[36]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2803,7 +3128,7 @@ func (x *NameServer) ProtoReflect() protoreflect.Message {
// Deprecated: Use NameServer.ProtoReflect.Descriptor instead.
func (*NameServer) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{32}
+ return file_management_proto_rawDescGZIP(), []int{36}
}
func (x *NameServer) GetIP() string {
@@ -2846,7 +3171,7 @@ type FirewallRule struct {
func (x *FirewallRule) Reset() {
*x = FirewallRule{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[33]
+ mi := &file_management_proto_msgTypes[37]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2859,7 +3184,7 @@ func (x *FirewallRule) String() string {
func (*FirewallRule) ProtoMessage() {}
func (x *FirewallRule) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[33]
+ mi := &file_management_proto_msgTypes[37]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2872,7 +3197,7 @@ func (x *FirewallRule) ProtoReflect() protoreflect.Message {
// Deprecated: Use FirewallRule.ProtoReflect.Descriptor instead.
func (*FirewallRule) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{33}
+ return file_management_proto_rawDescGZIP(), []int{37}
}
func (x *FirewallRule) GetPeerIP() string {
@@ -2936,7 +3261,7 @@ type NetworkAddress struct {
func (x *NetworkAddress) Reset() {
*x = NetworkAddress{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[34]
+ mi := &file_management_proto_msgTypes[38]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -2949,7 +3274,7 @@ func (x *NetworkAddress) String() string {
func (*NetworkAddress) ProtoMessage() {}
func (x *NetworkAddress) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[34]
+ mi := &file_management_proto_msgTypes[38]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -2962,7 +3287,7 @@ func (x *NetworkAddress) ProtoReflect() protoreflect.Message {
// Deprecated: Use NetworkAddress.ProtoReflect.Descriptor instead.
func (*NetworkAddress) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{34}
+ return file_management_proto_rawDescGZIP(), []int{38}
}
func (x *NetworkAddress) GetNetIP() string {
@@ -2990,7 +3315,7 @@ type Checks struct {
func (x *Checks) Reset() {
*x = Checks{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[35]
+ mi := &file_management_proto_msgTypes[39]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3003,7 +3328,7 @@ func (x *Checks) String() string {
func (*Checks) ProtoMessage() {}
func (x *Checks) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[35]
+ mi := &file_management_proto_msgTypes[39]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3016,7 +3341,7 @@ func (x *Checks) ProtoReflect() protoreflect.Message {
// Deprecated: Use Checks.ProtoReflect.Descriptor instead.
func (*Checks) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{35}
+ return file_management_proto_rawDescGZIP(), []int{39}
}
func (x *Checks) GetFiles() []string {
@@ -3041,7 +3366,7 @@ type PortInfo struct {
func (x *PortInfo) Reset() {
*x = PortInfo{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[36]
+ mi := &file_management_proto_msgTypes[40]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3054,7 +3379,7 @@ func (x *PortInfo) String() string {
func (*PortInfo) ProtoMessage() {}
func (x *PortInfo) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[36]
+ mi := &file_management_proto_msgTypes[40]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3067,7 +3392,7 @@ func (x *PortInfo) ProtoReflect() protoreflect.Message {
// Deprecated: Use PortInfo.ProtoReflect.Descriptor instead.
func (*PortInfo) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{36}
+ return file_management_proto_rawDescGZIP(), []int{40}
}
func (m *PortInfo) GetPortSelection() isPortInfo_PortSelection {
@@ -3138,7 +3463,7 @@ type RouteFirewallRule struct {
func (x *RouteFirewallRule) Reset() {
*x = RouteFirewallRule{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[37]
+ mi := &file_management_proto_msgTypes[41]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3151,7 +3476,7 @@ func (x *RouteFirewallRule) String() string {
func (*RouteFirewallRule) ProtoMessage() {}
func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[37]
+ mi := &file_management_proto_msgTypes[41]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3164,7 +3489,7 @@ func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message {
// Deprecated: Use RouteFirewallRule.ProtoReflect.Descriptor instead.
func (*RouteFirewallRule) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{37}
+ return file_management_proto_rawDescGZIP(), []int{41}
}
func (x *RouteFirewallRule) GetSourceRanges() []string {
@@ -3255,7 +3580,7 @@ type ForwardingRule struct {
func (x *ForwardingRule) Reset() {
*x = ForwardingRule{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[38]
+ mi := &file_management_proto_msgTypes[42]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3268,7 +3593,7 @@ func (x *ForwardingRule) String() string {
func (*ForwardingRule) ProtoMessage() {}
func (x *ForwardingRule) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[38]
+ mi := &file_management_proto_msgTypes[42]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3281,7 +3606,7 @@ func (x *ForwardingRule) ProtoReflect() protoreflect.Message {
// Deprecated: Use ForwardingRule.ProtoReflect.Descriptor instead.
func (*ForwardingRule) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{38}
+ return file_management_proto_rawDescGZIP(), []int{42}
}
func (x *ForwardingRule) GetProtocol() RuleProtocol {
@@ -3324,7 +3649,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{}
if protoimpl.UnsafeEnabled {
- mi := &file_management_proto_msgTypes[39]
+ mi := &file_management_proto_msgTypes[44]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3337,7 +3662,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
- mi := &file_management_proto_msgTypes[39]
+ mi := &file_management_proto_msgTypes[44]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3350,7 +3675,7 @@ func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
// Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead.
func (*PortInfo_Range) Descriptor() ([]byte, []int) {
- return file_management_proto_rawDescGZIP(), []int{36, 0}
+ return file_management_proto_rawDescGZIP(), []int{40, 0}
}
func (x *PortInfo_Range) GetStart() uint32 {
@@ -3438,7 +3763,7 @@ var file_management_proto_rawDesc = []byte{
0x73, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x65, 0x78, 0x69, 0x73, 0x74, 0x12,
0x2a, 0x0a, 0x10, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x49, 0x73, 0x52, 0x75, 0x6e, 0x6e,
0x69, 0x6e, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x70, 0x72, 0x6f, 0x63, 0x65,
- 0x73, 0x73, 0x49, 0x73, 0x52, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x22, 0xc1, 0x03, 0x0a, 0x05,
+ 0x73, 0x73, 0x49, 0x73, 0x52, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x22, 0xbf, 0x05, 0x0a, 0x05,
0x46, 0x6c, 0x61, 0x67, 0x73, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61,
0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52,
0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65,
@@ -3466,435 +3791,500 @@ var file_management_proto_rawDesc = []byte{
0x63, 0x6b, 0x49, 0x6e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x61, 0x7a,
0x79, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c,
0x65, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x6c, 0x61, 0x7a, 0x79, 0x43, 0x6f,
- 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22,
- 0xf2, 0x04, 0x0a, 0x0e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x4d, 0x65,
- 0x74, 0x61, 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01,
- 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x12,
- 0x0a, 0x04, 0x67, 0x6f, 0x4f, 0x53, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x67, 0x6f,
- 0x4f, 0x53, 0x12, 0x16, 0x0a, 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01,
- 0x28, 0x09, 0x52, 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f,
- 0x72, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x12, 0x1a,
- 0x0a, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09,
- 0x52, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x0e, 0x0a, 0x02, 0x4f, 0x53,
- 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x4f, 0x53, 0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x65,
- 0x74, 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01,
- 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69,
- 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18,
- 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e,
- 0x12, 0x24, 0x0a, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f,
- 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56,
- 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73,
- 0x69, 0x6f, 0x6e, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72,
- 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x46, 0x0a, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41,
- 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a,
- 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77,
- 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x10, 0x6e, 0x65, 0x74, 0x77,
- 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0f,
- 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18,
- 0x0c, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c,
- 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x12, 0x26, 0x0a, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f,
- 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e,
- 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x28,
- 0x0a, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65,
- 0x72, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75,
- 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x12, 0x39, 0x0a, 0x0b, 0x65, 0x6e, 0x76, 0x69,
- 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e,
- 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x76, 0x69, 0x72,
- 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d,
- 0x65, 0x6e, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x10, 0x20, 0x03,
- 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
- 0x46, 0x69, 0x6c, 0x65, 0x52, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x12, 0x27, 0x0a, 0x05, 0x66,
- 0x6c, 0x61, 0x67, 0x73, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e,
- 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x52, 0x05, 0x66,
- 0x6c, 0x61, 0x67, 0x73, 0x22, 0xb4, 0x01, 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65,
- 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3f, 0x0a, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72,
- 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e,
- 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x62, 0x69,
- 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72,
- 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43,
- 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61,
- 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e,
- 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12,
- 0x2a, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32,
- 0x12, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65,
- 0x63, 0x6b, 0x73, 0x52, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x22, 0x79, 0x0a, 0x11, 0x53,
- 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
- 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b,
- 0x65, 0x79, 0x12, 0x38, 0x0a, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x18,
- 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70,
- 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d,
- 0x70, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07,
- 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76,
- 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22,
- 0xff, 0x01, 0x0a, 0x0d, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69,
- 0x67, 0x12, 0x2c, 0x0a, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b,
- 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f,
- 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x12,
- 0x35, 0x0a, 0x05, 0x74, 0x75, 0x72, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1f,
- 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x74,
- 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52,
- 0x05, 0x74, 0x75, 0x72, 0x6e, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c,
- 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
- 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x06,
- 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x12, 0x2d, 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18,
- 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
- 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05,
- 0x72, 0x65, 0x6c, 0x61, 0x79, 0x12, 0x2a, 0x0a, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x18, 0x05, 0x20,
- 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
- 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f,
- 0x77, 0x22, 0x98, 0x01, 0x0a, 0x0a, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67,
- 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x69, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75,
- 0x72, 0x69, 0x12, 0x3b, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02,
- 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
- 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f,
- 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22,
- 0x3b, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55,
- 0x44, 0x50, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a,
- 0x04, 0x48, 0x54, 0x54, 0x50, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53,
- 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b,
- 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75,
- 0x72, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12,
- 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18,
- 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c,
- 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e,
- 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b,
- 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a,
- 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72,
- 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c,
- 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01,
- 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64,
- 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75,
- 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53,
- 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x12, 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65,
- 0x72, 0x76, 0x61, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f,
- 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72,
- 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12,
- 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08,
- 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75,
- 0x6e, 0x74, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75,
- 0x6e, 0x74, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64,
- 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28,
- 0x08, 0x52, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65,
- 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c,
- 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e,
- 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x7d, 0x0a, 0x13, 0x50,
- 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66,
- 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67,
- 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
- 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a,
- 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73,
- 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a,
- 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09,
- 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x93, 0x02, 0x0a, 0x0a, 0x50,
- 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64,
- 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72,
- 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09,
- 0x52, 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66,
- 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67,
- 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52,
- 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71,
- 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x48,
- 0x0a, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73,
- 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65,
- 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67,
- 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f,
- 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x4c, 0x61, 0x7a, 0x79,
- 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65,
- 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e,
- 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x10,
- 0x0a, 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75,
- 0x22, 0xb9, 0x05, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12,
- 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52,
- 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43,
- 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61,
- 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e,
- 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12,
- 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03,
- 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
- 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66,
- 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12,
- 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73,
- 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d,
- 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12,
- 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32,
- 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75,
- 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e,
- 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e,
- 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f,
- 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12,
- 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18,
- 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
- 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e,
- 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72,
- 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c,
- 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67,
- 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75,
- 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65,
- 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c,
- 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52,
- 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73,
- 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46,
- 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03,
- 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
- 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c,
- 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c,
- 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73,
- 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45,
- 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74,
- 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49,
- 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x44, 0x0a, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72,
- 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32,
- 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6f, 0x72,
- 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0f, 0x66, 0x6f, 0x72,
- 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a,
- 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69,
- 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20,
- 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a,
- 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28,
- 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a,
- 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b,
- 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53,
- 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66,
- 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09,
- 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56,
- 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67,
- 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53,
- 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e,
- 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68,
- 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75,
- 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50,
- 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41,
- 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77,
- 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69,
+ 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12,
+ 0x24, 0x0a, 0x0d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x52, 0x6f, 0x6f, 0x74,
+ 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53,
+ 0x48, 0x52, 0x6f, 0x6f, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53,
+ 0x53, 0x48, 0x53, 0x46, 0x54, 0x50, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x65, 0x6e,
+ 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x53, 0x46, 0x54, 0x50, 0x12, 0x42, 0x0a, 0x1c, 0x65,
+ 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x6f, 0x72,
+ 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x18, 0x0d, 0x20, 0x01, 0x28,
+ 0x08, 0x52, 0x1c, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x4c, 0x6f, 0x63, 0x61,
+ 0x6c, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x12,
+ 0x44, 0x0a, 0x1d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x52, 0x65, 0x6d, 0x6f,
+ 0x74, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67,
+ 0x18, 0x0e, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53,
+ 0x48, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61,
+ 0x72, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x26, 0x0a, 0x0e, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65,
+ 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x64,
+ 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x22, 0xf2, 0x04,
+ 0x0a, 0x0e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x4d, 0x65, 0x74, 0x61,
+ 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01,
+ 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04,
+ 0x67, 0x6f, 0x4f, 0x53, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x67, 0x6f, 0x4f, 0x53,
+ 0x12, 0x16, 0x0a, 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09,
+ 0x52, 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x72, 0x65,
+ 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08,
+ 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08,
+ 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x0e, 0x0a, 0x02, 0x4f, 0x53, 0x18, 0x06,
+ 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x4f, 0x53, 0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x65, 0x74, 0x62,
+ 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09,
+ 0x52, 0x0e, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e,
+ 0x12, 0x1c, 0x0a, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20,
+ 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x24,
+ 0x0a, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18,
+ 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72,
+ 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f,
+ 0x6e, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, 0x69,
+ 0x6f, 0x6e, 0x12, 0x46, 0x0a, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64,
+ 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d,
+ 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72,
+ 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72,
+ 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0f, 0x73, 0x79,
+ 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18, 0x0c, 0x20,
+ 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75,
+ 0x6d, 0x62, 0x65, 0x72, 0x12, 0x26, 0x0a, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75,
+ 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x73, 0x79,
+ 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x28, 0x0a, 0x0f,
+ 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x18,
+ 0x0e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61,
+ 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x12, 0x39, 0x0a, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f,
+ 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61,
+ 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e,
+ 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e,
+ 0x74, 0x12, 0x26, 0x0a, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x0b,
+ 0x32, 0x10, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69,
+ 0x6c, 0x65, 0x52, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x12, 0x27, 0x0a, 0x05, 0x66, 0x6c, 0x61,
+ 0x67, 0x73, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67,
+ 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x52, 0x05, 0x66, 0x6c, 0x61,
+ 0x67, 0x73, 0x22, 0xb4, 0x01, 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70,
+ 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3f, 0x0a, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43,
+ 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x6d, 0x61,
+ 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64,
+ 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43,
+ 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e,
+ 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61,
+ 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69,
+ 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2a, 0x0a,
+ 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e,
+ 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, 0x63, 0x6b,
+ 0x73, 0x52, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x22, 0x79, 0x0a, 0x11, 0x53, 0x65, 0x72,
+ 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10,
+ 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79,
+ 0x12, 0x38, 0x0a, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x18, 0x02, 0x20,
+ 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f,
+ 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52,
+ 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65,
+ 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72,
+ 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0xff, 0x01,
+ 0x0a, 0x0d, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12,
+ 0x2c, 0x0a, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16,
+ 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74,
+ 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x12, 0x35, 0x0a,
+ 0x05, 0x74, 0x75, 0x72, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x6d,
+ 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63,
+ 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x74,
+ 0x75, 0x72, 0x6e, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x18, 0x03,
+ 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
+ 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x06, 0x73, 0x69,
+ 0x67, 0x6e, 0x61, 0x6c, 0x12, 0x2d, 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18, 0x04, 0x20,
+ 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
+ 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x72, 0x65,
+ 0x6c, 0x61, 0x79, 0x12, 0x2a, 0x0a, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x18, 0x05, 0x20, 0x01, 0x28,
+ 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46,
+ 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x22,
+ 0x98, 0x01, 0x0a, 0x0a, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10,
+ 0x0a, 0x03, 0x75, 0x72, 0x69, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69,
+ 0x12, 0x3b, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01,
+ 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
+ 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x6f,
+ 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x3b, 0x0a,
+ 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50,
+ 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x48,
+ 0x54, 0x54, 0x50, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x03,
+ 0x12, 0x08, 0x0a, 0x04, 0x44, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, 0x52, 0x65,
+ 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x72, 0x6c,
+ 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, 0x22, 0x0a,
+ 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20,
+ 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61,
+ 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74,
+ 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e,
+ 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a, 0x46, 0x6c,
+ 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18,
+ 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f,
+ 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09,
+ 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26,
+ 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65,
+ 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67,
+ 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x12, 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76,
+ 0x61, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c,
+ 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74,
+ 0x69, 0x6f, 0x6e, 0x52, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x18, 0x0a,
+ 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07,
+ 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74,
+ 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74,
+ 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43,
+ 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52,
+ 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74,
+ 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63,
+ 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e, 0x73, 0x43,
+ 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x85, 0x01, 0x0a, 0x09, 0x4a, 0x57,
+ 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x16, 0x0a, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65,
+ 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x12,
+ 0x1a, 0x0a, 0x08, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28,
+ 0x09, 0x52, 0x08, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x6b,
+ 0x65, 0x79, 0x73, 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28,
+ 0x09, 0x52, 0x0c, 0x6b, 0x65, 0x79, 0x73, 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12,
+ 0x20, 0x0a, 0x0b, 0x6d, 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x18, 0x04,
+ 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x6d, 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67,
+ 0x65, 0x22, 0x7d, 0x0a, 0x13, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f,
+ 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74,
+ 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d,
+ 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f,
+ 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67,
+ 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04,
+ 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64,
+ 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64,
+ 0x22, 0xd3, 0x02, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12,
+ 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09,
+ 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73,
+ 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73,
+ 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15,
+ 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43,
+ 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67,
+ 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04,
+ 0x66, 0x71, 0x64, 0x6e, 0x12, 0x48, 0x0a, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50,
+ 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e,
+ 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52,
+ 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73,
+ 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x34,
+ 0x0a, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e,
+ 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c,
+ 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61,
+ 0x62, 0x6c, 0x65, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28,
+ 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, 0x12, 0x3e, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x6f, 0x55, 0x70,
+ 0x64, 0x61, 0x74, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x6d, 0x61, 0x6e,
+ 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61,
+ 0x74, 0x65, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x0a, 0x61, 0x75, 0x74, 0x6f,
+ 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x22, 0x52, 0x0a, 0x12, 0x41, 0x75, 0x74, 0x6f, 0x55, 0x70,
+ 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x18, 0x0a, 0x07,
+ 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76,
+ 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x6c, 0x77, 0x61, 0x79, 0x73,
+ 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x61, 0x6c,
+ 0x77, 0x61, 0x79, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x22, 0xe8, 0x05, 0x0a, 0x0a, 0x4e,
+ 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72,
+ 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61,
+ 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18,
+ 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
+ 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70,
+ 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d,
+ 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c,
+ 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f,
+ 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65,
+ 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d,
+ 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18,
+ 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65,
+ 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75,
+ 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61,
+ 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f,
+ 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69,
+ 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65,
+ 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09,
+ 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66,
+ 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32,
+ 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d,
+ 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f,
+ 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46,
+ 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, 0x03,
+ 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
+ 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69,
+ 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66,
+ 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d,
+ 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77,
+ 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12,
+ 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c,
+ 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d,
+ 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46,
+ 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75,
+ 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73,
+ 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61,
+ 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b,
+ 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65,
+ 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79,
+ 0x12, 0x44, 0x0a, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75,
+ 0x6c, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61,
+ 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e,
+ 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e,
+ 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x2d, 0x0a, 0x07, 0x73, 0x73, 0x68, 0x41, 0x75, 0x74,
+ 0x68, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65,
+ 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x52, 0x07, 0x73, 0x73,
+ 0x68, 0x41, 0x75, 0x74, 0x68, 0x22, 0x82, 0x02, 0x0a, 0x07, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74,
+ 0x68, 0x12, 0x20, 0x0a, 0x0b, 0x55, 0x73, 0x65, 0x72, 0x49, 0x44, 0x43, 0x6c, 0x61, 0x69, 0x6d,
+ 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x55, 0x73, 0x65, 0x72, 0x49, 0x44, 0x43, 0x6c,
+ 0x61, 0x69, 0x6d, 0x12, 0x28, 0x0a, 0x0f, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65,
+ 0x64, 0x55, 0x73, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x0f, 0x41, 0x75,
+ 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x55, 0x73, 0x65, 0x72, 0x73, 0x12, 0x4a, 0x0a,
+ 0x0d, 0x6d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x5f, 0x75, 0x73, 0x65, 0x72, 0x73, 0x18, 0x03,
+ 0x20, 0x03, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
+ 0x74, 0x2e, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x2e, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e,
+ 0x65, 0x55, 0x73, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0c, 0x6d, 0x61, 0x63,
+ 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x73, 0x1a, 0x5f, 0x0a, 0x11, 0x4d, 0x61, 0x63,
+ 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10,
+ 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79,
+ 0x12, 0x34, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32,
+ 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4d, 0x61, 0x63,
+ 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x52,
+ 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x2e, 0x0a, 0x12, 0x4d, 0x61,
+ 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73,
+ 0x12, 0x18, 0x0a, 0x07, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28,
+ 0x0d, 0x52, 0x07, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, 0x52,
+ 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12,
+ 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28,
+ 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61,
+ 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52,
+ 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73,
+ 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15,
+ 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43,
+ 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67,
+ 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04,
+ 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72,
+ 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, 0x6e,
+ 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43,
+ 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62,
+ 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e,
+ 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b,
+ 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62,
+ 0x4b, 0x65, 0x79, 0x12, 0x33, 0x0a, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67,
+ 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
+ 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x6a,
+ 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69,
0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46,
- 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18,
- 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
- 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69,
- 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69,
- 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a,
- 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18,
- 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
- 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69,
- 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69,
- 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a,
- 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43,
- 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c,
- 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43,
- 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c,
- 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f,
- 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e,
- 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72,
- 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72,
- 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69,
- 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69,
- 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69,
- 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53,
- 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69,
- 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d,
- 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69,
- 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20,
- 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a,
- 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f,
- 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63,
- 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a,
- 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06,
- 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f,
- 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01,
- 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65,
- 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55,
- 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74,
+ 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44,
+ 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69,
+ 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64,
+ 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67,
+ 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68,
+ 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72,
+ 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72,
+ 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66,
+ 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67,
+ 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f,
+ 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f,
+ 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72,
+ 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c,
+ 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f,
+ 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15,
+ 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f,
+ 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65,
+ 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e,
+ 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69,
+ 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69,
+ 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, 0x72,
+ 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08,
+ 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08,
+ 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65,
+ 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c,
+ 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06,
+ 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f,
+ 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65,
+ 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65,
+ 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e,
+ 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65,
+ 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74,
+ 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e,
+ 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e,
+ 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18,
+ 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a,
+ 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08,
+ 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15,
+ 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64,
+ 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74,
0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69,
- 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72,
- 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12,
- 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18,
- 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55,
- 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72,
- 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52,
- 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f,
- 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67,
- 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61,
- 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49,
- 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e,
- 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65,
- 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b,
- 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77,
- 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18,
- 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d,
- 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74,
- 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64,
- 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72,
- 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01,
- 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d,
- 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61,
- 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65,
- 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74,
- 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70,
- 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75,
- 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xda, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43,
- 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65,
- 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65,
- 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e,
- 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18,
- 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
- 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f,
- 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72,
- 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f,
- 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61,
- 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e,
- 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x24,
- 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18,
- 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72,
- 0x50, 0x6f, 0x72, 0x74, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f,
- 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01,
- 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65,
- 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61,
- 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52,
- 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74,
- 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12,
- 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61,
- 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03,
- 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18,
- 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03,
- 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14,
- 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52,
- 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72,
- 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65,
- 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e,
- 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53,
- 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65,
- 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20,
- 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07,
- 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44,
- 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68,
- 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04,
+ 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52,
+ 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65,
+ 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c,
+ 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01,
+ 0x28, 0x08, 0x52, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70,
+ 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46,
+ 0x6c, 0x61, 0x67, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e,
+ 0x46, 0x6c, 0x61, 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e,
+ 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18,
+ 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52,
+ 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77,
+ 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e,
+ 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65,
+ 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16,
+ 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06,
+ 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65,
+ 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71,
+ 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18,
+ 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07,
+ 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44,
+ 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f,
+ 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52,
+ 0x6f, 0x75, 0x74, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f,
+ 0x41, 0x70, 0x70, 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69,
+ 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44,
+ 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76,
+ 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52,
+ 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47,
+ 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75,
+ 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67,
+ 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
+ 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65,
+ 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f,
+ 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d,
+ 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d,
+ 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65,
+ 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f,
+ 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f,
+ 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xb4, 0x01, 0x0a, 0x0a,
+ 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f,
+ 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61,
+ 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20,
+ 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
+ 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52,
+ 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68,
+ 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03,
0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61,
- 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61,
- 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01,
- 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79,
- 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65,
- 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04,
- 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c,
- 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18,
- 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a,
- 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e,
- 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75,
- 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72,
- 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e,
- 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
- 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06,
- 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63,
- 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67,
- 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63,
- 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04,
- 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74,
- 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01,
- 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
- 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e,
- 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07,
- 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38,
- 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73,
- 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
- 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20,
- 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63,
- 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28,
- 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72,
- 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20,
- 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72,
- 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e,
- 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f,
- 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a,
- 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72,
- 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10,
- 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64,
- 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f,
- 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77,
- 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63,
- 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73,
- 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61,
- 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61,
- 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74,
- 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64,
- 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09,
- 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a,
- 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32,
- 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c,
- 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f,
- 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18,
- 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
- 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72,
- 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d,
- 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61,
- 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07,
- 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a,
- 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18,
- 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f,
- 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49,
- 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49,
- 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01,
- 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e,
- 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34,
- 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e,
- 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75,
- 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74,
- 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74,
- 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e,
- 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49,
- 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e,
- 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74,
- 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52,
- 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65,
- 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64,
- 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e,
- 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f,
- 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74,
- 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c,
- 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a,
- 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12,
- 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50,
- 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20,
- 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12,
- 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01,
- 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a,
- 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52,
- 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
- 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f,
- 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
- 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
- 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45,
- 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22,
- 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61,
- 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64,
- 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65,
+ 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x53, 0x6b,
+ 0x69, 0x70, 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01,
+ 0x28, 0x08, 0x52, 0x0e, 0x53, 0x6b, 0x69, 0x70, 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, 0x63, 0x65,
+ 0x73, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f,
+ 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09,
+ 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02,
+ 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c,
+ 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73,
+ 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54,
+ 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28,
+ 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d,
+ 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b,
+ 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28,
+ 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e,
+ 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53,
+ 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72,
+ 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79,
+ 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28,
+ 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65,
+ 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c,
+ 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68,
+ 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48,
+ 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02,
+ 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06,
+ 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53,
+ 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01,
+ 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72,
+ 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65,
+ 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49,
+ 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02,
+ 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
+ 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52,
+ 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63,
+ 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e,
+ 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69,
+ 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72,
+ 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d,
+ 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72,
+ 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c,
+ 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04,
+ 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f,
+ 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
+ 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f,
+ 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79,
+ 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79,
+ 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64,
+ 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20,
+ 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61,
+ 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06,
+ 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18,
+ 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a,
+ 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72,
+ 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12,
+ 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a,
+ 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74,
+ 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61,
+ 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05,
+ 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61,
+ 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52,
+ 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65,
+ 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46,
+ 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73,
+ 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28,
+ 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12,
+ 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32,
+ 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c,
+ 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12,
+ 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03,
+ 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f,
+ 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20,
+ 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
+ 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70,
+ 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49,
+ 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61,
+ 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52,
+ 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44,
+ 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73,
+ 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69,
+ 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e,
+ 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f,
+ 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f,
+ 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c,
+ 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c,
+ 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44,
+ 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22,
+ 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75,
+ 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01,
+ 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
+ 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08,
+ 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74,
+ 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28,
+ 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50,
+ 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61,
+ 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e,
+ 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20,
+ 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41,
+ 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c,
+ 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14,
+ 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74,
+ 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64,
+ 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74,
+ 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10,
+ 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43,
+ 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04,
+ 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d,
+ 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74,
+ 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f,
+ 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69,
+ 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08,
+ 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e,
+ 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45,
+ 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65,
0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65,
- 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74,
- 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61,
- 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d,
- 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
- 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a,
- 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e,
- 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e,
+ 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
+ 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73,
+ 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e,
+ 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79,
+ 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61,
+ 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
+ 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a,
+ 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e,
0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79,
- 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41,
- 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77,
- 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e,
- 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c,
- 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72,
- 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58,
- 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69,
- 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e,
- 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
- 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67,
- 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d,
- 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63,
- 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
+ 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65,
+ 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
+ 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11,
+ 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74,
+ 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45,
+ 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76,
+ 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e,
+ 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61,
- 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
- 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75,
- 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45,
- 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a,
- 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70,
- 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06,
- 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
+ 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
+ 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
+ 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74,
+ 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c,
+ 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72,
+ 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d,
+ 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70,
+ 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08,
+ 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67,
+ 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d,
+ 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
+ 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c,
+ 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
+ 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73,
+ 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
+ 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f,
+ 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@@ -3910,7 +4300,7 @@ func file_management_proto_rawDescGZIP() []byte {
}
var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 5)
-var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 40)
+var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 45)
var file_management_proto_goTypes = []interface{}{
(RuleProtocol)(0), // 0: management.RuleProtocol
(RuleDirection)(0), // 1: management.RuleDirection
@@ -3934,107 +4324,117 @@ var file_management_proto_goTypes = []interface{}{
(*HostConfig)(nil), // 19: management.HostConfig
(*RelayConfig)(nil), // 20: management.RelayConfig
(*FlowConfig)(nil), // 21: management.FlowConfig
- (*ProtectedHostConfig)(nil), // 22: management.ProtectedHostConfig
- (*PeerConfig)(nil), // 23: management.PeerConfig
- (*NetworkMap)(nil), // 24: management.NetworkMap
- (*RemotePeerConfig)(nil), // 25: management.RemotePeerConfig
- (*SSHConfig)(nil), // 26: management.SSHConfig
- (*DeviceAuthorizationFlowRequest)(nil), // 27: management.DeviceAuthorizationFlowRequest
- (*DeviceAuthorizationFlow)(nil), // 28: management.DeviceAuthorizationFlow
- (*PKCEAuthorizationFlowRequest)(nil), // 29: management.PKCEAuthorizationFlowRequest
- (*PKCEAuthorizationFlow)(nil), // 30: management.PKCEAuthorizationFlow
- (*ProviderConfig)(nil), // 31: management.ProviderConfig
- (*Route)(nil), // 32: management.Route
- (*DNSConfig)(nil), // 33: management.DNSConfig
- (*CustomZone)(nil), // 34: management.CustomZone
- (*SimpleRecord)(nil), // 35: management.SimpleRecord
- (*NameServerGroup)(nil), // 36: management.NameServerGroup
- (*NameServer)(nil), // 37: management.NameServer
- (*FirewallRule)(nil), // 38: management.FirewallRule
- (*NetworkAddress)(nil), // 39: management.NetworkAddress
- (*Checks)(nil), // 40: management.Checks
- (*PortInfo)(nil), // 41: management.PortInfo
- (*RouteFirewallRule)(nil), // 42: management.RouteFirewallRule
- (*ForwardingRule)(nil), // 43: management.ForwardingRule
- (*PortInfo_Range)(nil), // 44: management.PortInfo.Range
- (*timestamppb.Timestamp)(nil), // 45: google.protobuf.Timestamp
- (*durationpb.Duration)(nil), // 46: google.protobuf.Duration
+ (*JWTConfig)(nil), // 22: management.JWTConfig
+ (*ProtectedHostConfig)(nil), // 23: management.ProtectedHostConfig
+ (*PeerConfig)(nil), // 24: management.PeerConfig
+ (*AutoUpdateSettings)(nil), // 25: management.AutoUpdateSettings
+ (*NetworkMap)(nil), // 26: management.NetworkMap
+ (*SSHAuth)(nil), // 27: management.SSHAuth
+ (*MachineUserIndexes)(nil), // 28: management.MachineUserIndexes
+ (*RemotePeerConfig)(nil), // 29: management.RemotePeerConfig
+ (*SSHConfig)(nil), // 30: management.SSHConfig
+ (*DeviceAuthorizationFlowRequest)(nil), // 31: management.DeviceAuthorizationFlowRequest
+ (*DeviceAuthorizationFlow)(nil), // 32: management.DeviceAuthorizationFlow
+ (*PKCEAuthorizationFlowRequest)(nil), // 33: management.PKCEAuthorizationFlowRequest
+ (*PKCEAuthorizationFlow)(nil), // 34: management.PKCEAuthorizationFlow
+ (*ProviderConfig)(nil), // 35: management.ProviderConfig
+ (*Route)(nil), // 36: management.Route
+ (*DNSConfig)(nil), // 37: management.DNSConfig
+ (*CustomZone)(nil), // 38: management.CustomZone
+ (*SimpleRecord)(nil), // 39: management.SimpleRecord
+ (*NameServerGroup)(nil), // 40: management.NameServerGroup
+ (*NameServer)(nil), // 41: management.NameServer
+ (*FirewallRule)(nil), // 42: management.FirewallRule
+ (*NetworkAddress)(nil), // 43: management.NetworkAddress
+ (*Checks)(nil), // 44: management.Checks
+ (*PortInfo)(nil), // 45: management.PortInfo
+ (*RouteFirewallRule)(nil), // 46: management.RouteFirewallRule
+ (*ForwardingRule)(nil), // 47: management.ForwardingRule
+ nil, // 48: management.SSHAuth.MachineUsersEntry
+ (*PortInfo_Range)(nil), // 49: management.PortInfo.Range
+ (*timestamppb.Timestamp)(nil), // 50: google.protobuf.Timestamp
+ (*durationpb.Duration)(nil), // 51: google.protobuf.Duration
}
var file_management_proto_depIdxs = []int32{
14, // 0: management.SyncRequest.meta:type_name -> management.PeerSystemMeta
18, // 1: management.SyncResponse.netbirdConfig:type_name -> management.NetbirdConfig
- 23, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig
- 25, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig
- 24, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap
- 40, // 5: management.SyncResponse.Checks:type_name -> management.Checks
+ 24, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig
+ 29, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig
+ 26, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap
+ 44, // 5: management.SyncResponse.Checks:type_name -> management.Checks
14, // 6: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta
14, // 7: management.LoginRequest.meta:type_name -> management.PeerSystemMeta
10, // 8: management.LoginRequest.peerKeys:type_name -> management.PeerKeys
- 39, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress
+ 43, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress
11, // 10: management.PeerSystemMeta.environment:type_name -> management.Environment
12, // 11: management.PeerSystemMeta.files:type_name -> management.File
13, // 12: management.PeerSystemMeta.flags:type_name -> management.Flags
18, // 13: management.LoginResponse.netbirdConfig:type_name -> management.NetbirdConfig
- 23, // 14: management.LoginResponse.peerConfig:type_name -> management.PeerConfig
- 40, // 15: management.LoginResponse.Checks:type_name -> management.Checks
- 45, // 16: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp
+ 24, // 14: management.LoginResponse.peerConfig:type_name -> management.PeerConfig
+ 44, // 15: management.LoginResponse.Checks:type_name -> management.Checks
+ 50, // 16: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp
19, // 17: management.NetbirdConfig.stuns:type_name -> management.HostConfig
- 22, // 18: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig
+ 23, // 18: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig
19, // 19: management.NetbirdConfig.signal:type_name -> management.HostConfig
20, // 20: management.NetbirdConfig.relay:type_name -> management.RelayConfig
21, // 21: management.NetbirdConfig.flow:type_name -> management.FlowConfig
3, // 22: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol
- 46, // 23: management.FlowConfig.interval:type_name -> google.protobuf.Duration
+ 51, // 23: management.FlowConfig.interval:type_name -> google.protobuf.Duration
19, // 24: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig
- 26, // 25: management.PeerConfig.sshConfig:type_name -> management.SSHConfig
- 23, // 26: management.NetworkMap.peerConfig:type_name -> management.PeerConfig
- 25, // 27: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig
- 32, // 28: management.NetworkMap.Routes:type_name -> management.Route
- 33, // 29: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig
- 25, // 30: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig
- 38, // 31: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule
- 42, // 32: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule
- 43, // 33: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule
- 26, // 34: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig
- 4, // 35: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider
- 31, // 36: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
- 31, // 37: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
- 36, // 38: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup
- 34, // 39: management.DNSConfig.CustomZones:type_name -> management.CustomZone
- 35, // 40: management.CustomZone.Records:type_name -> management.SimpleRecord
- 37, // 41: management.NameServerGroup.NameServers:type_name -> management.NameServer
- 1, // 42: management.FirewallRule.Direction:type_name -> management.RuleDirection
- 2, // 43: management.FirewallRule.Action:type_name -> management.RuleAction
- 0, // 44: management.FirewallRule.Protocol:type_name -> management.RuleProtocol
- 41, // 45: management.FirewallRule.PortInfo:type_name -> management.PortInfo
- 44, // 46: management.PortInfo.range:type_name -> management.PortInfo.Range
- 2, // 47: management.RouteFirewallRule.action:type_name -> management.RuleAction
- 0, // 48: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol
- 41, // 49: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo
- 0, // 50: management.ForwardingRule.protocol:type_name -> management.RuleProtocol
- 41, // 51: management.ForwardingRule.destinationPort:type_name -> management.PortInfo
- 41, // 52: management.ForwardingRule.translatedPort:type_name -> management.PortInfo
- 5, // 53: management.ManagementService.Login:input_type -> management.EncryptedMessage
- 5, // 54: management.ManagementService.Sync:input_type -> management.EncryptedMessage
- 17, // 55: management.ManagementService.GetServerKey:input_type -> management.Empty
- 17, // 56: management.ManagementService.isHealthy:input_type -> management.Empty
- 5, // 57: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage
- 5, // 58: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage
- 5, // 59: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage
- 5, // 60: management.ManagementService.Logout:input_type -> management.EncryptedMessage
- 5, // 61: management.ManagementService.Login:output_type -> management.EncryptedMessage
- 5, // 62: management.ManagementService.Sync:output_type -> management.EncryptedMessage
- 16, // 63: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse
- 17, // 64: management.ManagementService.isHealthy:output_type -> management.Empty
- 5, // 65: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage
- 5, // 66: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage
- 17, // 67: management.ManagementService.SyncMeta:output_type -> management.Empty
- 17, // 68: management.ManagementService.Logout:output_type -> management.Empty
- 61, // [61:69] is the sub-list for method output_type
- 53, // [53:61] is the sub-list for method input_type
- 53, // [53:53] is the sub-list for extension type_name
- 53, // [53:53] is the sub-list for extension extendee
- 0, // [0:53] is the sub-list for field type_name
+ 30, // 25: management.PeerConfig.sshConfig:type_name -> management.SSHConfig
+ 25, // 26: management.PeerConfig.autoUpdate:type_name -> management.AutoUpdateSettings
+ 24, // 27: management.NetworkMap.peerConfig:type_name -> management.PeerConfig
+ 29, // 28: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig
+ 36, // 29: management.NetworkMap.Routes:type_name -> management.Route
+ 37, // 30: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig
+ 29, // 31: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig
+ 42, // 32: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule
+ 46, // 33: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule
+ 47, // 34: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule
+ 27, // 35: management.NetworkMap.sshAuth:type_name -> management.SSHAuth
+ 48, // 36: management.SSHAuth.machine_users:type_name -> management.SSHAuth.MachineUsersEntry
+ 30, // 37: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig
+ 22, // 38: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig
+ 4, // 39: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider
+ 35, // 40: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
+ 35, // 41: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
+ 40, // 42: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup
+ 38, // 43: management.DNSConfig.CustomZones:type_name -> management.CustomZone
+ 39, // 44: management.CustomZone.Records:type_name -> management.SimpleRecord
+ 41, // 45: management.NameServerGroup.NameServers:type_name -> management.NameServer
+ 1, // 46: management.FirewallRule.Direction:type_name -> management.RuleDirection
+ 2, // 47: management.FirewallRule.Action:type_name -> management.RuleAction
+ 0, // 48: management.FirewallRule.Protocol:type_name -> management.RuleProtocol
+ 45, // 49: management.FirewallRule.PortInfo:type_name -> management.PortInfo
+ 49, // 50: management.PortInfo.range:type_name -> management.PortInfo.Range
+ 2, // 51: management.RouteFirewallRule.action:type_name -> management.RuleAction
+ 0, // 52: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol
+ 45, // 53: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo
+ 0, // 54: management.ForwardingRule.protocol:type_name -> management.RuleProtocol
+ 45, // 55: management.ForwardingRule.destinationPort:type_name -> management.PortInfo
+ 45, // 56: management.ForwardingRule.translatedPort:type_name -> management.PortInfo
+ 28, // 57: management.SSHAuth.MachineUsersEntry.value:type_name -> management.MachineUserIndexes
+ 5, // 58: management.ManagementService.Login:input_type -> management.EncryptedMessage
+ 5, // 59: management.ManagementService.Sync:input_type -> management.EncryptedMessage
+ 17, // 60: management.ManagementService.GetServerKey:input_type -> management.Empty
+ 17, // 61: management.ManagementService.isHealthy:input_type -> management.Empty
+ 5, // 62: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage
+ 5, // 63: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage
+ 5, // 64: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage
+ 5, // 65: management.ManagementService.Logout:input_type -> management.EncryptedMessage
+ 5, // 66: management.ManagementService.Login:output_type -> management.EncryptedMessage
+ 5, // 67: management.ManagementService.Sync:output_type -> management.EncryptedMessage
+ 16, // 68: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse
+ 17, // 69: management.ManagementService.isHealthy:output_type -> management.Empty
+ 5, // 70: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage
+ 5, // 71: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage
+ 17, // 72: management.ManagementService.SyncMeta:output_type -> management.Empty
+ 17, // 73: management.ManagementService.Logout:output_type -> management.Empty
+ 66, // [66:74] is the sub-list for method output_type
+ 58, // [58:66] is the sub-list for method input_type
+ 58, // [58:58] is the sub-list for extension type_name
+ 58, // [58:58] is the sub-list for extension extendee
+ 0, // [0:58] is the sub-list for field type_name
}
func init() { file_management_proto_init() }
@@ -4248,7 +4648,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*ProtectedHostConfig); i {
+ switch v := v.(*JWTConfig); i {
case 0:
return &v.state
case 1:
@@ -4260,7 +4660,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*PeerConfig); i {
+ switch v := v.(*ProtectedHostConfig); i {
case 0:
return &v.state
case 1:
@@ -4272,7 +4672,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*NetworkMap); i {
+ switch v := v.(*PeerConfig); i {
case 0:
return &v.state
case 1:
@@ -4284,7 +4684,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*RemotePeerConfig); i {
+ switch v := v.(*AutoUpdateSettings); i {
case 0:
return &v.state
case 1:
@@ -4296,7 +4696,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*SSHConfig); i {
+ switch v := v.(*NetworkMap); i {
case 0:
return &v.state
case 1:
@@ -4308,7 +4708,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*DeviceAuthorizationFlowRequest); i {
+ switch v := v.(*SSHAuth); i {
case 0:
return &v.state
case 1:
@@ -4320,7 +4720,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*DeviceAuthorizationFlow); i {
+ switch v := v.(*MachineUserIndexes); i {
case 0:
return &v.state
case 1:
@@ -4332,7 +4732,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*PKCEAuthorizationFlowRequest); i {
+ switch v := v.(*RemotePeerConfig); i {
case 0:
return &v.state
case 1:
@@ -4344,7 +4744,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*PKCEAuthorizationFlow); i {
+ switch v := v.(*SSHConfig); i {
case 0:
return &v.state
case 1:
@@ -4356,7 +4756,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*ProviderConfig); i {
+ switch v := v.(*DeviceAuthorizationFlowRequest); i {
case 0:
return &v.state
case 1:
@@ -4368,7 +4768,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*Route); i {
+ switch v := v.(*DeviceAuthorizationFlow); i {
case 0:
return &v.state
case 1:
@@ -4380,7 +4780,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*DNSConfig); i {
+ switch v := v.(*PKCEAuthorizationFlowRequest); i {
case 0:
return &v.state
case 1:
@@ -4392,7 +4792,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*CustomZone); i {
+ switch v := v.(*PKCEAuthorizationFlow); i {
case 0:
return &v.state
case 1:
@@ -4404,7 +4804,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[30].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*SimpleRecord); i {
+ switch v := v.(*ProviderConfig); i {
case 0:
return &v.state
case 1:
@@ -4416,7 +4816,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*NameServerGroup); i {
+ switch v := v.(*Route); i {
case 0:
return &v.state
case 1:
@@ -4428,7 +4828,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*NameServer); i {
+ switch v := v.(*DNSConfig); i {
case 0:
return &v.state
case 1:
@@ -4440,7 +4840,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*FirewallRule); i {
+ switch v := v.(*CustomZone); i {
case 0:
return &v.state
case 1:
@@ -4452,7 +4852,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*NetworkAddress); i {
+ switch v := v.(*SimpleRecord); i {
case 0:
return &v.state
case 1:
@@ -4464,7 +4864,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*Checks); i {
+ switch v := v.(*NameServerGroup); i {
case 0:
return &v.state
case 1:
@@ -4476,7 +4876,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*PortInfo); i {
+ switch v := v.(*NameServer); i {
case 0:
return &v.state
case 1:
@@ -4488,7 +4888,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[37].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*RouteFirewallRule); i {
+ switch v := v.(*FirewallRule); i {
case 0:
return &v.state
case 1:
@@ -4500,7 +4900,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[38].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*ForwardingRule); i {
+ switch v := v.(*NetworkAddress); i {
case 0:
return &v.state
case 1:
@@ -4512,6 +4912,54 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[39].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*Checks); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_management_proto_msgTypes[40].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*PortInfo); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_management_proto_msgTypes[41].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*RouteFirewallRule); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_management_proto_msgTypes[42].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*ForwardingRule); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_management_proto_msgTypes[44].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*PortInfo_Range); i {
case 0:
return &v.state
@@ -4524,7 +4972,7 @@ func file_management_proto_init() {
}
}
}
- file_management_proto_msgTypes[36].OneofWrappers = []interface{}{
+ file_management_proto_msgTypes[40].OneofWrappers = []interface{}{
(*PortInfo_Port)(nil),
(*PortInfo_Range_)(nil),
}
@@ -4534,7 +4982,7 @@ func file_management_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_management_proto_rawDesc,
NumEnums: 5,
- NumMessages: 40,
+ NumMessages: 45,
NumExtensions: 0,
NumServices: 1,
},
diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto
index ad82d37d9..f2e591e88 100644
--- a/shared/management/proto/management.proto
+++ b/shared/management/proto/management.proto
@@ -146,6 +146,12 @@ message Flags {
bool blockInbound = 9;
bool lazyConnectionEnabled = 10;
+
+ bool enableSSHRoot = 11;
+ bool enableSSHSFTP = 12;
+ bool enableSSHLocalPortForwarding = 13;
+ bool enableSSHRemotePortForwarding = 14;
+ bool disableSSHAuth = 15;
}
// PeerSystemMeta is machine meta data like OS and version.
@@ -240,6 +246,14 @@ message FlowConfig {
bool dnsCollection = 8;
}
+// JWTConfig represents JWT authentication configuration
+message JWTConfig {
+ string issuer = 1;
+ string audience = 2;
+ string keysLocation = 3;
+ int64 maxTokenAge = 4;
+}
+
// ProtectedHostConfig is similar to HostConfig but has additional user and password
// Mostly used for TURN servers
message ProtectedHostConfig {
@@ -266,6 +280,18 @@ message PeerConfig {
bool LazyConnectionEnabled = 6;
int32 mtu = 7;
+
+ // Auto-update config
+ AutoUpdateSettings autoUpdate = 8;
+}
+
+message AutoUpdateSettings {
+ string version = 1;
+ /*
+ alwaysUpdate = true → Updates happen automatically in the background
+ alwaysUpdate = false → Updates only happen when triggered by a peer connection
+ */
+ bool alwaysUpdate = 2;
}
// NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections
@@ -306,6 +332,24 @@ message NetworkMap {
bool routesFirewallRulesIsEmpty = 11;
repeated ForwardingRule forwardingRules = 12;
+
+ // SSHAuth represents SSH authorization configuration
+ SSHAuth sshAuth = 13;
+}
+
+message SSHAuth {
+ // UserIDClaim is the JWT claim to be used to get the users ID
+ string UserIDClaim = 1;
+
+ // AuthorizedUsers is a list of hashed user IDs authorized to access this peer via SSH
+ repeated bytes AuthorizedUsers = 2;
+
+ // MachineUsers is a map of machine user names to their corresponding indexes in the AuthorizedUsers list
+ map machine_users = 3;
+}
+
+message MachineUserIndexes {
+ repeated uint32 indexes = 1;
}
// RemotePeerConfig represents a configuration of a remote peer.
@@ -335,6 +379,8 @@ message SSHConfig {
// sshPubKey is a SSH public key of a peer to be added to authorized_hosts.
// This property should be ignore if SSHConfig comes from PeerConfig.
bytes sshPubKey = 2;
+
+ JWTConfig jwtConfig = 3;
}
// DeviceAuthorizationFlowRequest empty struct for future expansion
@@ -410,13 +456,15 @@ message DNSConfig {
bool ServiceEnable = 1;
repeated NameServerGroup NameServerGroups = 2;
repeated CustomZone CustomZones = 3;
- int64 ForwarderPort = 4;
+ int64 ForwarderPort = 4 [deprecated = true];
}
// CustomZone represents a dns.CustomZone
message CustomZone {
string Domain = 1;
repeated SimpleRecord Records = 2;
+ bool SearchDomainDisabled = 3;
+ bool SkipPTRProcess = 4;
}
// SimpleRecord represents a dns.SimpleRecord
diff --git a/shared/management/status/error.go b/shared/management/status/error.go
index 1e914babb..09676847e 100644
--- a/shared/management/status/error.go
+++ b/shared/management/status/error.go
@@ -37,6 +37,9 @@ const (
// Unauthenticated indicates that user is not authenticated due to absence of valid credentials
Unauthenticated Type = 10
+
+ // TooManyRequests indicates that the user has sent too many requests in a given amount of time (rate limiting)
+ TooManyRequests Type = 11
)
// Type is a type of the Error
diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go
index 967e18d79..c057ef089 100644
--- a/shared/relay/client/dialer/quic/quic.go
+++ b/shared/relay/client/dialer/quic/quic.go
@@ -11,8 +11,8 @@ import (
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
- quictls "github.com/netbirdio/netbird/shared/relay/tls"
nbnet "github.com/netbirdio/netbird/client/net"
+ quictls "github.com/netbirdio/netbird/shared/relay/tls"
)
type Dialer struct {
diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go
index 66fff3447..37b189e05 100644
--- a/shared/relay/client/dialer/ws/ws.go
+++ b/shared/relay/client/dialer/ws/ws.go
@@ -14,9 +14,9 @@ import (
"github.com/coder/websocket"
log "github.com/sirupsen/logrus"
+ nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/shared/relay"
"github.com/netbirdio/netbird/util/embeddedroots"
- nbnet "github.com/netbirdio/netbird/client/net"
)
type Dialer struct {
diff --git a/shared/relay/constants.go b/shared/relay/constants.go
index 3c7c3cd29..0f2a27610 100644
--- a/shared/relay/constants.go
+++ b/shared/relay/constants.go
@@ -3,4 +3,4 @@ package relay
const (
// WebSocketURLPath is the path for the websocket relay connection
WebSocketURLPath = "/relay"
-)
\ No newline at end of file
+)
diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go
index 31f3372c0..5368b57a2 100644
--- a/shared/signal/client/grpc.go
+++ b/shared/signal/client/grpc.go
@@ -60,8 +60,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent)
if err != nil {
- log.Printf("createConnection error: %v", err)
- return err
+ return fmt.Errorf("create connection: %w", err)
}
return nil
}
diff --git a/shared/sshauth/userhash.go b/shared/sshauth/userhash.go
new file mode 100644
index 000000000..276fc9ba2
--- /dev/null
+++ b/shared/sshauth/userhash.go
@@ -0,0 +1,28 @@
+package sshauth
+
+import (
+ "encoding/hex"
+
+ "golang.org/x/crypto/blake2b"
+)
+
+// UserIDHash represents a hashed user ID (BLAKE2b-128)
+type UserIDHash [16]byte
+
+// HashUserID hashes a user ID using BLAKE2b-128 and returns the hash value
+// This function must produce the same hash on both client and management server
+func HashUserID(userID string) (UserIDHash, error) {
+ hash, err := blake2b.New(16, nil)
+ if err != nil {
+ return UserIDHash{}, err
+ }
+ hash.Write([]byte(userID))
+ var result UserIDHash
+ copy(result[:], hash.Sum(nil))
+ return result, nil
+}
+
+// String returns the hexadecimal string representation of the hash
+func (h UserIDHash) String() string {
+ return hex.EncodeToString(h[:])
+}
diff --git a/shared/sshauth/userhash_test.go b/shared/sshauth/userhash_test.go
new file mode 100644
index 000000000..5a3cb6986
--- /dev/null
+++ b/shared/sshauth/userhash_test.go
@@ -0,0 +1,210 @@
+package sshauth
+
+import (
+ "testing"
+)
+
+func TestHashUserID(t *testing.T) {
+ tests := []struct {
+ name string
+ userID string
+ }{
+ {
+ name: "simple user ID",
+ userID: "user@example.com",
+ },
+ {
+ name: "UUID format",
+ userID: "550e8400-e29b-41d4-a716-446655440000",
+ },
+ {
+ name: "numeric ID",
+ userID: "12345",
+ },
+ {
+ name: "empty string",
+ userID: "",
+ },
+ {
+ name: "special characters",
+ userID: "user+test@domain.com",
+ },
+ {
+ name: "unicode characters",
+ userID: "用户@example.com",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ hash, err := HashUserID(tt.userID)
+ if err != nil {
+ t.Errorf("HashUserID() error = %v, want nil", err)
+ return
+ }
+
+ // Verify hash is non-zero for non-empty inputs
+ if tt.userID != "" && hash == [16]byte{} {
+ t.Errorf("HashUserID() returned zero hash for non-empty input")
+ }
+ })
+ }
+}
+
+func TestHashUserID_Consistency(t *testing.T) {
+ userID := "test@example.com"
+
+ hash1, err1 := HashUserID(userID)
+ if err1 != nil {
+ t.Fatalf("First HashUserID() error = %v", err1)
+ }
+
+ hash2, err2 := HashUserID(userID)
+ if err2 != nil {
+ t.Fatalf("Second HashUserID() error = %v", err2)
+ }
+
+ if hash1 != hash2 {
+ t.Errorf("HashUserID() is not consistent: got %v and %v for same input", hash1, hash2)
+ }
+}
+
+func TestHashUserID_Uniqueness(t *testing.T) {
+ tests := []struct {
+ userID1 string
+ userID2 string
+ }{
+ {"user1@example.com", "user2@example.com"},
+ {"alice@domain.com", "bob@domain.com"},
+ {"test", "test1"},
+ {"", "a"},
+ }
+
+ for _, tt := range tests {
+ hash1, err1 := HashUserID(tt.userID1)
+ if err1 != nil {
+ t.Fatalf("HashUserID(%s) error = %v", tt.userID1, err1)
+ }
+
+ hash2, err2 := HashUserID(tt.userID2)
+ if err2 != nil {
+ t.Fatalf("HashUserID(%s) error = %v", tt.userID2, err2)
+ }
+
+ if hash1 == hash2 {
+ t.Errorf("HashUserID() collision: %s and %s produced same hash %v", tt.userID1, tt.userID2, hash1)
+ }
+ }
+}
+
+func TestUserIDHash_String(t *testing.T) {
+ tests := []struct {
+ name string
+ hash UserIDHash
+ expected string
+ }{
+ {
+ name: "zero hash",
+ hash: [16]byte{},
+ expected: "00000000000000000000000000000000",
+ },
+ {
+ name: "small value",
+ hash: [16]byte{15: 0xff},
+ expected: "000000000000000000000000000000ff",
+ },
+ {
+ name: "large value",
+ hash: [16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe},
+ expected: "0000000000000000deadbeefcafebabe",
+ },
+ {
+ name: "max value",
+ hash: [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ expected: "ffffffffffffffffffffffffffffffff",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tt.hash.String()
+ if result != tt.expected {
+ t.Errorf("UserIDHash.String() = %v, want %v", result, tt.expected)
+ }
+ })
+ }
+}
+
+func TestUserIDHash_String_Length(t *testing.T) {
+ // Test that String() always returns 32 hex characters (16 bytes * 2)
+ userID := "test@example.com"
+ hash, err := HashUserID(userID)
+ if err != nil {
+ t.Fatalf("HashUserID() error = %v", err)
+ }
+
+ result := hash.String()
+ if len(result) != 32 {
+ t.Errorf("UserIDHash.String() length = %d, want 32", len(result))
+ }
+
+ // Verify it's valid hex
+ for i, c := range result {
+ if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
+ t.Errorf("UserIDHash.String() contains non-hex character at position %d: %c", i, c)
+ }
+ }
+}
+
+func TestHashUserID_KnownValues(t *testing.T) {
+ // Test with known BLAKE2b-128 values to ensure correct implementation
+ tests := []struct {
+ name string
+ userID string
+ expected UserIDHash
+ }{
+ {
+ name: "empty string",
+ userID: "",
+ // BLAKE2b-128 of empty string
+ expected: [16]byte{0xca, 0xe6, 0x69, 0x41, 0xd9, 0xef, 0xbd, 0x40, 0x4e, 0x4d, 0x88, 0x75, 0x8e, 0xa6, 0x76, 0x70},
+ },
+ {
+ name: "single character 'a'",
+ userID: "a",
+ // BLAKE2b-128 of "a"
+ expected: [16]byte{0x27, 0xc3, 0x5e, 0x6e, 0x93, 0x73, 0x87, 0x7f, 0x29, 0xe5, 0x62, 0x46, 0x4e, 0x46, 0x49, 0x7e},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ hash, err := HashUserID(tt.userID)
+ if err != nil {
+ t.Errorf("HashUserID() error = %v", err)
+ return
+ }
+
+ if hash != tt.expected {
+ t.Errorf("HashUserID(%q) = %x, want %x",
+ tt.userID, hash, tt.expected)
+ }
+ })
+ }
+}
+
+func BenchmarkHashUserID(b *testing.B) {
+ userID := "user@example.com"
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, _ = HashUserID(userID)
+ }
+}
+
+func BenchmarkUserIDHash_String(b *testing.B) {
+ hash := UserIDHash([16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe})
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _ = hash.String()
+ }
+}
diff --git a/signal/cmd/run.go b/signal/cmd/run.go
index 96873dee7..bf8f8e327 100644
--- a/signal/cmd/run.go
+++ b/signal/cmd/run.go
@@ -94,7 +94,7 @@ var (
startPprof()
- opts, certManager, err := getTLSConfigurations()
+ opts, certManager, tlsConfig, err := getTLSConfigurations()
if err != nil {
return err
}
@@ -132,7 +132,7 @@ var (
// Start the main server - always serve HTTP with WebSocket proxy support
// If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager
- if certManager == nil {
+ if tlsConfig == nil {
// Without TLS, serve plain HTTP
httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort))
if err != nil {
@@ -140,9 +140,10 @@ var (
}
log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String())
serveHTTP(httpListener, grpcRootHandler)
- } else if signalPort != 443 {
- // With TLS but not on port 443, serve HTTPS
- httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig())
+ } else if certManager == nil || signalPort != 443 {
+ // Serve HTTPS if not already handled by startServerWithCertManager
+ // (custom certificates or Let's Encrypt with custom port)
+ httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), tlsConfig)
if err != nil {
return err
}
@@ -202,7 +203,7 @@ func startPprof() {
}()
}
-func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) {
+func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config, error) {
var (
err error
certManager *autocert.Manager
@@ -211,33 +212,33 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) {
if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" {
log.Infof("running without TLS")
- return nil, nil, nil
+ return nil, nil, nil, nil
}
if signalLetsencryptDomain != "" {
certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain)
if err != nil {
- return nil, certManager, err
+ return nil, certManager, nil, err
}
tlsConfig = certManager.TLSConfig()
log.Infof("setting up TLS with LetsEncrypt.")
} else {
if signalCertFile == "" || signalCertKey == "" {
log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt")
- return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt")
+ return nil, certManager, nil, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt")
}
tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey)
if err != nil {
log.Errorf("cannot load TLS credentials: %v", err)
- return nil, certManager, err
+ return nil, certManager, nil, err
}
log.Infof("setting up TLS with custom certificates.")
}
transportCredentials := credentials.NewTLS(tlsConfig)
- return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err
+ return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, tlsConfig, err
}
func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) {
diff --git a/util/common.go b/util/common.go
index 27adb9d13..89903b609 100644
--- a/util/common.go
+++ b/util/common.go
@@ -1,6 +1,19 @@
package util
-import "os"
+import (
+ "os"
+ "os/exec"
+
+ "github.com/skratchdot/open-golang/open"
+)
+
+// OpenBrowser opens the URL in a browser, respecting the BROWSER environment variable.
+func OpenBrowser(url string) error {
+ if browser := os.Getenv("BROWSER"); browser != "" {
+ return exec.Command(browser, url).Start()
+ }
+ return open.Run(url)
+}
// SliceDiff returns the elements in slice `x` that are not in slice `y`
func SliceDiff(x, y []string) []string {
diff --git a/version/update.go b/version/update.go
index 272eef4c6..a324d97fe 100644
--- a/version/update.go
+++ b/version/update.go
@@ -41,21 +41,28 @@ func NewUpdate(httpAgent string) *Update {
currentVersion, _ = goversion.NewVersion("0.0.0")
}
- latestAvailable, _ := goversion.NewVersion("0.0.0")
-
u := &Update{
- httpAgent: httpAgent,
- latestAvailable: latestAvailable,
- uiVersion: currentVersion,
- fetchTicker: time.NewTicker(fetchPeriod),
- fetchDone: make(chan struct{}),
+ httpAgent: httpAgent,
+ uiVersion: currentVersion,
+ fetchDone: make(chan struct{}),
}
- go u.startFetcher()
+
+ return u
+}
+
+func NewUpdateAndStart(httpAgent string) *Update {
+ u := NewUpdate(httpAgent)
+ go u.StartFetcher()
+
return u
}
// StopWatch stop the version info fetch loop
func (u *Update) StopWatch() {
+ if u.fetchTicker == nil {
+ return
+ }
+
u.fetchTicker.Stop()
select {
@@ -94,7 +101,18 @@ func (u *Update) SetOnUpdateListener(updateFn func()) {
}
}
-func (u *Update) startFetcher() {
+func (u *Update) LatestVersion() *goversion.Version {
+ u.versionsLock.Lock()
+ defer u.versionsLock.Unlock()
+ return u.latestAvailable
+}
+
+func (u *Update) StartFetcher() {
+ if u.fetchTicker != nil {
+ return
+ }
+ u.fetchTicker = time.NewTicker(fetchPeriod)
+
if changed := u.fetchVersion(); changed {
u.checkUpdate()
}
@@ -181,6 +199,10 @@ func (u *Update) isUpdateAvailable() bool {
u.versionsLock.Lock()
defer u.versionsLock.Unlock()
+ if u.latestAvailable == nil {
+ return false
+ }
+
if u.latestAvailable.GreaterThan(u.uiVersion) {
return true
}
diff --git a/version/update_test.go b/version/update_test.go
index a733714cf..d5d60800e 100644
--- a/version/update_test.go
+++ b/version/update_test.go
@@ -23,7 +23,7 @@ func TestNewUpdate(t *testing.T) {
wg.Add(1)
onUpdate := false
- u := NewUpdate(httpAgent)
+ u := NewUpdateAndStart(httpAgent)
defer u.StopWatch()
u.SetOnUpdateListener(func() {
onUpdate = true
@@ -48,7 +48,7 @@ func TestDoNotUpdate(t *testing.T) {
wg.Add(1)
onUpdate := false
- u := NewUpdate(httpAgent)
+ u := NewUpdateAndStart(httpAgent)
defer u.StopWatch()
u.SetOnUpdateListener(func() {
onUpdate = true
@@ -73,7 +73,7 @@ func TestDaemonUpdate(t *testing.T) {
wg.Add(1)
onUpdate := false
- u := NewUpdate(httpAgent)
+ u := NewUpdateAndStart(httpAgent)
defer u.StopWatch()
u.SetOnUpdateListener(func() {
onUpdate = true
diff --git a/version/url_windows.go b/version/url_windows.go
index 14fdb7ae6..a0fb6e5dd 100644
--- a/version/url_windows.go
+++ b/version/url_windows.go
@@ -6,7 +6,7 @@ import (
)
const (
- urlWinExe = "https://pkgs.netbird.io/windows/x64"
+ urlWinExe = "https://pkgs.netbird.io/windows/x64"
urlWinExeArm = "https://pkgs.netbird.io/windows/arm64"
)
@@ -18,11 +18,11 @@ func DownloadUrl() string {
if err != nil {
return downloadURL
}
-
+
url := urlWinExe
if runtime.GOARCH == "arm64" {
url = urlWinExeArm
}
-
+
return url
}