Compare commits

..

91 Commits

Author SHA1 Message Date
Viktor Liu
24b66fb406 Translate usernames to UPN format for domain login 2025-11-05 22:27:08 +01:00
Viktor Liu
9378b6b0a3 Merge branch 'ssh-rewrite' into move-licensed-code 2025-11-05 16:09:03 +01:00
Viktor Liu
3779a3385f Fix tests 2025-11-05 13:06:54 +01:00
Viktor Liu
b5d75ad9c4 Go fmt everything 2025-11-05 12:59:36 +01:00
Viktor Liu
8db91abfdf Merge branch 'main' into ssh-rewrite 2025-11-05 12:44:17 +01:00
Viktor Liu
6f817cad6d Remove duplicate code 2025-11-03 13:47:33 +01:00
Viktor Liu
e3bb8c1b7b Merge branch 'main' into ssh-rewrite 2025-11-03 13:43:29 +01:00
Viktor Liu
107066fa3d Merge branch 'main' into ssh-rewrite 2025-10-28 22:08:46 +01:00
Viktor Liu
a7a85d4dc8 Fix tests 2025-10-28 21:11:45 +01:00
Viktor Liu
576b4a779c Log shell 2025-10-28 18:15:53 +01:00
Viktor Liu
e6854dfd99 Improve session logging 2025-10-28 17:57:59 +01:00
Viktor Liu
6f14134988 Merge branch 'main' into ssh-rewrite 2025-10-28 16:50:23 +01:00
Viktor Liu
4fd64379da Move client-imported GPL code to separate package 2025-10-23 23:52:44 +02:00
Viktor Liu
c20202a6c3 Add new flags to test 2025-10-17 16:15:05 +02:00
Viktor Liu
4386a21956 Merge branch 'main' into ssh-rewrite 2025-10-17 15:34:36 +02:00
Zoltan Papp
5882daf5d9 Force relay connection, do not waste signaling resources on ICE connection (#4628) 2025-10-13 11:02:21 +02:00
Viktor Liu
11d71e6e22 Ignore default log file 2025-10-10 16:21:39 +02:00
Viktor Liu
4dadcfd9bd Remove client.log check 2025-10-10 16:17:46 +02:00
Viktor Liu
34b55c600e Log errors on debug 2025-10-10 16:11:13 +02:00
Viktor Liu
316c0afa9a Remove unused arg 2025-10-10 11:08:34 +02:00
Viktor Liu
cf97799db8 Fix test 2025-10-10 10:23:45 +02:00
Viktor Liu
4d297205c3 Fix test build 2025-10-09 17:26:25 +02:00
Viktor Liu
559f6aeeaf Improve logging 2025-10-08 18:54:56 +02:00
Viktor Liu
7216c201da Log priv check errors 2025-10-08 18:46:02 +02:00
Viktor Liu
4d89d0f115 Remove unused code 2025-10-08 18:39:41 +02:00
Viktor Liu
610c880ec9 Fix missing jwt config passed to peers 2025-10-08 16:47:11 +02:00
Viktor Liu
19adcb5f63 Merge branch 'main' into ssh-rewrite 2025-10-08 12:40:07 +02:00
Viktor Liu
f3d31698da Skip some auth tests on windows that are already covered 2025-10-07 23:39:01 +02:00
Viktor Liu
d9efe4e944 Add ssh authenatication with jwt (#4550) 2025-10-07 23:38:27 +02:00
Viktor Liu
7e0bbaaa3c Merge branch 'main' into ssh-rewrite 2025-10-07 09:41:07 +02:00
Viktor Liu
b3c7b3c7b2 Fix js build 2025-10-02 15:59:17 +02:00
Viktor Liu
66483ab48d Merge branch 'main' into ssh-rewrite 2025-10-02 15:53:12 +02:00
Viktor Liu
5272fc2b18 Merge branch 'main' into ssh-rewrite 2025-09-25 11:12:47 +02:00
Viktor Liu
4c53372815 Add missing flags 2025-08-27 09:59:12 +02:00
Viktor Liu
79d28b71ee Improve forwarding cancellation 2025-08-26 22:22:15 +02:00
Viktor Liu
77a352763d Fix button style 2025-08-26 21:19:04 +02:00
Viktor Liu
cdd5c6c005 Address review 2025-08-26 21:01:55 +02:00
Viktor Liu
b1a9242c98 Fix merge commit changes 2025-08-26 20:43:29 +02:00
Viktor Liu
b43ef4f17b Merge branch 'main' into ssh-rewrite 2025-08-26 20:09:47 +02:00
Viktor Liu
758a97c352 Generate ssh_config independently of ssh server 2025-07-14 22:02:41 +02:00
Viktor Liu
d93b7c2f38 Fix known hosts entries 2025-07-14 21:41:59 +02:00
Viktor Liu
fa893aa0a4 Fix build 2025-07-12 00:49:08 +02:00
Viktor Liu
ac7120871b Fix proto 2025-07-12 00:11:31 +02:00
Viktor Liu
9a7daa132e Fix client ssh file 2025-07-11 22:08:28 +02:00
Viktor Liu
cdded8c22e Merge branch 'main' into ssh-rewrite 2025-07-11 22:05:12 +02:00
Viktor Liu
e4e0b8fff9 Remove empty file 2025-07-04 17:09:54 +02:00
Viktor Liu
a4b067553d Merge branch 'main' into ssh-rewrite 2025-07-04 16:53:54 +02:00
Viktor Liu
088956645f Fix username validation and skip ci tests properly 2025-07-03 15:36:42 +02:00
Viktor Liu
aa30b7afe8 More windows tests 2025-07-03 14:11:20 +02:00
Viktor Liu
f1bb4d2ac3 Fix more Windows tests 2025-07-03 13:35:53 +02:00
Viktor Liu
982841e25b Test up tests users if none are available on CI 2025-07-03 12:33:31 +02:00
Viktor Liu
a476b8d12f Fix more windows tests 2025-07-03 11:26:04 +02:00
Viktor Liu
a21f924b26 Fix some windows tests 2025-07-03 10:20:16 +02:00
Viktor Liu
9e51d2e8fb Fix lint and sonar 2025-07-03 09:58:25 +02:00
Viktor Liu
3e490d974c Remove duplicated code 2025-07-03 03:40:27 +02:00
Viktor Liu
04bb314426 Allow sftp same user switching on windows 2025-07-03 02:19:12 +02:00
Viktor Liu
6e15882c11 Fix tests and windows username validation 2025-07-03 01:58:15 +02:00
Viktor Liu
76f9e11b29 Fix tests 2025-07-03 01:07:58 +02:00
Viktor Liu
612de2c784 Remove socketfilter temporarily 2025-07-02 22:00:10 +02:00
Viktor Liu
1fdde66c31 More lint 2025-07-02 21:55:25 +02:00
Viktor Liu
5970591d24 Fix lint 2025-07-02 21:32:39 +02:00
Viktor Liu
0d5408baec Fix lint 2025-07-02 21:04:58 +02:00
Viktor Liu
96084e3a02 Reduce complexity 2025-07-02 20:43:17 +02:00
Viktor Liu
4bbca28eb6 Fix lint 2025-07-02 20:23:23 +02:00
Viktor Liu
279b77dee0 Bump sftp 2025-07-02 19:42:57 +02:00
Viktor Liu
9d1554f9f7 Complete overhaul 2025-07-02 19:35:19 +02:00
Viktor Liu
f56075ca15 Tidy mod 2025-07-02 19:34:36 +02:00
Viktor Liu
6ed846ae29 Refactor ssh server and client 2025-07-02 19:34:36 +02:00
Viktor Liu
520f2cfdb4 Remove implicit inbound ssh firewall rules and change default port 2025-07-02 19:34:32 +02:00
Viktor Liu
0f79a8942d Fix route notificaiton 2025-07-02 17:24:14 +02:00
Viktor Liu
5299e9fda3 Merge branch 'main' into android-dns-routes 2025-07-02 15:23:14 +02:00
Viktor Liu
11bdf5b3a5 Use r 2025-06-26 15:41:56 +02:00
Viktor Liu
5fc95d4a0c Display domains properly 2025-06-26 15:36:14 +02:00
Viktor Liu
c7884039b8 Revert "Fix errorf"
This reverts commit 26fc32f1be.
2025-06-25 15:17:31 +02:00
Viktor Liu
26fc32f1be Fix errorf 2025-06-25 15:03:55 +02:00
Viktor Liu
a79cb1c11b Merge branch 'main' into android-dns-routes 2025-06-18 17:27:13 +02:00
Viktor Liu
306d75fe1a Set up fake ip route only if the dns feature flag is enabled 2025-06-17 22:29:13 +02:00
Viktor Liu
9468e69c8c Extract static error 2025-06-17 21:47:05 +02:00
Viktor Liu
f51ce7cee5 Remove nil checks 2025-06-17 21:41:58 +02:00
Viktor Liu
d47c6b624e Fix spelling 2025-06-17 20:02:52 +02:00
Viktor Liu
471f90e8db Rename methods 2025-06-17 15:52:34 +02:00
Viktor Liu
1a3b04d2fe Swap tracking and nat order 2025-06-17 15:45:22 +02:00
Viktor Liu
51b9e93eb9 Merge branch 'main' into android-dns-routes 2025-06-17 15:12:05 +02:00
Viktor Liu
2952669e97 Fix lint 2025-06-17 14:16:59 +02:00
Viktor Liu
7cd44a9a3c Improve nat perf 2025-06-17 13:55:57 +02:00
Viktor Liu
8684981b57 Add tests 2025-06-17 13:41:06 +02:00
Viktor Liu
8e94d85d14 Rename test files 2025-06-17 12:46:17 +02:00
Viktor Liu
631b77dc3c Remove some allocations 2025-06-17 12:44:52 +02:00
Viktor Liu
50ac3d437e Fix lint issues 2025-06-17 03:07:28 +02:00
Viktor Liu
49bbd90557 Fix test 2025-06-17 02:57:15 +02:00
Viktor Liu
bb74e903cd Implement dns routes for Android 2025-06-17 02:48:13 +02:00
179 changed files with 4638 additions and 16830 deletions

View File

@@ -3,108 +3,40 @@ name: Check License Dependencies
on: on:
push: push:
branches: [ main ] branches: [ main ]
paths:
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
pull_request: pull_request:
paths:
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
jobs: jobs:
check-internal-dependencies: check-dependencies:
name: Check Internal AGPL Dependencies
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Check for problematic license dependencies
run: |
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
echo ""
# Find all directories except the problematic ones and system dirs
FOUND_ISSUES=0
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
FOUND_ISSUES=1
else
echo "✓ No problematic dependencies found"
fi
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
echo ""
if [ $FOUND_ISSUES -eq 1 ]; then
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
echo ""
echo "✅ All internal license dependencies are clean"
fi
check-external-licenses:
name: Check External GPL/AGPL Licenses
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Go - name: Check for problematic license dependencies
uses: actions/setup-go@v5
with:
go-version-file: 'go.mod'
cache: true
- name: Install go-licenses
run: go install github.com/google/go-licenses@v1.6.0
- name: Check for GPL/AGPL licensed dependencies
run: | run: |
echo "Checking for GPL/AGPL/LGPL licensed dependencies..." echo "Checking for dependencies on management/, signal/, and relay/ packages..."
echo "" echo ""
# Check all Go packages for copyleft licenses, excluding internal netbird packages # Find all directories except the problematic ones and system dirs
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true) FOUND_ISSUES=0
while IFS= read -r dir; do
if [ -n "$COPYLEFT_DEPS" ]; then echo "=== Checking $dir ==="
echo "Found copyleft licensed dependencies:" # Search for problematic imports, excluding test files
echo "$COPYLEFT_DEPS" RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
echo "" if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
# Filter out dependencies that are only pulled in by internal AGPL packages echo "$RESULTS"
INCOMPATIBLE="" FOUND_ISSUES=1
while IFS=',' read -r package url license; do else
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then echo "✓ No problematic dependencies found"
# Find ALL packages that import this GPL package using go list
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
else
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
fi
fi
done <<< "$COPYLEFT_DEPS"
if [ -n "$INCOMPATIBLE" ]; then
echo ""
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
echo -e "$INCOMPATIBLE"
exit 1
fi fi
fi done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
echo "✅ All external license dependencies are compatible with BSD-3-Clause" echo ""
if [ $FOUND_ISSUES -eq 1 ]; then
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
echo "✅ All license dependencies are clean"
fi

View File

@@ -47,7 +47,7 @@ jobs:
with: with:
go-version: "1.23.x" go-version: "1.23.x"
- name: Build Wasm client - name: Build Wasm client
run: GOOS=js GOARCH=wasm go build -o netbird.wasm -ldflags="-s -w" ./client/wasm/cmd run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
env: env:
CGO_ENABLED: 0 CGO_ENABLED: 0
- name: Check Wasm build size - name: Check Wasm build size

View File

@@ -200,7 +200,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
} }
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) { func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "") oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/upload-server/types" "github.com/netbirdio/netbird/upload-server/types"
) )
@@ -97,6 +98,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
request := &proto.DebugBundleRequest{ request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: systemInfoFlag, SystemInfo: systemInfoFlag,
LogFileCount: logFileCount, LogFileCount: logFileCount,
} }
@@ -218,6 +220,9 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil { if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr return waitErr
} }
@@ -225,8 +230,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Creating debug bundle...") cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{ request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: systemInfoFlag, SystemInfo: systemInfoFlag,
LogFileCount: logFileCount, LogFileCount: logFileCount,
} }
@@ -293,6 +301,25 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context(), true)
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
)
}
return statusOutputString
}
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error { func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
ticker := time.NewTicker(1 * time.Second) ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop() defer ticker.Stop()
@@ -351,7 +378,7 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
InternalConfig: config, InternalConfig: config,
StatusRecorder: recorder, StatusRecorder: recorder,
SyncResponse: syncResponse, SyncResponse: syncResponse,
LogPath: logFilePath, LogFile: logFilePath,
}, },
debug.BundleConfig{ debug.BundleConfig{
IncludeSystemInfo: true, IncludeSystemInfo: true,

View File

@@ -106,13 +106,6 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
Username: &username, Username: &username,
} }
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginRequest.Hint = &profileState.Email
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
loginRequest.OptionalPreSharedKey = &preSharedKey loginRequest.OptionalPreSharedKey = &preSharedKey
} }
@@ -248,7 +241,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
return fmt.Errorf("read config file %s: %v", configFilePath, err) return fmt.Errorf("read config file %s: %v", configFilePath, err)
} }
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name) err = foregroundLogin(ctx, cmd, config, setupKey)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@@ -276,7 +269,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
return nil return nil
} }
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error { func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error {
needsLogin := false needsLogin := false
err := WithBackOff(func() error { err := WithBackOff(func() error {
@@ -293,7 +286,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken := "" jwtToken := ""
if setupKey == "" && needsLogin { if setupKey == "" && needsLogin {
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName) tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config)
if err != nil { if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err) return fmt.Errorf("interactive sso login failed: %v", err)
} }
@@ -322,17 +315,8 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
return nil return nil
} }
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) { func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) {
hint := "" oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
pm := profilemanager.NewProfileManager()
profileState, err := pm.GetProfileState(profileName)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
hint = profileState.Email
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -259,7 +259,6 @@ func isServiceRunning() (bool, error) {
} }
const ( const (
networkdConf = "/etc/systemd/networkd.conf"
networkdConfDir = "/etc/systemd/networkd.conf.d" networkdConfDir = "/etc/systemd/networkd.conf.d"
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf" networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
@@ -274,16 +273,12 @@ ManageForeignRoutingPolicyRules=no
// configureSystemdNetworkd creates a drop-in configuration file to prevent // configureSystemdNetworkd creates a drop-in configuration file to prevent
// systemd-networkd from removing NetBird's routes and policy rules. // systemd-networkd from removing NetBird's routes and policy rules.
func configureSystemdNetworkd() error { func configureSystemdNetworkd() error {
if _, err := os.Stat(networkdConf); os.IsNotExist(err) { parentDir := filepath.Dir(networkdConfDir)
log.Debug("systemd-networkd not in use, skipping configuration") if _, err := os.Stat(parentDir); os.IsNotExist(err) {
log.Debug("systemd networkd.conf.d parent directory does not exist, skipping configuration")
return nil return nil
} }
// nolint:gosec // standard networkd permissions
if err := os.MkdirAll(networkdConfDir, 0755); err != nil {
return fmt.Errorf("create networkd.conf.d directory: %w", err)
}
// nolint:gosec // standard networkd permissions // nolint:gosec // standard networkd permissions
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil { if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
return fmt.Errorf("write networkd configuration: %w", err) return fmt.Errorf("write networkd configuration: %w", err)

View File

@@ -14,9 +14,7 @@ import (
"strings" "strings"
"syscall" "syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
sshclient "github.com/netbirdio/netbird/client/ssh/client" sshclient "github.com/netbirdio/netbird/client/ssh/client"
@@ -36,7 +34,6 @@ const (
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding" enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding" enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
disableSSHAuthFlag = "disable-ssh-auth" disableSSHAuthFlag = "disable-ssh-auth"
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
) )
var ( var (
@@ -50,7 +47,6 @@ var (
knownHostsFile string knownHostsFile string
identityFile string identityFile string
skipCachedToken bool skipCachedToken bool
requestPTY bool
) )
var ( var (
@@ -60,7 +56,6 @@ var (
enableSSHLocalPortForward bool enableSSHLocalPortForward bool
enableSSHRemotePortForward bool enableSSHRemotePortForward bool
disableSSHAuth bool disableSSHAuth bool
sshJWTCacheTTL int
) )
func init() { func init() {
@@ -70,16 +65,13 @@ func init() {
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server") upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server") upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication") upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port") sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc) sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)") sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation")
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)") sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)") sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)") sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file")
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport") sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
@@ -105,9 +97,9 @@ SSH Options:
-p, --port int Remote SSH port (default 22) -p, --port int Remote SSH port (default 22)
-u, --user string SSH username -u, --user string SSH username
--login string SSH username (alias for --user) --login string SSH username (alias for --user)
-t, --tty Force pseudo-terminal allocation
--strict-host-key-checking Enable strict host key checking (default: true) --strict-host-key-checking Enable strict host key checking (default: true)
-o, --known-hosts string Path to known_hosts file -o, --known-hosts string Path to known_hosts file
-i, --identity string Path to SSH private key file
Examples: Examples:
netbird ssh peer-hostname netbird ssh peer-hostname
@@ -115,10 +107,8 @@ Examples:
netbird ssh --login root peer-hostname netbird ssh --login root peer-hostname
netbird ssh peer-hostname ls -la netbird ssh peer-hostname ls -la
netbird ssh peer-hostname whoami netbird ssh peer-hostname whoami
netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`, netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
DisableFlagParsing: true, DisableFlagParsing: true,
@@ -153,10 +143,10 @@ func sshFn(cmd *cobra.Command, args []string) error {
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT) signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
sshctx, cancel := context.WithCancel(ctx) sshctx, cancel := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() { go func() {
if err := runSSH(sshctx, host, cmd); err != nil { if err := runSSH(sshctx, host, cmd); err != nil {
errCh <- err cmd.Printf("Error: %v\n", err)
os.Exit(1)
} }
cancel() cancel()
}() }()
@@ -164,10 +154,6 @@ func sshFn(cmd *cobra.Command, args []string) error {
select { select {
case <-sig: case <-sig:
cancel() cancel()
<-sshctx.Done()
return nil
case err := <-errCh:
return err
case <-sshctx.Done(): case <-sshctx.Done():
} }
@@ -365,7 +351,6 @@ type sshFlags struct {
Port int Port int
Username string Username string
Login string Login string
RequestPTY bool
StrictHostKeyChecking bool StrictHostKeyChecking bool
KnownHostsFile string KnownHostsFile string
IdentityFile string IdentityFile string
@@ -388,24 +373,22 @@ func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
flags := &sshFlags{} flags := &sshFlags{}
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port") fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port") fs.Int("port", sshserver.DefaultSSHPort, "SSH port")
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc) fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
fs.StringVar(&flags.Username, "user", "", sshUsernameDesc) fs.String("user", "", sshUsernameDesc)
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)") fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation")
fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation")
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking") fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file") fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file") fs.String("known-hosts", "", "Path to known_hosts file")
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file") fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file") fs.String("identity", "", "Path to SSH private key file")
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location") fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location") fs.String("config", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level") fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level") fs.String("log-level", defaultLogLevel, "sets Netbird log level")
return fs, flags return fs, flags
} }
@@ -426,10 +409,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
fs, flags := createSSHFlagSet() fs, flags := createSSHFlagSet()
if err := fs.Parse(filteredArgs); err != nil { if err := fs.Parse(filteredArgs); err != nil {
if errors.Is(err, flag.ErrHelp) { return parseHostnameAndCommand(filteredArgs)
return nil
}
return err
} }
remaining := fs.Args() remaining := fs.Args()
@@ -444,7 +424,6 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
username = flags.Login username = flags.Login
} }
requestPTY = flags.RequestPTY
strictHostKeyChecking = flags.StrictHostKeyChecking strictHostKeyChecking = flags.StrictHostKeyChecking
knownHostsFile = flags.KnownHostsFile knownHostsFile = flags.KnownHostsFile
identityFile = flags.IdentityFile identityFile = flags.IdentityFile
@@ -541,29 +520,10 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
// executeSSHCommand executes a command over SSH. // executeSSHCommand executes a command over SSH.
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error { func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
var err error if err := c.ExecuteCommandWithIO(ctx, command); err != nil {
if requestPTY {
err = c.ExecuteCommandWithPTY(ctx, command)
} else {
err = c.ExecuteCommandWithIO(ctx, command)
}
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil return nil
} }
var exitErr *ssh.ExitError
if errors.As(err, &exitErr) {
os.Exit(exitErr.ExitStatus())
}
var exitMissingErr *ssh.ExitMissingError
if errors.As(err, &exitMissingErr) {
log.Debugf("Remote command exited without exit status: %v", err)
return nil
}
return fmt.Errorf("execute command: %w", err) return fmt.Errorf("execute command: %w", err)
} }
return nil return nil
@@ -575,13 +535,6 @@ func openSSHTerminal(ctx context.Context, c *sshclient.Client) error {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil return nil
} }
var exitMissingErr *ssh.ExitMissingError
if errors.As(err, &exitMissingErr) {
log.Debugf("Remote terminal exited without exit status: %v", err)
return nil
}
return fmt.Errorf("open terminal: %w", err) return fmt.Errorf("open terminal: %w", err)
} }
return nil return nil
@@ -765,11 +718,6 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
if err != nil { if err != nil {
return fmt.Errorf("create SSH proxy: %w", err) return fmt.Errorf("create SSH proxy: %w", err)
} }
defer func() {
if err := proxy.Close(); err != nil {
log.Debugf("close SSH proxy: %v", err)
}
}()
if err := proxy.Connect(cmd.Context()); err != nil { if err := proxy.Connect(cmd.Context()); err != nil {
return fmt.Errorf("SSH proxy: %w", err) return fmt.Errorf("SSH proxy: %w", err)

View File

@@ -8,7 +8,6 @@ import (
"io" "io"
"os" "os"
"os/user" "os/user"
"strings"
"github.com/pkg/sftp" "github.com/pkg/sftp"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -52,7 +51,7 @@ func sftpMainDirect(cmd *cobra.Command) error {
if windowsDomain != "" { if windowsDomain != "" {
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername) expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
} }
if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) { if currentUser.Username != expectedUsername && currentUser.Username != windowsUsername {
cmd.PrintErrf("user switching failed\n") cmd.PrintErrf("user switching failed\n")
os.Exit(sshserver.ExitCodeValidationFail) os.Exit(sshserver.ExitCodeValidationFail)
} }

View File

@@ -667,51 +667,3 @@ func TestSSHCommand_ParameterIsolation(t *testing.T) {
}) })
} }
} }
func TestSSHCommand_InvalidFlagRejection(t *testing.T) {
// Test that invalid flags are properly rejected and not misinterpreted as hostnames
tests := []struct {
name string
args []string
description string
}{
{
name: "invalid long flag before hostname",
args: []string{"--invalid-flag", "hostname"},
description: "Invalid flag should return parse error, not treat flag as hostname",
},
{
name: "invalid short flag before hostname",
args: []string{"-x", "hostname"},
description: "Invalid short flag should return parse error",
},
{
name: "invalid flag with value before hostname",
args: []string{"--invalid-option=value", "hostname"},
description: "Invalid flag with value should return parse error",
},
{
name: "typo in known flag",
args: []string{"--por", "2222", "hostname"},
description: "Typo in flag name should return parse error (not silently ignored)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
// Should return an error for invalid flags
assert.Error(t, err, tt.description)
// Should not have set host to the invalid flag
assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname")
})
}
}

View File

@@ -99,7 +99,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
profName = activeProf.Name profName = activeProf.Name
} }
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName) var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var statusOutputString string var statusOutputString string
switch { switch {
case detailFlag: case detailFlag:
@@ -109,7 +109,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
case yamlFlag: case yamlFlag:
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder) statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
default: default:
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false) statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
} }
if err != nil { if err != nil {

View File

@@ -13,11 +13,6 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
clientProto "github.com/netbirdio/netbird/client/proto" clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server" client "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
@@ -89,7 +84,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
jobManager := job.NewJobManager(nil, store) peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil { if err != nil {
return nil, nil return nil, nil
@@ -115,18 +110,13 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
Return(&types.Settings{}, nil). Return(&types.Settings{}, nil).
AnyTimes() AnyTimes()
ctx := context.Background() accountManager, err := mgmt.BuildManager(context.Background(), config, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, jobManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath) _, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name) err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@@ -200,7 +200,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
connectClient := internal.NewConnectClient(ctx, config, r) connectClient := internal.NewConnectClient(ctx, config, r)
SetupDebugHandler(ctx, config, r, connectClient, "") SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil, util.FindFirstLogPath(logFiles)) return connectClient.Run(nil)
} }
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error { func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
@@ -286,13 +286,6 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
loginRequest.ProfileName = &activeProf.Name loginRequest.ProfileName = &activeProf.Name
loginRequest.Username = &username loginRequest.Username = &username
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginRequest.Hint = &profileState.Email
}
var loginErr error var loginErr error
var loginResp *proto.LoginResponse var loginResp *proto.LoginResponse
@@ -362,18 +355,14 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
req.EnableSSHSFTP = &enableSSHSFTP req.EnableSSHSFTP = &enableSSHSFTP
} }
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed { if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward req.EnableSSHLocalPortForward = &enableSSHLocalPortForward
} }
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed { if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward req.EnableSSHRemotePortForward = &enableSSHRemotePortForward
} }
if cmd.Flag(disableSSHAuthFlag).Changed { if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth req.DisableSSHAuth = &disableSSHAuth
} }
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
req.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(interfaceNameFlag).Changed { if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil { if err := parseInterfaceName(interfaceName); err != nil {
log.Errorf("parse interface name: %v", err) log.Errorf("parse interface name: %v", err)
@@ -478,10 +467,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.DisableSSHAuth = &disableSSHAuth ic.DisableSSHAuth = &disableSSHAuth
} }
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
}
if cmd.Flag(interfaceNameFlag).Changed { if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil { if err := parseInterfaceName(interfaceName); err != nil {
return nil, err return nil, err
@@ -602,11 +587,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.DisableSSHAuth = &disableSSHAuth loginRequest.DisableSSHAuth = &disableSSHAuth
} }
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(disableAutoConnectFlag).Changed { if cmd.Flag(disableAutoConnectFlag).Changed {
loginRequest.DisableAutoConnect = &autoConnectDisabled loginRequest.DisableAutoConnect = &autoConnectDisabled
} }

View File

@@ -173,7 +173,6 @@ func (c *Client) Start(startCtx context.Context) error {
} }
recorder := peer.NewRecorder(c.config.ManagementURL.String()) recorder := peer.NewRecorder(c.config.ManagementURL.String())
client := internal.NewConnectClient(ctx, c.config, recorder) client := internal.NewConnectClient(ctx, c.config, recorder)
// either startup error (permanent backoff err) or nil err (successful engine up) // either startup error (permanent backoff err) or nil err (successful engine up)
@@ -181,7 +180,7 @@ func (c *Client) Start(startCtx context.Context) error {
run := make(chan struct{}) run := make(chan struct{})
clientErr := make(chan error, 1) clientErr := make(chan error, 1)
go func() { go func() {
if err := client.Run(run, ""); err != nil { if err := client.Run(run); err != nil {
clientErr <- err clientErr <- err
} }
}() }()

View File

@@ -1,14 +1,13 @@
package iptables package iptables
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"slices" "slices"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/uuid" "github.com/google/uuid"
ipset "github.com/lrh3321/ipset-go" "github.com/nadoo/ipset"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -41,13 +40,19 @@ type aclManager struct {
} }
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) { func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
return &aclManager{ m := &aclManager{
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
wgIface: wgIface, wgIface: wgIface,
entries: make(map[string][][]string), entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry), optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
}, nil }
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
return m, nil
} }
func (m *aclManager) init(stateManager *statemanager.Manager) error { func (m *aclManager) init(stateManager *statemanager.Manager) error {
@@ -93,8 +98,8 @@ func (m *aclManager) AddPeerFiltering(
specs = append(specs, "-j", actionToStr(action)) specs = append(specs, "-j", actionToStr(action))
if ipsetName != "" { if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists { if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := m.addToIPSet(ipsetName, ip); err != nil { if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err) return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
} }
// if ruleset already exists it means we already have the firewall rule // if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager. // so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
@@ -108,18 +113,14 @@ func (m *aclManager) AddPeerFiltering(
}}, nil }}, nil
} }
if err := m.flushIPSet(ipsetName); err != nil { if err := ipset.Flush(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) { log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
}
} }
if err := m.createIPSet(ipsetName); err != nil { if err := ipset.Create(ipsetName); err != nil {
return nil, fmt.Errorf("create ipset: %w", err) return nil, fmt.Errorf("failed to create ipset: %w", err)
} }
if err := m.addToIPSet(ipsetName, ip); err != nil { if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err) return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
} }
ipList := newIpList(ip.String()) ipList := newIpList(ip.String())
@@ -171,16 +172,11 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("invalid rule type") return fmt.Errorf("invalid rule type")
} }
shouldDestroyIpset := false
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok { if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset // delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok { if _, ok := ipsetList.ips[r.ip]; ok {
ip := net.ParseIP(r.ip) if err := ipset.Del(r.ipsetName, r.ip); err != nil {
if ip == nil { return fmt.Errorf("failed to delete ip from ipset: %w", err)
return fmt.Errorf("parse IP %s", r.ip)
}
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
return fmt.Errorf("delete ip from ipset: %w", err)
} }
delete(ipsetList.ips, r.ip) delete(ipsetList.ips, r.ip)
} }
@@ -194,7 +190,10 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
// we delete last IP from the set, that means we need to delete // we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too // set itself and associated firewall rule too
m.ipsetStore.deleteIpset(r.ipsetName) m.ipsetStore.deleteIpset(r.ipsetName)
shouldDestroyIpset = true
if err := ipset.Destroy(r.ipsetName); err != nil {
log.Errorf("delete empty ipset: %v", err)
}
} }
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil { if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
@@ -207,16 +206,6 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
} }
} }
if shouldDestroyIpset {
if err := m.destroyIPSet(r.ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("destroy empty ipset: %v", err)
} else {
log.Errorf("destroy empty ipset: %v", err)
}
}
}
m.updateState() m.updateState()
return nil return nil
@@ -275,19 +264,11 @@ func (m *aclManager) cleanChains() error {
} }
for _, ipsetName := range m.ipsetStore.ipsetNames() { for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := m.flushIPSet(ipsetName); err != nil { if err := ipset.Flush(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) { log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
} }
if err := m.destroyIPSet(ipsetName); err != nil { if err := ipset.Destroy(ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) { log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
}
} }
m.ipsetStore.deleteIpset(ipsetName) m.ipsetStore.deleteIpset(ipsetName)
} }
@@ -387,8 +368,8 @@ func (m *aclManager) updateState() {
// filterRuleSpecs returns the specs of a filtering rule // filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) { func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
matchByIP := true matchByIP := true
// don't use IP matching if IP is 0.0.0.0 // don't use IP matching if IP is ip 0.0.0.0
if ip.IsUnspecified() { if ip.String() == "0.0.0.0" {
matchByIP = false matchByIP = false
} }
@@ -435,61 +416,3 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
return ipsetName + actionSuffix return ipsetName + actionSuffix
} }
} }
func (m *aclManager) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add IP to ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
}
if err := ipset.Del(name, entry); err != nil {
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) flushIPSet(name string) error {
return ipset.Flush(name)
}
func (m *aclManager) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
ipset "github.com/lrh3321/ipset-go" "github.com/nadoo/ipset"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
@@ -107,6 +107,10 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
}, },
) )
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
return r, nil return r, nil
} }
@@ -228,12 +232,12 @@ func (r *router) findSets(rule []string) []string {
} }
func (r *router) createIpSet(setName string, sources []netip.Prefix) error { func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
if err := r.createIPSet(setName); err != nil { if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
return fmt.Errorf("create set %s: %w", setName, err) return fmt.Errorf("create set %s: %w", setName, err)
} }
for _, prefix := range sources { for _, prefix := range sources {
if err := r.addPrefixToIPSet(setName, prefix); err != nil { if err := ipset.AddPrefix(setName, prefix); err != nil {
return fmt.Errorf("add element to set %s: %w", setName, err) return fmt.Errorf("add element to set %s: %w", setName, err)
} }
} }
@@ -242,7 +246,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
} }
func (r *router) deleteIpSet(setName string) error { func (r *router) deleteIpSet(setName string) error {
if err := r.destroyIPSet(setName); err != nil { if err := ipset.Destroy(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err) return fmt.Errorf("destroy set %s: %w", setName, err)
} }
@@ -911,8 +915,8 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue continue
} }
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil { if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err)) merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
} }
} }
if merr == nil { if merr == nil {
@@ -989,37 +993,3 @@ func applyPort(flag string, port *firewall.Port) []string {
return []string{flag, strconv.Itoa(int(port.Values[0]))} return []string{flag, strconv.Itoa(int(port.Values[0]))}
} }
func (r *router) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
addr := prefix.Addr()
ip := addr.AsSlice()
entry := &ipset.Entry{
IP: ip,
CIDR: uint8(prefix.Bits()),
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
}
return nil
}
func (r *router) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors"
"fmt" "fmt"
"runtime" "runtime"
"time" "time"
@@ -11,6 +12,7 @@ import (
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
@@ -18,6 +20,9 @@ import (
"github.com/netbirdio/netbird/util/embeddedroots" "github.com/netbirdio/netbird/util/embeddedroots"
) )
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
// Backoff returns a backoff configuration for gRPC calls // Backoff returns a backoff configuration for gRPC calls
func Backoff(ctx context.Context) backoff.BackOff { func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff() b := backoff.NewExponentialBackOff()
@@ -26,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx) return backoff.WithContext(b, ctx)
} }
// waitForConnectionReady blocks until the connection becomes ready or fails.
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
conn.Connect()
state := conn.GetState()
for state != connectivity.Ready && state != connectivity.Shutdown {
if !conn.WaitForStateChange(ctx, state) {
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
}
state = conn.GetState()
}
if state == connectivity.Shutdown {
return ErrConnectionShutdown
}
return nil
}
// CreateConnection creates a gRPC client connection with the appropriate transport options. // CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
@@ -43,22 +68,25 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
})) }))
} }
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) conn, err := grpc.NewClient(
defer cancel()
conn, err := grpc.DialContext(
connCtx,
addr, addr,
transportOption, transportOption,
WithCustomDialer(tlsEnabled, component), WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{ grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second, Time: 30 * time.Second,
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
}), }),
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("dial context: %w", err) return nil, fmt.Errorf("new client: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
if err := waitForConnectionReady(ctx, conn); err != nil {
_ = conn.Close()
return nil, err
} }
return conn, nil return conn, nil

View File

@@ -1,7 +1,6 @@
package iface package iface
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -10,13 +9,13 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/stdnet"
) )
// keep darwin compatibility // keep darwin compatibility
@@ -41,7 +40,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
addr := "100.64.0.1/8" addr := "100.64.0.1/8"
wgPort := 33100 wgPort := 33100
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -124,7 +123,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
func Test_CreateInterface(t *testing.T) { func Test_CreateInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
wgIP := "10.99.99.1/32" wgIP := "10.99.99.1/32"
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -167,7 +166,7 @@ func Test_Close(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32" wgIP := "10.99.99.2/32"
wgPort := 33100 wgPort := 33100
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -212,7 +211,7 @@ func TestRecreation(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32" wgIP := "10.99.99.2/32"
wgPort := 33100 wgPort := 33100
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -285,7 +284,7 @@ func Test_ConfigureInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
wgIP := "10.99.99.5/30" wgIP := "10.99.99.5/30"
wgPort := 33100 wgPort := 33100
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -340,7 +339,7 @@ func Test_ConfigureInterface(t *testing.T) {
func Test_UpdatePeer(t *testing.T) { func Test_UpdatePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.9/30" wgIP := "10.99.99.9/30"
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -410,7 +409,7 @@ func Test_UpdatePeer(t *testing.T) {
func Test_RemovePeer(t *testing.T) { func Test_RemovePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.13/30" wgIP := "10.99.99.13/30"
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -472,7 +471,7 @@ func Test_ConnectPeers(t *testing.T) {
peer2wgPort := 33200 peer2wgPort := 33200
keepAlive := 1 * time.Second keepAlive := 1 * time.Second
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -515,7 +514,7 @@ func Test_ConnectPeers(t *testing.T) {
guid = fmt.Sprintf("{%s}", uuid.New().String()) guid = fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid) device.CustomWindowsGUIDString = strings.ToLower(guid)
newNet, err = stdnet.NewNet(context.Background(), nil) newNet, err = stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -1,7 +1,6 @@
package udpmux package udpmux
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net" "net"
@@ -13,9 +12,8 @@ import (
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun/v3" "github.com/pion/stun/v3"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
) )
/* /*
@@ -201,7 +199,7 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() {
if len(networks) > 0 { if len(networks) > 0 {
if m.params.Net == nil { if m.params.Net == nil {
var err error var err error
if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil { if m.params.Net, err = stdnet.NewNet(); err != nil {
m.params.Logger.Errorf("failed to get create network: %v", err) m.params.Logger.Errorf("failed to get create network: %v", err)
} }
} }

View File

@@ -128,34 +128,9 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow
deviceCode.VerificationURIComplete = deviceCode.VerificationURI deviceCode.VerificationURIComplete = deviceCode.VerificationURI
} }
if d.providerConfig.LoginHint != "" {
deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint)
if deviceCode.VerificationURI != "" {
deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint)
}
}
return deviceCode, err return deviceCode, err
} }
func appendLoginHint(uri, loginHint string) string {
if uri == "" || loginHint == "" {
return uri
}
parsedURL, err := url.Parse(uri)
if err != nil {
log.Debugf("failed to parse verification URI for login_hint: %v", err)
return uri
}
query := parsedURL.Query()
query.Set("login_hint", loginHint)
parsedURL.RawQuery = query.Encode()
return parsedURL.String()
}
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) { func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
form := url.Values{} form := url.Values{}
form.Add("client_id", d.providerConfig.ClientID) form.Add("client_id", d.providerConfig.ClientID)

View File

@@ -66,34 +66,32 @@ func (t TokenInfo) GetTokenToUse() string {
// and if that also fails, the authentication process is deemed unsuccessful // and if that also fails, the authentication process is deemed unsuccessful
// //
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) { func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient { if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config, hint) return authenticateWithDeviceCodeFlow(ctx, config)
} }
pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint) pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
if err != nil { if err != nil {
// fallback to device code flow
log.Debugf("failed to initialize pkce authentication with error: %v\n", err) log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
log.Debug("falling back to device code flow") log.Debug("falling back to device code flow")
return authenticateWithDeviceCodeFlow(ctx, config, hint) return authenticateWithDeviceCodeFlow(ctx, config)
} }
return pkceFlow, nil return pkceFlow, nil
} }
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow // authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair) pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
if err != nil { if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
} }
pkceFlowInfo.ProviderConfig.LoginHint = hint
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
} }
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow // authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil { if err != nil {
switch s, ok := gstatus.FromError(err); { switch s, ok := gstatus.FromError(err); {
@@ -109,7 +107,5 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
} }
} }
deviceFlowInfo.ProviderConfig.LoginHint = hint
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig) return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
} }

View File

@@ -109,9 +109,6 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
params = append(params, oauth2.SetAuthURLParam("max_age", "0")) params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
} }
} }
if p.providerConfig.LoginHint != "" {
params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint))
}
authURL := p.oAuthConfig.AuthCodeURL(state, params...) authURL := p.oAuthConfig.AuthCodeURL(state, params...)
@@ -192,20 +189,17 @@ func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token,
if authError := query.Get(queryError); authError != "" { if authError := query.Get(queryError); authError != "" {
authErrorDesc := query.Get(queryErrorDesc) authErrorDesc := query.Get(queryErrorDesc)
if authErrorDesc != "" { return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
return nil, fmt.Errorf("authentication failed: %s", authErrorDesc)
}
return nil, fmt.Errorf("authentication failed: %s", authError)
} }
// Prevent timing attacks on the state // Prevent timing attacks on the state
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 { if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
return nil, fmt.Errorf("authentication failed: Invalid state") return nil, fmt.Errorf("invalid state")
} }
code := query.Get(queryCode) code := query.Get(queryCode)
if code == "" { if code == "" {
return nil, fmt.Errorf("authentication failed: missing code") return nil, fmt.Errorf("missing code")
} }
return p.oAuthConfig.Exchange( return p.oAuthConfig.Exchange(
@@ -234,7 +228,7 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo,
} }
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil { if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
return TokenInfo{}, fmt.Errorf("authentication failed: invalid access token - %w", err) return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
} }
email, err := parseEmailFromIDToken(tokenInfo.IDToken) email, err := parseEmailFromIDToken(tokenInfo.IDToken)

View File

@@ -52,6 +52,7 @@ func NewConnectClient(
ctx context.Context, ctx context.Context,
config *profilemanager.Config, config *profilemanager.Config,
statusRecorder *peer.Status, statusRecorder *peer.Status,
) *ConnectClient { ) *ConnectClient {
return &ConnectClient{ return &ConnectClient{
ctx: ctx, ctx: ctx,
@@ -62,8 +63,8 @@ func NewConnectClient(
} }
// Run with main logic. // Run with main logic.
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error { func (c *ConnectClient) Run(runningChan chan struct{}) error {
return c.run(MobileDependency{}, runningChan, logPath) return c.run(MobileDependency{}, runningChan)
} }
// RunOnAndroid with main logic on mobile system // RunOnAndroid with main logic on mobile system
@@ -82,7 +83,7 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses, HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener, DnsReadyListener: dnsReadyListener,
} }
return c.run(mobileDependency, nil, "") return c.run(mobileDependency, nil)
} }
func (c *ConnectClient) RunOniOS( func (c *ConnectClient) RunOniOS(
@@ -100,10 +101,10 @@ func (c *ConnectClient) RunOniOS(
DnsManager: dnsManager, DnsManager: dnsManager,
StateFilePath: stateFilePath, StateFilePath: stateFilePath,
} }
return c.run(mobileDependency, nil, "") return c.run(mobileDependency, nil)
} }
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error { func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
rec := c.statusRecorder rec := c.statusRecorder
@@ -246,7 +247,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
relayURLs, token := parseRelayInfo(loginResp) relayURLs, token := parseRelayInfo(loginResp)
peerConfig := loginResp.GetPeerConfig() peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath) engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return wrapErr(err) return wrapErr(err)
@@ -270,7 +271,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks() checks := loginResp.GetChecks()
c.engineMutex.Lock() c.engineMutex.Lock()
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, c.config) c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
c.engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engineMutex.Unlock() c.engineMutex.Unlock()
@@ -409,7 +410,7 @@ func (c *ConnectClient) SetSyncResponsePersistence(enabled bool) {
} }
// createEngineConfig converts configuration received from Management Service to EngineConfig // createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig, logPath string) (*EngineConfig, error) { func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false nm := false
if config.NetworkMonitor != nil { if config.NetworkMonitor != nil {
nm = *config.NetworkMonitor nm = *config.NetworkMonitor
@@ -444,10 +445,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
LazyConnectionEnabled: config.LazyConnectionEnabled, LazyConnectionEnabled: config.LazyConnectionEnabled,
MTU: selectMTU(config.MTU, peerConfig.Mtu), MTU: selectMTU(config.MTU, peerConfig.Mtu),
LogPath: logPath,
ProfileConfig: config,
} }
if config.PreSharedKey != "" { if config.PreSharedKey != "" {

View File

@@ -27,10 +27,8 @@ import (
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
) )
const readmeContent = `Netbird debug bundle const readmeContent = `Netbird debug bundle
@@ -46,8 +44,6 @@ interfaces.txt: Anonymized network interface information, if --system-info flag
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided. ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided. iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided. nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
config.txt: Anonymized configuration information of the NetBird client. config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
@@ -188,20 +184,6 @@ The ip_rules.txt file contains detailed IP routing rule information:
The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing. The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing.
For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged. For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged.
DNS Configuration
The debug bundle includes platform-specific DNS configuration files:
resolv.conf (Unix systems):
- Contains DNS resolver configuration from /etc/resolv.conf
- Includes nameserver entries, search domains, and resolver options
- All IP addresses and domain names are anonymized following the same rules as other files
scutil_dns.txt (macOS only):
- Contains detailed DNS configuration from scutil --dns
- Shows DNS configuration for all network interfaces
- Includes search domains, nameservers, and DNS resolver settings
- All IP addresses and domain names are anonymized
` `
const ( const (
@@ -220,9 +202,10 @@ type BundleGenerator struct {
internalConfig *profilemanager.Config internalConfig *profilemanager.Config
statusRecorder *peer.Status statusRecorder *peer.Status
syncResponse *mgmProto.SyncResponse syncResponse *mgmProto.SyncResponse
logPath string logFile string
anonymize bool anonymize bool
clientStatus string
includeSystemInfo bool includeSystemInfo bool
logFileCount uint32 logFileCount uint32
@@ -231,6 +214,7 @@ type BundleGenerator struct {
type BundleConfig struct { type BundleConfig struct {
Anonymize bool Anonymize bool
ClientStatus string
IncludeSystemInfo bool IncludeSystemInfo bool
LogFileCount uint32 LogFileCount uint32
} }
@@ -239,7 +223,7 @@ type GeneratorDependencies struct {
InternalConfig *profilemanager.Config InternalConfig *profilemanager.Config
StatusRecorder *peer.Status StatusRecorder *peer.Status
SyncResponse *mgmProto.SyncResponse SyncResponse *mgmProto.SyncResponse
LogPath string LogFile string
} }
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator { func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
@@ -255,9 +239,10 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
internalConfig: deps.InternalConfig, internalConfig: deps.InternalConfig,
statusRecorder: deps.StatusRecorder, statusRecorder: deps.StatusRecorder,
syncResponse: deps.SyncResponse, syncResponse: deps.SyncResponse,
logPath: deps.LogPath, logFile: deps.LogFile,
anonymize: cfg.Anonymize, anonymize: cfg.Anonymize,
clientStatus: cfg.ClientStatus,
includeSystemInfo: cfg.IncludeSystemInfo, includeSystemInfo: cfg.IncludeSystemInfo,
logFileCount: logFileCount, logFileCount: logFileCount,
} }
@@ -303,6 +288,13 @@ func (g *BundleGenerator) createArchive() error {
return fmt.Errorf("add status: %w", err) return fmt.Errorf("add status: %w", err)
} }
if g.statusRecorder != nil {
status := g.statusRecorder.GetFullStatus()
seedFromStatus(g.anonymizer, &status)
} else {
log.Debugf("no status recorder available for seeding")
}
if err := g.addConfig(); err != nil { if err := g.addConfig(); err != nil {
log.Errorf("failed to add config to debug bundle: %v", err) log.Errorf("failed to add config to debug bundle: %v", err)
} }
@@ -335,7 +327,7 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add wg show output: %v", err) log.Errorf("failed to add wg show output: %v", err)
} }
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) { if g.logFile != "" && !slices.Contains(util.SpecialLogs, g.logFile) {
if err := g.addLogfile(); err != nil { if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err) log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil { if err := g.trySystemdLogFallback(); err != nil {
@@ -365,10 +357,6 @@ func (g *BundleGenerator) addSystemInfo() {
if err := g.addFirewallRules(); err != nil { if err := g.addFirewallRules(); err != nil {
log.Errorf("failed to add firewall rules to debug bundle: %v", err) log.Errorf("failed to add firewall rules to debug bundle: %v", err)
} }
if err := g.addDNSInfo(); err != nil {
log.Errorf("failed to add DNS info to debug bundle: %v", err)
}
} }
func (g *BundleGenerator) addReadme() error { func (g *BundleGenerator) addReadme() error {
@@ -380,26 +368,11 @@ func (g *BundleGenerator) addReadme() error {
} }
func (g *BundleGenerator) addStatus() error { func (g *BundleGenerator) addStatus() error {
if g.statusRecorder != nil { if status := g.clientStatus; status != "" {
pm := profilemanager.NewProfileManager() statusReader := strings.NewReader(status)
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
fullStatus := g.statusRecorder.GetFullStatus()
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
statusOutput := nbstatus.ParseToFullDetailSummary(overview)
statusReader := strings.NewReader(statusOutput)
if err := g.addFileToZip(statusReader, "status.txt"); err != nil { if err := g.addFileToZip(statusReader, "status.txt"); err != nil {
return fmt.Errorf("add status file to zip: %w", err) return fmt.Errorf("add status file to zip: %w", err)
} }
seedFromStatus(g.anonymizer, &fullStatus)
} else {
log.Debugf("no status recorder available for seeding")
} }
return nil return nil
} }
@@ -669,14 +642,14 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
} }
func (g *BundleGenerator) addLogfile() error { func (g *BundleGenerator) addLogfile() error {
if g.logPath == "" { if g.logFile == "" {
log.Debugf("skipping empty log file in debug bundle") log.Debugf("skipping empty log file in debug bundle")
return nil return nil
} }
logDir := filepath.Dir(g.logPath) logDir := filepath.Dir(g.logFile)
if err := g.addSingleLogfile(g.logPath, clientLogFile); err != nil { if err := g.addSingleLogfile(g.logFile, clientLogFile); err != nil {
return fmt.Errorf("add client log file to zip: %w", err) return fmt.Errorf("add client log file to zip: %w", err)
} }

View File

@@ -1,53 +0,0 @@
//go:build darwin && !ios
package debug
import (
"bytes"
"context"
"fmt"
"os/exec"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
// addDNSInfo collects and adds DNS configuration information to the archive
func (g *BundleGenerator) addDNSInfo() error {
if err := g.addResolvConf(); err != nil {
log.Errorf("failed to add resolv.conf: %v", err)
}
if err := g.addScutilDNS(); err != nil {
log.Errorf("failed to add scutil DNS output: %v", err)
}
return nil
}
func (g *BundleGenerator) addScutilDNS() error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "scutil", "--dns")
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("execute scutil --dns: %w", err)
}
if len(bytes.TrimSpace(output)) == 0 {
return fmt.Errorf("no scutil DNS output")
}
content := string(output)
if g.anonymize {
content = g.anonymizer.AnonymizeString(content)
}
if err := g.addFileToZip(strings.NewReader(content), "scutil_dns.txt"); err != nil {
return fmt.Errorf("add scutil DNS output to zip: %w", err)
}
return nil
}

View File

@@ -5,7 +5,3 @@ package debug
func (g *BundleGenerator) addRoutes() error { func (g *BundleGenerator) addRoutes() error {
return nil return nil
} }
func (g *BundleGenerator) addDNSInfo() error {
return nil
}

View File

@@ -1,16 +0,0 @@
//go:build unix && !darwin && !android
package debug
import (
log "github.com/sirupsen/logrus"
)
// addDNSInfo collects and adds DNS configuration information to the archive
func (g *BundleGenerator) addDNSInfo() error {
if err := g.addResolvConf(); err != nil {
log.Errorf("failed to add resolv.conf: %v", err)
}
return nil
}

View File

@@ -1,7 +0,0 @@
//go:build !unix
package debug
func (g *BundleGenerator) addDNSInfo() error {
return nil
}

View File

@@ -1,29 +0,0 @@
//go:build unix && !android
package debug
import (
"fmt"
"os"
"strings"
)
const resolvConfPath = "/etc/resolv.conf"
func (g *BundleGenerator) addResolvConf() error {
data, err := os.ReadFile(resolvConfPath)
if err != nil {
return fmt.Errorf("read %s: %w", resolvConfPath, err)
}
content := string(data)
if g.anonymize {
content = g.anonymizer.AnonymizeString(content)
}
if err := g.addFileToZip(strings.NewReader(content), "resolv.conf"); err != nil {
return fmt.Errorf("add resolv.conf to zip: %w", err)
}
return nil
}

View File

@@ -1,101 +0,0 @@
package debug
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"github.com/netbirdio/netbird/upload-server/types"
)
const maxBundleUploadSize = 50 * 1024 * 1024
func UploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
response, err := getUploadURL(ctx, url, managementURL)
if err != nil {
return "", err
}
err = upload(ctx, filePath, response)
if err != nil {
return "", err
}
return response.Key, nil
}
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
fileData, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("open file: %w", err)
}
defer fileData.Close()
stat, err := fileData.Stat()
if err != nil {
return fmt.Errorf("stat file: %w", err)
}
if stat.Size() > maxBundleUploadSize {
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
}
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
if err != nil {
return fmt.Errorf("create PUT request: %w", err)
}
req.ContentLength = stat.Size()
req.Header.Set("Content-Type", "application/octet-stream")
putResp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("upload failed: %v", err)
}
defer putResp.Body.Close()
if putResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(putResp.Body)
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
}
return nil
}
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
id := getURLHash(managementURL)
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
if err != nil {
return nil, fmt.Errorf("create GET request: %w", err)
}
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
resp, err := http.DefaultClient.Do(getReq)
if err != nil {
return nil, fmt.Errorf("get presigned URL: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
}
urlBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body: %w", err)
}
var response types.GetURLResponse
if err := json.Unmarshal(urlBytes, &response); err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
return &response, nil
}
func getURLHash(url string) string {
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
}

View File

@@ -38,8 +38,6 @@ type DeviceAuthProviderConfig struct {
Scope string Scope string
// UseIDToken indicates if the id token should be used for authentication // UseIDToken indicates if the id token should be used for authentication
UseIDToken bool UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
} }
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it // GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it

View File

@@ -335,7 +335,7 @@ func TestUpdateDNSServer(t *testing.T) {
for n, testCase := range testCases { for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
privKey, _ := wgtypes.GenerateKey() privKey, _ := wgtypes.GenerateKey()
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet(nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -434,7 +434,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true") t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"}) newNet, err := stdnet.NewNet([]string{"utun2301"})
if err != nil { if err != nil {
t.Errorf("create stdnet: %v", err) t.Errorf("create stdnet: %v", err)
return return
@@ -915,7 +915,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true") t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"}) newNet, err := stdnet.NewNet([]string{"utun2301"})
if err != nil { if err != nil {
t.Fatalf("create stdnet: %v", err) t.Fatalf("create stdnet: %v", err)
return nil, err return nil, err

View File

@@ -31,7 +31,6 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/dnsfwd"
@@ -49,7 +48,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/jobexec"
cProto "github.com/netbirdio/netbird/client/proto" cProto "github.com/netbirdio/netbird/client/proto"
sshconfig "github.com/netbirdio/netbird/client/ssh/config" sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
@@ -134,11 +132,6 @@ type EngineConfig struct {
LazyConnectionEnabled bool LazyConnectionEnabled bool
MTU uint16 MTU uint16
// for debug bundle generation
ProfileConfig *profilemanager.Config
LogPath string
} }
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -202,8 +195,7 @@ type Engine struct {
stateManager *statemanager.Manager stateManager *statemanager.Manager
srWatcher *guard.SRWatcher srWatcher *guard.SRWatcher
// Sync response persistence (protected by syncRespMux) // Sync response persistence
syncRespMux sync.RWMutex
persistSyncResponse bool persistSyncResponse bool
latestSyncResponse *mgmProto.SyncResponse latestSyncResponse *mgmProto.SyncResponse
connSemaphore *semaphoregroup.SemaphoreGroup connSemaphore *semaphoregroup.SemaphoreGroup
@@ -216,9 +208,6 @@ type Engine struct {
shutdownWg sync.WaitGroup shutdownWg sync.WaitGroup
probeStunTurn *relay.StunTurnProbe probeStunTurn *relay.StunTurnProbe
jobExecutor *jobexec.Executor
jobExecutorWG sync.WaitGroup
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -232,7 +221,17 @@ type localIpUpdater interface {
} }
// NewEngine creates a new Connection Engine with probes attached // NewEngine creates a new Connection Engine with probes attached
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, c *profilemanager.Config) *Engine { func NewEngine(
clientCtx context.Context,
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
) *Engine {
engine := &Engine{ engine := &Engine{
clientCtx: clientCtx, clientCtx: clientCtx,
clientCancel: clientCancel, clientCancel: clientCancel,
@@ -251,7 +250,6 @@ func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signa
checks: checks, checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
} }
sm := profilemanager.NewServiceManager("") sm := profilemanager.NewServiceManager("")
@@ -332,8 +330,6 @@ func (e *Engine) Stop() error {
e.cancel() e.cancel()
} }
e.jobExecutorWG.Wait() // block until job goroutines finish
e.close() e.close()
// stop flow manager after wg interface is gone // stop flow manager after wg interface is gone
@@ -520,7 +516,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.receiveSignalEvents() e.receiveSignalEvents()
e.receiveManagementEvents() e.receiveManagementEvents()
e.receiveJobEvents()
// starting network monitor at the very last to avoid disruptions // starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor() e.startNetworkMonitor()
@@ -798,18 +793,9 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil return nil
} }
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
// Read the storage-enabled flag under the syncRespMux too.
e.syncRespMux.RLock()
enabled := e.persistSyncResponse
e.syncRespMux.RUnlock()
// Store sync response if persistence is enabled // Store sync response if persistence is enabled
if enabled { if e.persistSyncResponse {
e.syncRespMux.Lock()
e.latestSyncResponse = update e.latestSyncResponse = update
e.syncRespMux.Unlock()
log.Debugf("sync response persisted with serial %d", nm.GetSerial()) log.Debugf("sync response persisted with serial %d", nm.GetSerial())
} }
@@ -939,77 +925,6 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return nil return nil
} }
func (e *Engine) receiveJobEvents() {
e.jobExecutorWG.Add(1)
go func() {
defer e.jobExecutorWG.Done()
err := e.mgmClient.Job(e.ctx, func(msg *mgmProto.JobRequest) *mgmProto.JobResponse {
resp := mgmProto.JobResponse{
ID: msg.ID,
Status: mgmProto.JobStatus_failed,
}
switch params := msg.WorkloadParameters.(type) {
case *mgmProto.JobRequest_Bundle:
bundleResult, err := e.handleBundle(params.Bundle)
if err != nil {
log.Errorf("handling bundle: %v", err)
resp.Reason = []byte(err.Error())
return &resp
}
resp.Status = mgmProto.JobStatus_succeeded
resp.WorkloadResults = bundleResult
return &resp
default:
resp.Reason = []byte(jobexec.ErrJobNotImplemented.Error())
return &resp
}
})
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return
}
log.Info("stopped receiving jobs from Management Service")
}()
log.Info("connecting to Management Service jobs stream")
}
func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobResponse_Bundle, error) {
log.Infof("handle remote debug bundle request: %s", params.String())
syncResponse, err := e.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
bundleDeps := debug.GeneratorDependencies{
InternalConfig: e.config.ProfileConfig,
StatusRecorder: e.statusRecorder,
SyncResponse: syncResponse,
LogPath: e.config.LogPath,
}
bundleJobParams := debug.BundleConfig{
Anonymize: params.Anonymize,
IncludeSystemInfo: true,
LogFileCount: uint32(params.LogFileCount),
}
waitFor := time.Duration(params.BundleForTime) * time.Minute
uploadKey, err := e.jobExecutor.BundleJob(e.ctx, bundleDeps, bundleJobParams, waitFor, e.config.ProfileConfig.ManagementURL.String())
if err != nil {
return nil, err
}
response := &mgmProto.JobResponse_Bundle{
Bundle: &mgmProto.BundleResult{
UploadKey: uploadKey,
},
}
return response, nil
}
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service // receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it. // E.g. when a new peer has been registered and we are allowed to connect to it.
@@ -1277,7 +1192,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
} }
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config { func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
forwarderPort := uint16(protoDNSConfig.GetForwarderPort()) //nolint forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
if forwarderPort == 0 { if forwarderPort == 0 {
forwarderPort = nbdns.ForwarderClientPort forwarderPort = nbdns.ForwarderClientPort
} }
@@ -1870,8 +1785,8 @@ func (e *Engine) stopDNSServer() {
// SetSyncResponsePersistence enables or disables sync response persistence // SetSyncResponsePersistence enables or disables sync response persistence
func (e *Engine) SetSyncResponsePersistence(enabled bool) { func (e *Engine) SetSyncResponsePersistence(enabled bool) {
e.syncRespMux.Lock() e.syncMsgMux.Lock()
defer e.syncRespMux.Unlock() defer e.syncMsgMux.Unlock()
if enabled == e.persistSyncResponse { if enabled == e.persistSyncResponse {
return return
@@ -1886,22 +1801,20 @@ func (e *Engine) SetSyncResponsePersistence(enabled bool) {
// GetLatestSyncResponse returns the stored sync response if persistence is enabled // GetLatestSyncResponse returns the stored sync response if persistence is enabled
func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) { func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
e.syncRespMux.RLock() e.syncMsgMux.Lock()
enabled := e.persistSyncResponse defer e.syncMsgMux.Unlock()
latest := e.latestSyncResponse
e.syncRespMux.RUnlock()
if !enabled { if !e.persistSyncResponse {
return nil, errors.New("sync response persistence is disabled") return nil, errors.New("sync response persistence is disabled")
} }
if latest == nil { if e.latestSyncResponse == nil {
//nolint:nilnil //nolint:nilnil
return nil, nil return nil, nil
} }
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(latest)) log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(e.latestSyncResponse))
sr, ok := proto.Clone(latest).(*mgmProto.SyncResponse) sr, ok := proto.Clone(e.latestSyncResponse).(*mgmProto.SyncResponse)
if !ok { if !ok {
return nil, fmt.Errorf("failed to clone sync response") return nil, fmt.Errorf("failed to clone sync response")
} }

View File

@@ -19,7 +19,6 @@ import (
type sshServer interface { type sshServer interface {
Start(ctx context.Context, addr netip.AddrPort) error Start(ctx context.Context, addr netip.AddrPort) error
Stop() error Stop() error
GetStatus() (bool, []sshserver.SessionInfo)
} }
func (e *Engine) setupSSHPortRedirection() error { func (e *Engine) setupSSHPortRedirection() error {
@@ -235,17 +234,7 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
server.SetNetstackNet(netstackNet) server.SetNetstackNet(netstackNet)
}
e.configureSSHServer(server)
if err := server.Start(e.ctx, listenAddr); err != nil {
return fmt.Errorf("start SSH server: %w", err)
}
e.sshServer = server
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface { if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16) RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok { }); ok {
@@ -254,10 +243,17 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
} }
} }
e.configureSSHServer(server)
e.sshServer = server
if err := e.setupSSHPortRedirection(); err != nil { if err := e.setupSSHPortRedirection(); err != nil {
log.Warnf("failed to setup SSH port redirection: %v", err) log.Warnf("failed to setup SSH port redirection: %v", err)
} }
if err := server.Start(e.ctx, listenAddr); err != nil {
return fmt.Errorf("start SSH server: %w", err)
}
return nil return nil
} }
@@ -340,16 +336,3 @@ func (e *Engine) stopSSHServer() error {
} }
return nil return nil
} }
// GetSSHServerStatus returns the SSH server status and active sessions
func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.SessionInfo) {
e.syncMsgMux.Lock()
sshServer := e.sshServer
e.syncMsgMux.Unlock()
if sshServer == nil {
return false, nil
}
return sshServer.GetStatus()
}

View File

@@ -7,5 +7,5 @@ import (
) )
func (e *Engine) newStdNet() (*stdnet.Net, error) { func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(e.clientCtx, e.config.IFaceBlackList) return stdnet.NewNet(e.config.IFaceBlackList)
} }

View File

@@ -3,5 +3,5 @@ package internal
import "github.com/netbirdio/netbird/client/internal/stdnet" import "github.com/netbirdio/netbird/client/internal/stdnet"
func (e *Engine) newStdNet() (*stdnet.Net, error) { func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(e.clientCtx, e.mobileDep.IFaceDiscover, e.config.IFaceBlackList) return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
} }

View File

@@ -14,6 +14,7 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -24,15 +25,8 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
@@ -249,7 +243,7 @@ func TestEngine_SSH(t *testing.T) {
}, },
MobileDependency{}, MobileDependency{},
peer.NewRecorder("https://mgm"), peer.NewRecorder("https://mgm"),
nil, nil, nil,
) )
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
@@ -411,13 +405,21 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
defer cancel() defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ engine := NewEngine(
WgIfaceName: "utun102", ctx, cancel,
WgAddr: "100.64.0.1/24", &signal.MockClient{},
WgPrivateKey: key, &mgmt.MockClient{},
WgPort: 33100, relayMgr,
MTU: iface.DefaultMTU, &EngineConfig{
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil)
wgIface := &MockWGIface{ wgIface := &MockWGIface{
NameFunc: func() string { return "utun102" }, NameFunc: func() string { return "utun102" },
@@ -636,7 +638,7 @@ func TestEngine_Sync(t *testing.T) {
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
MTU: iface.DefaultMTU, MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx engine.ctx = ctx
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
@@ -801,9 +803,9 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
MTU: iface.DefaultMTU, MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -1003,10 +1005,10 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
MTU: iface.DefaultMTU, MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -1529,7 +1531,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
} }
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e.ctx = ctx e.ctx = ctx
return e, err return e, err
} }
@@ -1588,7 +1590,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
jobManager := job.NewJobManager(nil, store) peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil { if err != nil {
return nil, "", err return nil, "", err
@@ -1616,16 +1618,13 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
updateManager := update_channel.NewPeersUpdateManager(metrics) accountManager, err := server.BuildManager(context.Background(), config, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, jobManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -78,7 +78,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) { func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
log.Debugf("Gathering ICE candidates") log.Debugf("Gathering ICE candidates")
agent, err := icemaker.NewAgent(ctx, cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd) agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
if err != nil { if err != nil {
return false, fmt.Errorf("create ICE agent: %w", err) return false, fmt.Errorf("create ICE agent: %w", err)
} }

View File

@@ -1,7 +1,6 @@
package ice package ice
import ( import (
"context"
"sync" "sync"
"time" "time"
@@ -23,8 +22,6 @@ const (
iceFailedTimeoutDefault = 6 * time.Second iceFailedTimeoutDefault = 6 * time.Second
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
iceRelayAcceptanceMinWaitDefault = 2 * time.Second iceRelayAcceptanceMinWaitDefault = 2 * time.Second
// iceAgentCloseTimeout is the maximum time to wait for ICE agent close to complete
iceAgentCloseTimeout = 3 * time.Second
) )
type ThreadSafeAgent struct { type ThreadSafeAgent struct {
@@ -35,28 +32,18 @@ type ThreadSafeAgent struct {
func (a *ThreadSafeAgent) Close() error { func (a *ThreadSafeAgent) Close() error {
var err error var err error
a.once.Do(func() { a.once.Do(func() {
done := make(chan error, 1) err = a.Agent.Close()
go func() {
done <- a.Agent.Close()
}()
select {
case err = <-done:
case <-time.After(iceAgentCloseTimeout):
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
err = nil
}
}) })
return err return err
} }
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) { func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
iceKeepAlive := iceKeepAlive() iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout() iceDisconnectedTimeout := iceDisconnectedTimeout()
iceFailedTimeout := iceFailedTimeout() iceFailedTimeout := iceFailedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
transportNet, err := newStdNet(ctx, iFaceDiscover, config.InterfaceBlackList) transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList)
if err != nil { if err != nil {
log.Errorf("failed to create pion's stdnet: %s", err) log.Errorf("failed to create pion's stdnet: %s", err)
} }

View File

@@ -3,11 +3,9 @@
package ice package ice
import ( import (
"context"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
) )
func newStdNet(ctx context.Context, _ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNet(ctx, ifaceBlacklist) return stdnet.NewNet(ifaceBlacklist)
} }

View File

@@ -1,11 +1,7 @@
package ice package ice
import ( import "github.com/netbirdio/netbird/client/internal/stdnet"
"context"
"github.com/netbirdio/netbird/client/internal/stdnet" func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
) return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist)
func newStdNet(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(ctx, iFaceDiscover, ifaceBlacklist)
} }

View File

@@ -209,7 +209,7 @@ func (w *WorkerICE) Close() {
} }
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) { func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil { if err != nil {
return nil, fmt.Errorf("create agent: %w", err) return nil, fmt.Errorf("create agent: %w", err)
} }
@@ -411,7 +411,7 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
if isController(w.config) { if isController(w.config) {
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else { } else {
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} }

View File

@@ -44,8 +44,6 @@ type PKCEAuthProviderConfig struct {
DisablePromptLogin bool DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior // LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
} }
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it // GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it

View File

@@ -55,7 +55,6 @@ type ConfigInput struct {
EnableSSHLocalPortForwarding *bool EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool DisableSSHAuth *bool
SSHJWTCacheTTL *int
NATExternalIPs []string NATExternalIPs []string
CustomDNSAddress []byte CustomDNSAddress []byte
RosenpassEnabled *bool RosenpassEnabled *bool
@@ -105,7 +104,6 @@ type Config struct {
EnableSSHLocalPortForwarding *bool EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool DisableSSHAuth *bool
SSHJWTCacheTTL *int
DisableClientRoutes bool DisableClientRoutes bool
DisableServerRoutes bool DisableServerRoutes bool
@@ -438,12 +436,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
updated = true
}
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval { if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
log.Infof("updating DNS route interval to %s (old value %s)", log.Infof("updating DNS route interval to %s (old value %s)",
input.DNSRouteInterval.String(), config.DNSRouteInterval.String()) input.DNSRouteInterval.String(), config.DNSRouteInterval.String())

View File

@@ -132,21 +132,3 @@ func (pm *ProfileManager) setActiveProfileState(profileName string) error {
return nil return nil
} }
// GetLoginHint retrieves the email from the active profile to use as login_hint.
func GetLoginHint() string {
pm := NewProfileManager()
activeProf, err := pm.GetActiveProfile()
if err != nil {
log.Debugf("failed to get active profile for login hint: %v", err)
return ""
}
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
return ""
}
return profileState.Email
}

View File

@@ -197,7 +197,7 @@ func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr stri
} }
}() }()
net, err := stdnet.NewNet(ctx, nil) net, err := stdnet.NewNet(nil)
if err != nil { if err != nil {
probeErr = fmt.Errorf("new net: %w", err) probeErr = fmt.Errorf("new net: %w", err)
return return
@@ -286,7 +286,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri
} }
}() }()
net, err := stdnet.NewNet(ctx, nil) net, err := stdnet.NewNet(nil)
if err != nil { if err != nil {
probeErr = fmt.Errorf("new net: %w", err) probeErr = fmt.Errorf("new net: %w", err)
return return

View File

@@ -24,6 +24,7 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/client" "github.com/netbirdio/netbird/client/internal/routemanager/client"
@@ -38,7 +39,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/client/net"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"

View File

@@ -6,7 +6,7 @@ import (
"net/netip" "net/netip"
"testing" "testing"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/pion/transport/v3/stdnet"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -403,7 +403,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
for n, testCase := range testCases { for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
peerPrivateKey, _ := wgtypes.GeneratePrivateKey() peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -15,7 +15,7 @@ import (
"syscall" "syscall"
"testing" "testing"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/pion/transport/v3/stdnet"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -436,7 +436,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
peerPrivateKey, err := wgtypes.GeneratePrivateKey() peerPrivateKey, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err) require.NoError(t, err)
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
require.NoError(t, err) require.NoError(t, err)
opts := iface.WGIFaceOpts{ opts := iface.WGIFaceOpts{

View File

@@ -4,28 +4,17 @@
package stdnet package stdnet
import ( import (
"context"
"errors"
"fmt" "fmt"
"net"
"net/netip"
"slices" "slices"
"strconv"
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/stdnet"
"github.com/netbirdio/netbird/client/iface/netstack"
) )
const ( const updateInterval = 30 * time.Second
updateInterval = 30 * time.Second
dnsResolveTimeout = 30 * time.Second
)
var errNoSuitableAddress = errors.New("no suitable address found")
// Net is an implementation of the net.Net interface // Net is an implementation of the net.Net interface
// based on functions of the standard net package. // based on functions of the standard net package.
@@ -39,19 +28,12 @@ type Net struct {
// mu is shared between interfaces and lastUpdate // mu is shared between interfaces and lastUpdate
mu sync.Mutex mu sync.Mutex
// ctx is the context for network operations that supports cancellation
ctx context.Context
} }
// NewNetWithDiscover creates a new StdNet instance. // NewNetWithDiscover creates a new StdNet instance.
func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) { func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
if ctx == nil {
ctx = context.Background()
}
n := &Net{ n := &Net{
interfaceFilter: InterfaceFilter(disallowList), interfaceFilter: InterfaceFilter(disallowList),
ctx: ctx,
} }
// current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client // current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client
// so in android cli use pionDiscover // so in android cli use pionDiscover
@@ -64,64 +46,14 @@ func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover
} }
// NewNet creates a new StdNet instance. // NewNet creates a new StdNet instance.
func NewNet(ctx context.Context, disallowList []string) (*Net, error) { func NewNet(disallowList []string) (*Net, error) {
if ctx == nil {
ctx = context.Background()
}
n := &Net{ n := &Net{
iFaceDiscover: pionDiscover{}, iFaceDiscover: pionDiscover{},
interfaceFilter: InterfaceFilter(disallowList), interfaceFilter: InterfaceFilter(disallowList),
ctx: ctx,
} }
return n, n.UpdateInterfaces() return n, n.UpdateInterfaces()
} }
// resolveAddr performs DNS resolution with context support and timeout.
func (n *Net) resolveAddr(network, address string) (netip.AddrPort, error) {
host, portStr, err := net.SplitHostPort(address)
if err != nil {
return netip.AddrPort{}, err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return netip.AddrPort{}, fmt.Errorf("invalid port: %w", err)
}
if port < 0 || port > 65535 {
return netip.AddrPort{}, fmt.Errorf("invalid port: %d", port)
}
ipNet := "ip"
switch network {
case "tcp4", "udp4":
ipNet = "ip4"
case "tcp6", "udp6":
ipNet = "ip6"
}
if host == "" {
addr := netip.IPv4Unspecified()
if ipNet == "ip6" {
addr = netip.IPv6Unspecified()
}
return netip.AddrPortFrom(addr, uint16(port)), nil
}
ctx, cancel := context.WithTimeout(n.ctx, dnsResolveTimeout)
defer cancel()
addrs, err := net.DefaultResolver.LookupNetIP(ctx, ipNet, host)
if err != nil {
return netip.AddrPort{}, err
}
if len(addrs) == 0 {
return netip.AddrPort{}, errNoSuitableAddress
}
return netip.AddrPortFrom(addrs[0], uint16(port)), nil
}
// UpdateInterfaces updates the internal list of network interfaces // UpdateInterfaces updates the internal list of network interfaces
// and associated addresses filtering them by name. // and associated addresses filtering them by name.
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one // The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
@@ -205,39 +137,3 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I
} }
return result return result
} }
// ResolveUDPAddr resolves UDP addresses with context support and timeout.
func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) {
switch network {
case "udp", "udp4", "udp6":
case "":
network = "udp"
default:
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
}
addrPort, err := n.resolveAddr(network, address)
if err != nil {
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.UDPAddr{IP: nil}, Err: err}
}
return net.UDPAddrFromAddrPort(addrPort), nil
}
// ResolveTCPAddr resolves TCP addresses with context support and timeout.
func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) {
switch network {
case "tcp", "tcp4", "tcp6":
case "":
network = "tcp"
default:
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
}
addrPort, err := n.resolveAddr(network, address)
if err != nil {
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.TCPAddr{IP: nil}, Err: err}
}
return net.TCPAddrFromAddrPort(addrPort), nil
}

File diff suppressed because one or more lines are too long

View File

@@ -1,299 +0,0 @@
package templates
import (
"html/template"
"os"
"path/filepath"
"testing"
)
func TestPKCEAuthMsgTemplate(t *testing.T) {
tests := []struct {
name string
data map[string]string
outputFile string
expectedTitle string
expectedInContent []string
notExpectedInContent []string
}{
{
name: "error_state",
data: map[string]string{
"Error": "authentication failed: invalid state",
},
outputFile: "pkce-auth-error.html",
expectedTitle: "Login Failed",
expectedInContent: []string{
"authentication failed: invalid state",
"Login Failed",
},
notExpectedInContent: []string{
"Login Successful",
"Your device is now registered and logged in to NetBird",
},
},
{
name: "success_state",
data: map[string]string{
// No error field means success
},
outputFile: "pkce-auth-success.html",
expectedTitle: "Login Successful",
expectedInContent: []string{
"Login Successful",
"Your device is now registered and logged in to NetBird. You can now close this window.",
},
notExpectedInContent: []string{
"Login Failed",
},
},
{
name: "error_state_timeout",
data: map[string]string{
"Error": "authentication timeout: request expired after 5 minutes",
},
outputFile: "pkce-auth-timeout.html",
expectedTitle: "Login Failed",
expectedInContent: []string{
"authentication timeout: request expired after 5 minutes",
"Login Failed",
},
notExpectedInContent: []string{
"Login Successful",
"Your device is now registered and logged in to NetBird",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Parse the template
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
// Create temp directory for this test
tempDir := t.TempDir()
outputPath := filepath.Join(tempDir, tt.outputFile)
// Create output file
file, err := os.Create(outputPath)
if err != nil {
t.Fatalf("Failed to create output file: %v", err)
}
// Execute the template
if err := tmpl.Execute(file, tt.data); err != nil {
file.Close()
t.Fatalf("Failed to execute template: %v", err)
}
file.Close()
t.Logf("Generated test output: %s", outputPath)
// Read the generated file
content, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("Failed to read output file: %v", err)
}
contentStr := string(content)
// Verify file has content
if len(contentStr) == 0 {
t.Error("Output file is empty")
}
// Verify basic HTML structure
basicElements := []string{
"<!DOCTYPE html>",
"<html",
"<head>",
"<body>",
"NetBird",
}
for _, elem := range basicElements {
if !contains(contentStr, elem) {
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
}
}
// Verify expected title
if !contains(contentStr, tt.expectedTitle) {
t.Errorf("Expected HTML to contain title '%s', but it was not found", tt.expectedTitle)
}
// Verify expected content is present
for _, expected := range tt.expectedInContent {
if !contains(contentStr, expected) {
t.Errorf("Expected HTML to contain '%s', but it was not found", expected)
}
}
// Verify unexpected content is not present
for _, notExpected := range tt.notExpectedInContent {
if contains(contentStr, notExpected) {
t.Errorf("Expected HTML to NOT contain '%s', but it was found", notExpected)
}
}
})
}
}
func TestPKCEAuthMsgTemplateValidation(t *testing.T) {
// Test that the template can be parsed without errors
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
if err != nil {
t.Fatalf("Template parsing failed: %v", err)
}
// Test with empty data
t.Run("empty_data", func(t *testing.T) {
tempDir := t.TempDir()
outputPath := filepath.Join(tempDir, "empty-data.html")
file, err := os.Create(outputPath)
if err != nil {
t.Fatalf("Failed to create output file: %v", err)
}
defer file.Close()
if err := tmpl.Execute(file, nil); err != nil {
t.Errorf("Template execution with nil data failed: %v", err)
}
})
// Test with error data
t.Run("with_error", func(t *testing.T) {
tempDir := t.TempDir()
outputPath := filepath.Join(tempDir, "with-error.html")
file, err := os.Create(outputPath)
if err != nil {
t.Fatalf("Failed to create output file: %v", err)
}
defer file.Close()
data := map[string]string{
"Error": "test error message",
}
if err := tmpl.Execute(file, data); err != nil {
t.Errorf("Template execution with error data failed: %v", err)
}
})
}
func TestPKCEAuthMsgTemplateContent(t *testing.T) {
// Test that the template contains expected elements
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
if err != nil {
t.Fatalf("Template parsing failed: %v", err)
}
t.Run("success_content", func(t *testing.T) {
tempDir := t.TempDir()
outputPath := filepath.Join(tempDir, "success.html")
file, err := os.Create(outputPath)
if err != nil {
t.Fatalf("Failed to create output file: %v", err)
}
defer file.Close()
data := map[string]string{}
if err := tmpl.Execute(file, data); err != nil {
t.Fatalf("Template execution failed: %v", err)
}
// Read the file and verify it contains expected content
content, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("Failed to read output file: %v", err)
}
// Check for success indicators
contentStr := string(content)
if len(contentStr) == 0 {
t.Error("Generated HTML is empty")
}
// Basic HTML structure checks
requiredElements := []string{
"<!DOCTYPE html>",
"<html",
"<head>",
"<body>",
"Login Successful",
"NetBird",
}
for _, elem := range requiredElements {
if !contains(contentStr, elem) {
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
}
}
})
t.Run("error_content", func(t *testing.T) {
tempDir := t.TempDir()
outputPath := filepath.Join(tempDir, "error.html")
file, err := os.Create(outputPath)
if err != nil {
t.Fatalf("Failed to create output file: %v", err)
}
defer file.Close()
errorMsg := "test error message"
data := map[string]string{
"Error": errorMsg,
}
if err := tmpl.Execute(file, data); err != nil {
t.Fatalf("Template execution failed: %v", err)
}
// Read the file and verify it contains expected content
content, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("Failed to read output file: %v", err)
}
// Check for error indicators
contentStr := string(content)
if len(contentStr) == 0 {
t.Error("Generated HTML is empty")
}
// Basic HTML structure checks
requiredElements := []string{
"<!DOCTYPE html>",
"<html",
"<head>",
"<body>",
"Login Failed",
errorMsg,
}
for _, elem := range requiredElements {
if !contains(contentStr, elem) {
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
}
}
})
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
(len(s) > 0 && len(substr) > 0 && containsHelper(s, substr)))
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View File

@@ -228,7 +228,7 @@ func (c *Client) LoginForMobile() string {
ConfigPath: c.cfgFile, ConfigPath: c.cfgFile,
}) })
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "") oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false)
if err != nil { if err != nil {
return err.Error() return err.Error()
} }

View File

@@ -1,66 +0,0 @@
package jobexec
import (
"context"
"errors"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/upload-server/types"
)
const (
MaxBundleWaitTime = 60 * time.Minute // maximum wait time for bundle generation (1 hour)
)
var (
ErrJobNotImplemented = errors.New("job not implemented")
)
type Executor struct {
}
func NewExecutor() *Executor {
return &Executor{}
}
func (e *Executor) BundleJob(ctx context.Context, debugBundleDependencies debug.GeneratorDependencies, params debug.BundleConfig, waitForDuration time.Duration, mgmURL string) (string, error) {
if waitForDuration > MaxBundleWaitTime {
log.Warnf("bundle wait time %v exceeds maximum %v, capping to maximum", waitForDuration, MaxBundleWaitTime)
waitForDuration = MaxBundleWaitTime
}
if waitForDuration > 0 {
waitFor(ctx, waitForDuration)
}
log.Infof("execute debug bundle generation")
bundleGenerator := debug.NewBundleGenerator(debugBundleDependencies, params)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
key, err := debug.UploadDebugBundle(ctx, types.DefaultBundleURL, mgmURL, path)
if err != nil {
log.Errorf("failed to upload debug bundle: %v", err)
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle has been generated well")
return key, nil
}
func waitFor(ctx context.Context, duration time.Duration) {
log.Infof("wait for %v minutes before executing debug bundle", duration.Minutes())
select {
case <-time.After(duration):
case <-ctx.Done():
log.Infof("wait cancelled: %v", ctx.Err())
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -90,7 +90,7 @@ service DaemonService {
// RequestJWTAuth initiates JWT authentication flow for SSH // RequestJWTAuth initiates JWT authentication flow for SSH
rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {} rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {}
// WaitJWTToken waits for JWT authentication completion // WaitJWTToken waits for JWT authentication completion
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {} rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
} }
@@ -168,15 +168,11 @@ message LoginRequest {
optional int64 mtu = 32; optional int64 mtu = 32;
// hint is used to pre-fill the email/username field during SSO authentication optional bool enableSSHRoot = 33;
optional string hint = 33; optional bool enableSSHSFTP = 34;
optional bool enableSSHLocalPortForwarding = 35;
optional bool enableSSHRoot = 34; optional bool enableSSHRemotePortForwarding = 36;
optional bool enableSSHSFTP = 35; optional bool disableSSHAuth = 37;
optional bool enableSSHLocalPortForwarding = 36;
optional bool enableSSHRemotePortForwarding = 37;
optional bool disableSSHAuth = 38;
optional int32 sshJWTCacheTTL = 39;
} }
message LoginResponse { message LoginResponse {
@@ -206,7 +202,7 @@ message StatusRequest{
bool getFullPeerStatus = 1; bool getFullPeerStatus = 1;
bool shouldRunProbes = 2; bool shouldRunProbes = 2;
// the UI do not using this yet, but CLIs could use it to wait until the status is ready // the UI do not using this yet, but CLIs could use it to wait until the status is ready
optional bool waitForReady = 3; optional bool waitForReady = 3;
} }
message StatusResponse{ message StatusResponse{
@@ -281,8 +277,6 @@ message GetConfigResponse {
bool enableSSHRemotePortForwarding = 23; bool enableSSHRemotePortForwarding = 23;
bool disableSSHAuth = 25; bool disableSSHAuth = 25;
int32 sshJWTCacheTTL = 26;
} }
// PeerState contains the latest state of a peer // PeerState contains the latest state of a peer
@@ -346,20 +340,6 @@ message NSGroupState {
string error = 4; string error = 4;
} }
// SSHSessionInfo contains information about an active SSH session
message SSHSessionInfo {
string username = 1;
string remoteAddress = 2;
string command = 3;
string jwtUsername = 4;
}
// SSHServerState contains the latest state of the SSH server
message SSHServerState {
bool enabled = 1;
repeated SSHSessionInfo sessions = 2;
}
// FullStatus contains the full state held by the Status instance // FullStatus contains the full state held by the Status instance
message FullStatus { message FullStatus {
ManagementState managementState = 1; ManagementState managementState = 1;
@@ -373,7 +353,6 @@ message FullStatus {
repeated SystemEvent events = 7; repeated SystemEvent events = 7;
bool lazyConnectionEnabled = 9; bool lazyConnectionEnabled = 9;
SSHServerState sshServerState = 10;
} }
// Networks // Networks
@@ -434,6 +413,7 @@ message ForwardingRulesResponse {
// DebugBundler // DebugBundler
message DebugBundleRequest { message DebugBundleRequest {
bool anonymize = 1; bool anonymize = 1;
string status = 2;
bool systemInfo = 3; bool systemInfo = 3;
string uploadURL = 4; string uploadURL = 4;
uint32 logFileCount = 5; uint32 logFileCount = 5;
@@ -639,10 +619,9 @@ message SetConfigRequest {
optional bool enableSSHRoot = 29; optional bool enableSSHRoot = 29;
optional bool enableSSHSFTP = 30; optional bool enableSSHSFTP = 30;
optional bool enableSSHLocalPortForwarding = 31; optional bool enableSSHLocalPortForward = 31;
optional bool enableSSHRemotePortForwarding = 32; optional bool enableSSHRemotePortForward = 32;
optional bool disableSSHAuth = 33; optional bool disableSSHAuth = 33;
optional int32 sshJWTCacheTTL = 34;
} }
message SetConfigResponse{} message SetConfigResponse{}
@@ -715,8 +694,6 @@ message GetPeerSSHHostKeyResponse {
// RequestJWTAuthRequest for initiating JWT authentication flow // RequestJWTAuthRequest for initiating JWT authentication flow
message RequestJWTAuthRequest { message RequestJWTAuthRequest {
// hint for OIDC login_hint parameter (typically email address)
optional string hint = 1;
} }
// RequestJWTAuthResponse contains authentication flow information // RequestJWTAuthResponse contains authentication flow information

View File

@@ -4,16 +4,24 @@ package server
import ( import (
"context" "context"
"crypto/sha256"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http"
"os"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
) )
const maxBundleUploadSize = 50 * 1024 * 1024
// DebugBundle creates a debug bundle and returns the location. // DebugBundle creates a debug bundle and returns the location.
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) { func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
s.mutex.Lock() s.mutex.Lock()
@@ -29,10 +37,11 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
InternalConfig: s.config, InternalConfig: s.config,
StatusRecorder: s.statusRecorder, StatusRecorder: s.statusRecorder,
SyncResponse: syncResponse, SyncResponse: syncResponse,
LogPath: s.logFile, LogFile: s.logFile,
}, },
debug.BundleConfig{ debug.BundleConfig{
Anonymize: req.GetAnonymize(), Anonymize: req.GetAnonymize(),
ClientStatus: req.GetStatus(),
IncludeSystemInfo: req.GetSystemInfo(), IncludeSystemInfo: req.GetSystemInfo(),
LogFileCount: req.GetLogFileCount(), LogFileCount: req.GetLogFileCount(),
}, },
@@ -46,7 +55,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
if req.GetUploadURL() == "" { if req.GetUploadURL() == "" {
return &proto.DebugBundleResponse{Path: path}, nil return &proto.DebugBundleResponse{Path: path}, nil
} }
key, err := debug.UploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path) key, err := uploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
if err != nil { if err != nil {
log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err) log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err)
return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil
@@ -57,6 +66,92 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil
} }
func uploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
response, err := getUploadURL(ctx, url, managementURL)
if err != nil {
return "", err
}
err = upload(ctx, filePath, response)
if err != nil {
return "", err
}
return response.Key, nil
}
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
fileData, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("open file: %w", err)
}
defer fileData.Close()
stat, err := fileData.Stat()
if err != nil {
return fmt.Errorf("stat file: %w", err)
}
if stat.Size() > maxBundleUploadSize {
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
}
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
if err != nil {
return fmt.Errorf("create PUT request: %w", err)
}
req.ContentLength = stat.Size()
req.Header.Set("Content-Type", "application/octet-stream")
putResp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("upload failed: %v", err)
}
defer putResp.Body.Close()
if putResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(putResp.Body)
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
}
return nil
}
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
id := getURLHash(managementURL)
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
if err != nil {
return nil, fmt.Errorf("create GET request: %w", err)
}
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
resp, err := http.DefaultClient.Do(getReq)
if err != nil {
return nil, fmt.Errorf("get presigned URL: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
}
urlBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body: %w", err)
}
var response types.GetURLResponse
if err := json.Unmarshal(urlBytes, &response); err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
return &response, nil
}
func getURLHash(url string) string {
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
}
// GetLogLevel gets the current logging level for the server. // GetLogLevel gets the current logging level for the server.
func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) { func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) {
s.mutex.Lock() s.mutex.Lock()

View File

@@ -1,4 +1,4 @@
package debug package server
import ( import (
"context" "context"
@@ -38,7 +38,7 @@ func TestUpload(t *testing.T) {
fileContent := []byte("test file content") fileContent := []byte("test file content")
err := os.WriteFile(file, fileContent, 0640) err := os.WriteFile(file, fileContent, 0640)
require.NoError(t, err) require.NoError(t, err)
key, err := UploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file) key, err := uploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
require.NoError(t, err) require.NoError(t, err)
id := getURLHash(testURL) id := getURLHash(testURL)
require.Contains(t, key, id+"/") require.Contains(t, key, id+"/")

View File

@@ -37,18 +37,13 @@ func (c *jwtCache) store(token string, maxAge time.Duration) {
c.expiresAt = time.Now().Add(maxAge) c.expiresAt = time.Now().Add(maxAge)
var timer *time.Timer c.timer = time.AfterFunc(maxAge, func() {
timer = time.AfterFunc(maxAge, func() {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.timer != timer {
return
}
c.cleanup() c.cleanup()
c.timer = nil c.timer = nil
log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge) log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge)
}) })
c.timer = timer
} }
func (c *jwtCache) get() (string, bool) { func (c *jwtCache) get() (string, bool) {
@@ -75,5 +70,4 @@ func (c *jwtCache) cleanup() {
if c.enclave != nil { if c.enclave != nil {
c.enclave = nil c.enclave = nil
} }
c.expiresAt = time.Time{}
} }

View File

@@ -13,11 +13,15 @@ import (
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" "golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/types/known/durationpb"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
@@ -28,7 +32,6 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -43,8 +46,8 @@ const (
defaultMaxRetryTime = 14 * 24 * time.Hour defaultMaxRetryTime = 14 * 24 * time.Hour
defaultRetryMultiplier = 1.7 defaultRetryMultiplier = 1.7
// JWT token cache TTL for the client daemon (disabled by default) // JWT token cache TTL for the client daemon
defaultJWTCacheTTL = 0 defaultJWTCacheTTL = 5 * time.Minute
errRestoreResidualState = "failed to restore residual state: %v" errRestoreResidualState = "failed to restore residual state: %v"
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled" errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
@@ -378,15 +381,11 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.BlockInbound = msg.BlockInbound config.BlockInbound = msg.BlockInbound
config.EnableSSHRoot = msg.EnableSSHRoot config.EnableSSHRoot = msg.EnableSSHRoot
config.EnableSSHSFTP = msg.EnableSSHSFTP config.EnableSSHSFTP = msg.EnableSSHSFTP
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForwarding config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForward
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForwarding config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForward
if msg.DisableSSHAuth != nil { if msg.DisableSSHAuth != nil {
config.DisableSSHAuth = msg.DisableSSHAuth config.DisableSSHAuth = msg.DisableSSHAuth
} }
if msg.SshJWTCacheTTL != nil {
ttl := int(*msg.SshJWTCacheTTL)
config.SSHJWTCacheTTL = &ttl
}
if msg.Mtu != nil { if msg.Mtu != nil {
mtu := uint16(*msg.Mtu) mtu := uint16(*msg.Mtu)
@@ -497,11 +496,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
state.Set(internal.StatusConnecting) state.Set(internal.StatusConnecting)
if msg.SetupKey == "" { if msg.SetupKey == "" {
hint := "" oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient)
if msg.Hint != nil {
hint = *msg.Hint
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint)
if err != nil { if err != nil {
state.Set(internal.StatusLoginFailed) state.Set(internal.StatusLoginFailed)
return nil, err return nil, err
@@ -1077,49 +1072,14 @@ func (s *Server) Status(
if msg.GetFullPeerStatus { if msg.GetFullPeerStatus {
s.runProbes(msg.ShouldRunProbes) s.runProbes(msg.ShouldRunProbes)
fullStatus := s.statusRecorder.GetFullStatus() fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := nbstatus.ToProtoFullStatus(fullStatus) pbFullStatus := toProtoFullStatus(fullStatus)
pbFullStatus.Events = s.statusRecorder.GetEventHistory() pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState()
statusResponse.FullStatus = pbFullStatus statusResponse.FullStatus = pbFullStatus
} }
return &statusResponse, nil return &statusResponse, nil
} }
// getSSHServerState retrieves the current SSH server state including enabled status and active sessions
func (s *Server) getSSHServerState() *proto.SSHServerState {
s.mutex.Lock()
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil
}
engine := connectClient.Engine()
if engine == nil {
return nil
}
enabled, sessions := engine.GetSSHServerStatus()
sshServerState := &proto.SSHServerState{
Enabled: enabled,
}
for _, session := range sessions {
sshServerState.Sessions = append(sshServerState.Sessions, &proto.SSHSessionInfo{
Username: session.Username,
RemoteAddress: session.RemoteAddress,
Command: session.Command,
JwtUsername: session.JWTUsername,
})
}
return sshServerState
}
// GetPeerSSHHostKey retrieves SSH host key for a specific peer // GetPeerSSHHostKey retrieves SSH host key for a specific peer
func (s *Server) GetPeerSSHHostKey( func (s *Server) GetPeerSSHHostKey(
ctx context.Context, ctx context.Context,
@@ -1172,31 +1132,35 @@ func (s *Server) GetPeerSSHHostKey(
return response, nil return response, nil
} }
// getJWTCacheTTL returns the JWT cache TTL from config or default (disabled) // getJWTCacheTTL returns the JWT cache TTL from environment variable or default
func (s *Server) getJWTCacheTTL() time.Duration { // NB_SSH_JWT_CACHE_TTL=0 disables caching
s.mutex.Lock() // NB_SSH_JWT_CACHE_TTL=<seconds> sets custom cache TTL
config := s.config func getJWTCacheTTL() time.Duration {
s.mutex.Unlock() envValue := os.Getenv("NB_SSH_JWT_CACHE_TTL")
if envValue == "" {
if config == nil || config.SSHJWTCacheTTL == nil { return defaultJWTCacheTTL
}
seconds, err := strconv.Atoi(envValue)
if err != nil {
log.Warnf("invalid NB_SSH_JWT_CACHE_TTL value %s, using default: %v", envValue, defaultJWTCacheTTL)
return defaultJWTCacheTTL return defaultJWTCacheTTL
} }
seconds := *config.SSHJWTCacheTTL
if seconds == 0 { if seconds == 0 {
log.Debug("SSH JWT cache disabled (configured to 0)") log.Info("SSH JWT cache disabled via NB_SSH_JWT_CACHE_TTL=0")
return 0 return 0
} }
ttl := time.Duration(seconds) * time.Second ttl := time.Duration(seconds) * time.Second
log.Debugf("SSH JWT cache TTL set to %v from config", ttl) log.Infof("SSH JWT cache TTL set to %v via NB_SSH_JWT_CACHE_TTL", ttl)
return ttl return ttl
} }
// RequestJWTAuth initiates JWT authentication flow for SSH // RequestJWTAuth initiates JWT authentication flow for SSH
func (s *Server) RequestJWTAuth( func (s *Server) RequestJWTAuth(
ctx context.Context, ctx context.Context,
msg *proto.RequestJWTAuthRequest, _ *proto.RequestJWTAuthRequest,
) (*proto.RequestJWTAuthResponse, error) { ) (*proto.RequestJWTAuthResponse, error) {
if ctx.Err() != nil { if ctx.Err() != nil {
return nil, ctx.Err() return nil, ctx.Err()
@@ -1210,7 +1174,7 @@ func (s *Server) RequestJWTAuth(
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured") return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
} }
jwtCacheTTL := s.getJWTCacheTTL() jwtCacheTTL := getJWTCacheTTL()
if jwtCacheTTL > 0 { if jwtCacheTTL > 0 {
if cachedToken, found := s.jwtCache.get(); found { if cachedToken, found := s.jwtCache.get(); found {
log.Debugf("JWT token found in cache, returning cached token for SSH authentication") log.Debugf("JWT token found in cache, returning cached token for SSH authentication")
@@ -1222,17 +1186,8 @@ func (s *Server) RequestJWTAuth(
} }
} }
hint := ""
if msg.Hint != nil {
hint = *msg.Hint
}
if hint == "" {
hint = profilemanager.GetLoginHint()
}
isDesktop := isUnixRunningDesktop() isDesktop := isUnixRunningDesktop()
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, hint) oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop)
if err != nil { if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err) return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
} }
@@ -1283,7 +1238,7 @@ func (s *Server) WaitJWTToken(
token := tokenInfo.GetTokenToUse() token := tokenInfo.GetTokenToUse()
jwtCacheTTL := s.getJWTCacheTTL() jwtCacheTTL := getJWTCacheTTL()
if jwtCacheTTL > 0 { if jwtCacheTTL > 0 {
s.jwtCache.store(token, jwtCacheTTL) s.jwtCache.store(token, jwtCacheTTL)
log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL) log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL)
@@ -1374,33 +1329,28 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
blockLANAccess := cfg.BlockLANAccess blockLANAccess := cfg.BlockLANAccess
enableSSHRoot := false enableSSHRoot := false
if cfg.EnableSSHRoot != nil { if s.config.EnableSSHRoot != nil {
enableSSHRoot = *cfg.EnableSSHRoot enableSSHRoot = *s.config.EnableSSHRoot
} }
enableSSHSFTP := false enableSSHSFTP := false
if cfg.EnableSSHSFTP != nil { if s.config.EnableSSHSFTP != nil {
enableSSHSFTP = *cfg.EnableSSHSFTP enableSSHSFTP = *s.config.EnableSSHSFTP
} }
enableSSHLocalPortForwarding := false enableSSHLocalPortForwarding := false
if cfg.EnableSSHLocalPortForwarding != nil { if s.config.EnableSSHLocalPortForwarding != nil {
enableSSHLocalPortForwarding = *cfg.EnableSSHLocalPortForwarding enableSSHLocalPortForwarding = *s.config.EnableSSHLocalPortForwarding
} }
enableSSHRemotePortForwarding := false enableSSHRemotePortForwarding := false
if cfg.EnableSSHRemotePortForwarding != nil { if s.config.EnableSSHRemotePortForwarding != nil {
enableSSHRemotePortForwarding = *cfg.EnableSSHRemotePortForwarding enableSSHRemotePortForwarding = *s.config.EnableSSHRemotePortForwarding
} }
disableSSHAuth := false disableSSHAuth := false
if cfg.DisableSSHAuth != nil { if s.config.DisableSSHAuth != nil {
disableSSHAuth = *cfg.DisableSSHAuth disableSSHAuth = *s.config.DisableSSHAuth
}
sshJWTCacheTTL := int32(0)
if cfg.SSHJWTCacheTTL != nil {
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
} }
return &proto.GetConfigResponse{ return &proto.GetConfigResponse{
@@ -1427,7 +1377,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding, EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding, EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
DisableSSHAuth: disableSSHAuth, DisableSSHAuth: disableSSHAuth,
SshJWTCacheTTL: sshJWTCacheTTL,
}, nil }, nil
} }
@@ -1535,7 +1484,7 @@ func (s *Server) connect(ctx context.Context, config *profilemanager.Config, sta
log.Tracef("running client connection") log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
if err := s.connectClient.Run(runningChan, s.logFile); err != nil { if err := s.connectClient.Run(runningChan); err != nil {
return err return err
} }
return nil return nil
@@ -1609,6 +1558,94 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio
return defaultDuration return defaultDuration
} }
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
if err := fullStatus.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
if err := fullStatus.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fullStatus.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fullStatus.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
return &pbFullStatus
}
// sendTerminalNotification sends a terminal notification message // sendTerminalNotification sends a terminal notification message
// to inform the user that the NetBird connection session has expired. // to inform the user that the NetBird connection session has expired.
func sendTerminalNotification() error { func sendTerminalNotification() error {

View File

@@ -15,11 +15,6 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
@@ -295,7 +290,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
jobManager := job.NewJobManager(nil, store) peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil { if err != nil {
return nil, "", err return nil, "", err
@@ -316,16 +311,13 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) accountManager, err := server.BuildManager(context.Background(), config, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, jobManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -72,7 +72,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
lazyConnectionEnabled := true lazyConnectionEnabled := true
blockInbound := true blockInbound := true
mtu := int64(1280) mtu := int64(1280)
sshJWTCacheTTL := int32(300)
req := &proto.SetConfigRequest{ req := &proto.SetConfigRequest{
ProfileName: profName, ProfileName: profName,
@@ -103,7 +102,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
CleanDNSLabels: false, CleanDNSLabels: false,
DnsRouteInterval: durationpb.New(2 * time.Minute), DnsRouteInterval: durationpb.New(2 * time.Minute),
Mtu: &mtu, Mtu: &mtu,
SshJWTCacheTTL: &sshJWTCacheTTL,
} }
_, err = s.SetConfig(ctx, req) _, err = s.SetConfig(ctx, req)
@@ -148,8 +146,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList()) require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList())
require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval) require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval)
require.Equal(t, uint16(mtu), cfg.MTU) require.Equal(t, uint16(mtu), cfg.MTU)
require.NotNil(t, cfg.SSHJWTCacheTTL)
require.Equal(t, int(sshJWTCacheTTL), *cfg.SSHJWTCacheTTL)
verifyAllFieldsCovered(t, req) verifyAllFieldsCovered(t, req)
} }
@@ -171,36 +167,35 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
} }
expectedFields := map[string]bool{ expectedFields := map[string]bool{
"ManagementUrl": true, "ManagementUrl": true,
"AdminURL": true, "AdminURL": true,
"RosenpassEnabled": true, "RosenpassEnabled": true,
"RosenpassPermissive": true, "RosenpassPermissive": true,
"ServerSSHAllowed": true, "ServerSSHAllowed": true,
"InterfaceName": true, "InterfaceName": true,
"WireguardPort": true, "WireguardPort": true,
"OptionalPreSharedKey": true, "OptionalPreSharedKey": true,
"DisableAutoConnect": true, "DisableAutoConnect": true,
"NetworkMonitor": true, "NetworkMonitor": true,
"DisableClientRoutes": true, "DisableClientRoutes": true,
"DisableServerRoutes": true, "DisableServerRoutes": true,
"DisableDns": true, "DisableDns": true,
"DisableFirewall": true, "DisableFirewall": true,
"BlockLanAccess": true, "BlockLanAccess": true,
"DisableNotifications": true, "DisableNotifications": true,
"LazyConnectionEnabled": true, "LazyConnectionEnabled": true,
"BlockInbound": true, "BlockInbound": true,
"NatExternalIPs": true, "NatExternalIPs": true,
"CustomDNSAddress": true, "CustomDNSAddress": true,
"ExtraIFaceBlacklist": true, "ExtraIFaceBlacklist": true,
"DnsLabels": true, "DnsLabels": true,
"DnsRouteInterval": true, "DnsRouteInterval": true,
"Mtu": true, "Mtu": true,
"EnableSSHRoot": true, "EnableSSHRoot": true,
"EnableSSHSFTP": true, "EnableSSHSFTP": true,
"EnableSSHLocalPortForwarding": true, "EnableSSHLocalPortForward": true,
"EnableSSHRemotePortForwarding": true, "EnableSSHRemotePortForward": true,
"DisableSSHAuth": true, "DisableSSHAuth": true,
"SshJWTCacheTTL": true,
} }
val := reflect.ValueOf(req).Elem() val := reflect.ValueOf(req).Elem()
@@ -256,10 +251,9 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
"mtu": "Mtu", "mtu": "Mtu",
"enable-ssh-root": "EnableSSHRoot", "enable-ssh-root": "EnableSSHRoot",
"enable-ssh-sftp": "EnableSSHSFTP", "enable-ssh-sftp": "EnableSSHSFTP",
"enable-ssh-local-port-forwarding": "EnableSSHLocalPortForwarding", "enable-ssh-local-port-forwarding": "EnableSSHLocalPortForward",
"enable-ssh-remote-port-forwarding": "EnableSSHRemotePortForwarding", "enable-ssh-remote-port-forwarding": "EnableSSHRemotePortForward",
"disable-ssh-auth": "DisableSSHAuth", "disable-ssh-auth": "DisableSSHAuth",
"ssh-jwt-cache-ttl": "SshJWTCacheTTL",
} }
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means). // SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).

View File

@@ -20,7 +20,6 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/client/ssh/detection"
@@ -215,7 +214,7 @@ func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) erro
} }
} }
// handleCommandError processes command execution errors // handleCommandError processes command execution errors, treating exit codes as normal
func (c *Client) handleCommandError(err error) error { func (c *Client) handleCommandError(err error) error {
if err == nil { if err == nil {
return nil return nil
@@ -223,11 +222,11 @@ func (c *Client) handleCommandError(err error) error {
var e *ssh.ExitError var e *ssh.ExitError
var em *ssh.ExitMissingError var em *ssh.ExitMissingError
if errors.As(err, &e) || errors.As(err, &em) { if !errors.As(err, &e) && !errors.As(err, &em) {
return err return fmt.Errorf("execute command: %w", err)
} }
return fmt.Errorf("execute command: %w", err) return nil
} }
// setupContextCancellation sets up context cancellation for a session // setupContextCancellation sets up context cancellation for a session
@@ -282,12 +281,6 @@ type DialOptions struct {
// Dial connects to the given ssh server with specified options // Dial connects to the given ssh server with specified options
func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) { func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
daemonAddr := opts.DaemonAddr
if daemonAddr == "" {
daemonAddr = getDefaultDaemonAddr()
}
opts.DaemonAddr = daemonAddr
hostKeyCallback, err := createHostKeyCallback(opts) hostKeyCallback, err := createHostKeyCallback(opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("create host key callback: %w", err) return nil, fmt.Errorf("create host key callback: %w", err)
@@ -307,6 +300,11 @@ func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, er
config.Auth = append(config.Auth, authMethod) config.Auth = append(config.Auth, authMethod)
} }
daemonAddr := opts.DaemonAddr
if daemonAddr == "" {
daemonAddr = getDefaultDaemonAddr()
}
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken) return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken)
} }
@@ -367,8 +365,6 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo
// requestJWTToken requests a JWT token from the NetBird daemon // requestJWTToken requests a JWT token from the NetBird daemon
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) { func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) {
hint := profilemanager.GetLoginHint()
conn, err := connectToDaemon(daemonAddr) conn, err := connectToDaemon(daemonAddr)
if err != nil { if err != nil {
return "", fmt.Errorf("connect to daemon: %w", err) return "", fmt.Errorf("connect to daemon: %w", err)
@@ -376,7 +372,7 @@ func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (st
defer conn.Close() defer conn.Close()
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint) return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache)
} }
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon // verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
@@ -468,7 +464,7 @@ func tryKnownHostsVerification(hostname string, remote net.Addr, key ssh.PublicK
return nil return nil
} }
} }
return fmt.Errorf("host key verification failed: key for %s not found in any known_hosts file", hostname) return fmt.Errorf("host key verification failed: key not found in NetBird daemon or any known_hosts file")
} }
func getKnownHostsFilesList(knownHostsFile string) []string { func getKnownHostsFilesList(knownHostsFile string) []string {
@@ -512,7 +508,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
go func() { go func() {
defer func() { defer func() {
if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := localListener.Close(); err != nil {
log.Debugf("local listener close error: %v", err) log.Debugf("local listener close error: %v", err)
} }
}() }()
@@ -530,9 +526,6 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
}() }()
<-ctx.Done() <-ctx.Done()
if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
log.Debugf("local listener close error: %v", err)
}
return ctx.Err() return ctx.Err()
} }

View File

@@ -137,10 +137,10 @@ func TestSSHClient_ConnectionHandling(t *testing.T) {
const numClients = 3 const numClients = 3
clients := make([]*Client, numClients) clients := make([]*Client, numClients)
currentUser := testutil.GetTestUsername(t)
for i := 0; i < numClients; i++ { for i := 0; i < numClients; i++ {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{ currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, fmt.Sprintf("%s-%d", currentUser, i), DialOptions{
InsecureSkipVerify: true, InsecureSkipVerify: true,
}) })
cancel() cancel()

View File

@@ -15,26 +15,17 @@ import (
) )
func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) error { func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) error {
stdinFd := int(os.Stdin.Fd()) fd := int(os.Stdout.Fd())
if !term.IsTerminal(stdinFd) { if !term.IsTerminal(fd) {
return c.setupNonTerminalMode(ctx, session) return c.setupNonTerminalMode(ctx, session)
} }
fd := int(os.Stdin.Fd())
state, err := term.MakeRaw(fd) state, err := term.MakeRaw(fd)
if err != nil { if err != nil {
return c.setupNonTerminalMode(ctx, session) return c.setupNonTerminalMode(ctx, session)
} }
if err := c.setupTerminal(session, fd); err != nil {
if restoreErr := term.Restore(fd, state); restoreErr != nil {
log.Debugf("restore terminal state: %v", restoreErr)
}
return err
}
c.terminalState = state c.terminalState = state
c.terminalFd = fd c.terminalFd = fd
@@ -64,10 +55,27 @@ func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) er
} }
}() }()
return nil return c.setupTerminal(session, fd)
} }
func (c *Client) setupNonTerminalMode(_ context.Context, session *ssh.Session) error { func (c *Client) setupNonTerminalMode(_ context.Context, session *ssh.Session) error {
w, h := 80, 24
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}
terminal := os.Getenv("TERM")
if terminal == "" {
terminal = "xterm-256color"
}
if err := session.RequestPty(terminal, h, w, modes); err != nil {
return fmt.Errorf("request pty: %w", err)
}
return nil return nil
} }

View File

@@ -62,10 +62,11 @@ func (c *Client) setupTerminalMode(_ context.Context, session *ssh.Session) erro
if err := c.saveWindowsConsoleState(); err != nil { if err := c.saveWindowsConsoleState(); err != nil {
var consoleErr *ConsoleUnavailableError var consoleErr *ConsoleUnavailableError
if errors.As(err, &consoleErr) { if errors.As(err, &consoleErr) {
log.Debugf("console unavailable, not requesting PTY: %v", err) log.Debugf("console unavailable, continuing with defaults: %v", err)
return nil c.terminalFd = 0
} else {
return fmt.Errorf("save console state: %w", err)
} }
return fmt.Errorf("save console state: %w", err)
} }
if err := c.enableWindowsVirtualTerminal(); err != nil { if err := c.enableWindowsVirtualTerminal(); err != nil {
@@ -104,14 +105,7 @@ func (c *Client) setupTerminalMode(_ context.Context, session *ssh.Session) erro
ssh.VREPRINT: 18, // Ctrl+R ssh.VREPRINT: 18, // Ctrl+R
} }
if err := session.RequestPty("xterm-256color", h, w, modes); err != nil { return session.RequestPty("xterm-256color", h, w, modes)
if restoreErr := c.restoreWindowsConsoleState(); restoreErr != nil {
log.Debugf("restore Windows console state: %v", restoreErr)
}
return fmt.Errorf("request pty: %w", err)
}
return nil
} }
func (c *Client) saveWindowsConsoleState() error { func (c *Client) saveWindowsConsoleState() error {

View File

@@ -68,12 +68,8 @@ func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKe
} }
// RequestJWTToken requests or retrieves a JWT token for SSH authentication // RequestJWTToken requests or retrieves a JWT token for SSH authentication
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string) (string, error) { func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool) (string, error) {
req := &proto.RequestJWTAuthRequest{} authResponse, err := client.RequestJWTAuth(ctx, &proto.RequestJWTAuthRequest{})
if hint != "" {
req.Hint = &hint
}
authResponse, err := client.RequestJWTAuth(ctx, req)
if err != nil { if err != nil {
return "", fmt.Errorf("request JWT auth: %w", err) return "", fmt.Errorf("request JWT auth: %w", err)
} }

View File

@@ -18,7 +18,6 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/client/ssh/detection"
@@ -39,7 +38,6 @@ type SSHProxy struct {
targetHost string targetHost string
targetPort int targetPort int
stderr io.Writer stderr io.Writer
conn *grpc.ClientConn
daemonClient proto.DaemonServiceClient daemonClient proto.DaemonServiceClient
} }
@@ -55,22 +53,12 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHP
targetHost: targetHost, targetHost: targetHost,
targetPort: targetPort, targetPort: targetPort,
stderr: stderr, stderr: stderr,
conn: grpcConn,
daemonClient: proto.NewDaemonServiceClient(grpcConn), daemonClient: proto.NewDaemonServiceClient(grpcConn),
}, nil }, nil
} }
func (p *SSHProxy) Close() error {
if p.conn != nil {
return p.conn.Close()
}
return nil
}
func (p *SSHProxy) Connect(ctx context.Context) error { func (p *SSHProxy) Connect(ctx context.Context) error {
hint := profilemanager.GetLoginHint() jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true)
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint)
if err != nil { if err != nil {
return fmt.Errorf(jwtAuthErrorMsg, err) return fmt.Errorf(jwtAuthErrorMsg, err)
} }
@@ -156,7 +144,6 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
if len(session.Command()) > 0 { if len(session.Command()) > 0 {
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil { if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
log.Debugf("run command: %v", err) log.Debugf("run command: %v", err)
p.handleProxyExitCode(session, err)
} }
return return
} }
@@ -167,16 +154,6 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
} }
if err := serverSession.Wait(); err != nil { if err := serverSession.Wait(); err != nil {
log.Debugf("session wait: %v", err) log.Debugf("session wait: %v", err)
p.handleProxyExitCode(session, err)
}
}
func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
var exitErr *cryptossh.ExitError
if errors.As(err, &exitErr) {
if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil {
log.Debugf("set exit status: %v", exitErr)
}
} }
} }

View File

@@ -23,7 +23,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command())) logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty) execCmd, err := s.createCommand(privilegeResult, session, hasPty)
if err != nil { if err != nil {
logger.Errorf("%s creation failed: %v", commandType, err) logger.Errorf("%s creation failed: %v", commandType, err)
@@ -42,59 +42,31 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
return return
} }
if !hasPty { if s.executeCommand(logger, session, execCmd) {
if s.executeCommand(logger, session, execCmd, cleanup) {
logger.Debugf("%s execution completed", commandType)
}
return
}
defer cleanup()
ptyReq, _, _ := session.Pty()
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
logger.Debugf("%s execution completed", commandType) logger.Debugf("%s execution completed", commandType)
} }
} }
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) { func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, error) {
localUser := privilegeResult.User localUser := privilegeResult.User
if localUser == nil {
return nil, nil, errors.New("no user in privilege result")
}
// If PTY requested but su doesn't support --pty, skip su and use executor
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
if hasPty && !s.suSupportsPty {
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
if err != nil {
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
}
cmd.Env = s.prepareCommandEnv(localUser, session)
return cmd, cleanup, nil
}
// Try su first for system integration (PAM/audit) when privileged // Try su first for system integration (PAM/audit) when privileged
cmd, err := s.createSuCommand(session, localUser, hasPty) cmd, err := s.createSuCommand(session, localUser, hasPty)
if err != nil || privilegeResult.UsedFallback { if err != nil || privilegeResult.UsedFallback {
log.Debugf("su command failed, falling back to executor: %v", err) log.Debugf("su command failed, falling back to executor: %v", err)
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty) cmd, err = s.createExecutorCommand(session, localUser, hasPty)
if err != nil { }
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
} if err != nil {
cmd.Env = s.prepareCommandEnv(localUser, session) return nil, fmt.Errorf("create command with privileges: %w", err)
return cmd, cleanup, nil
} }
cmd.Env = s.prepareCommandEnv(localUser, session) cmd.Env = s.prepareCommandEnv(localUser, session)
return cmd, func() {}, nil return cmd, nil
} }
// executeCommand executes the command and handles I/O and exit codes // executeCommand executes the command and handles I/O and exit codes
func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, cleanup func()) bool { func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd) bool {
defer cleanup()
s.setupProcessGroup(execCmd) s.setupProcessGroup(execCmd)
stdinPipe, err := execCmd.StdinPipe() stdinPipe, err := execCmd.StdinPipe()
@@ -107,7 +79,7 @@ func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd
} }
execCmd.Stdout = session execCmd.Stdout = session
execCmd.Stderr = session.Stderr() execCmd.Stderr = session
if execCmd.Dir != "" { if execCmd.Dir != "" {
if _, err := os.Stat(execCmd.Dir); err != nil { if _, err := os.Stat(execCmd.Dir); err != nil {

View File

@@ -3,13 +3,11 @@
package server package server
import ( import (
"context"
"errors" "errors"
"os/exec" "os/exec"
"os/user" "os/user"
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
) )
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform") var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
@@ -20,8 +18,8 @@ func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd
} }
// createExecutorCommand is not supported on JS/WASM // createExecutorCommand is not supported on JS/WASM
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) { func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
return nil, nil, errNotSupported return nil, errNotSupported
} }
// prepareCommandEnv is not supported on JS/WASM // prepareCommandEnv is not supported on JS/WASM
@@ -34,19 +32,5 @@ func (s *Server) setupProcessGroup(_ *exec.Cmd) {
} }
// killProcessGroup is not supported on JS/WASM // killProcessGroup is not supported on JS/WASM
func (s *Server) killProcessGroup(*exec.Cmd) { func (s *Server) killProcessGroup(_ *exec.Cmd) {
}
// detectSuPtySupport always returns false on JS/WASM
func (s *Server) detectSuPtySupport(context.Context) bool {
return false
}
// executeCommandWithPty is not supported on JS/WASM
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
logger.Errorf("PTY command execution not supported on JS/WASM")
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
} }

View File

@@ -3,14 +3,12 @@
package server package server
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"os" "os"
"os/exec" "os/exec"
"os/user" "os/user"
"strings"
"sync" "sync"
"syscall" "syscall"
"time" "time"
@@ -20,6 +18,47 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// createSuCommand creates a command using su -l -c for privilege switching
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
suPath, err := exec.LookPath("su")
if err != nil {
return nil, fmt.Errorf("su command not available: %w", err)
}
command := session.RawCommand()
if command == "" {
return nil, fmt.Errorf("no command specified for su execution")
}
// TODO: handle pty flag if available
args := []string{"-l", localUser.Username, "-c", command}
cmd := exec.CommandContext(session.Context(), suPath, args...)
cmd.Dir = localUser.HomeDir
return cmd, nil
}
// getShellCommandArgs returns the shell command and arguments for executing a command string
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
if cmdString == "" {
return []string{shell, "-l"}
}
return []string{shell, "-l", "-c", cmdString}
}
// prepareCommandEnv prepares environment variables for command execution on Unix
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if acceptEnv(v) {
env = append(env, v)
}
}
return env
}
// ptyManager manages Pty file operations with thread safety // ptyManager manages Pty file operations with thread safety
type ptyManager struct { type ptyManager struct {
file *os.File file *os.File
@@ -49,7 +88,7 @@ func (pm *ptyManager) Setsize(ws *pty.Winsize) error {
pm.mu.RLock() pm.mu.RLock()
defer pm.mu.RUnlock() defer pm.mu.RUnlock()
if pm.closed { if pm.closed {
return errors.New("pty is closed") return errors.New("Pty is closed")
} }
return pty.Setsize(pm.file, ws) return pty.Setsize(pm.file, ws)
} }
@@ -58,78 +97,6 @@ func (pm *ptyManager) File() *os.File {
return pm.file return pm.file
} }
// detectSuPtySupport checks if su supports the --pty flag
func (s *Server) detectSuPtySupport(ctx context.Context) bool {
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()
cmd := exec.CommandContext(ctx, "su", "--help")
output, err := cmd.CombinedOutput()
if err != nil {
log.Debugf("su --help failed (may not support --help): %v", err)
return false
}
supported := strings.Contains(string(output), "--pty")
log.Debugf("su --pty support detected: %v", supported)
return supported
}
// createSuCommand creates a command using su -l -c for privilege switching
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
suPath, err := exec.LookPath("su")
if err != nil {
return nil, fmt.Errorf("su command not available: %w", err)
}
command := session.RawCommand()
if command == "" {
return nil, fmt.Errorf("no command specified for su execution")
}
args := []string{"-l"}
if hasPty && s.suSupportsPty {
args = append(args, "--pty")
}
args = append(args, localUser.Username, "-c", command)
cmd := exec.CommandContext(session.Context(), suPath, args...)
cmd.Dir = localUser.HomeDir
return cmd, nil
}
// getShellCommandArgs returns the shell command and arguments for executing a command string
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
if cmdString == "" {
return []string{shell, "-l"}
}
return []string{shell, "-l", "-c", cmdString}
}
// prepareCommandEnv prepares environment variables for command execution on Unix
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if acceptEnv(v) {
env = append(env, v)
}
}
return env
}
// executeCommandWithPty executes a command with PTY allocation
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
termType := ptyReq.Term
if termType == "" {
termType = "xterm-256color"
}
execCmd.Env = append(execCmd.Env, fmt.Sprintf("TERM=%s", termType))
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
}
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session) execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
if err != nil { if err != nil {
@@ -144,12 +111,14 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResu
return false return false
} }
logger.Infof("starting interactive shell: %s", execCmd.Path) shell := execCmd.Path
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh) cmd := session.Command()
} if len(cmd) == 0 {
logger.Infof("starting interactive shell: %s", shell)
} else {
logger.Infof("executing command: %s", safeLogCommand(cmd))
}
// runPtyCommand runs a command with PTY management (common code for interactive and command execution)
func (s *Server) runPtyCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
ptmx, err := s.startPtyCommandWithSize(execCmd, ptyReq) ptmx, err := s.startPtyCommandWithSize(execCmd, ptyReq)
if err != nil { if err != nil {
logger.Errorf("Pty start failed: %v", err) logger.Errorf("Pty start failed: %v", err)
@@ -301,29 +270,9 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) {
pgid := cmd.Process.Pid pgid := cmd.Process.Pid
if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil { if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
logger.Debugf("kill process group SIGTERM: %v", err) logger.Debugf("kill process group SIGTERM failed: %v", err)
return if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil {
} logger.Debugf("kill process group SIGKILL failed: %v", err)
const gracePeriod = 500 * time.Millisecond
const checkInterval = 50 * time.Millisecond
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
timeout := time.After(gracePeriod)
for {
select {
case <-timeout:
if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil {
logger.Debugf("kill process group SIGKILL: %v", err)
}
return
case <-ticker.C:
if err := syscall.Kill(-pgid, 0); err != nil {
return
}
} }
} }
} }

View File

@@ -1,3 +1,5 @@
//go:build windows
package server package server
import ( import (
@@ -7,6 +9,7 @@ import (
"os/exec" "os/exec"
"os/user" "os/user"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"unsafe" "unsafe"
@@ -31,11 +34,6 @@ func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
} }
}() }()
return s.getUserEnvironmentWithToken(userToken, username, domain)
}
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
userProfile, err := s.loadUserProfile(userToken, username, domain) userProfile, err := s.loadUserProfile(userToken, username, domain)
if err != nil { if err != nil {
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err) log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
@@ -268,11 +266,6 @@ func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []
} }
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
if privilegeResult.User == nil {
logger.Errorf("no user in privilege result")
return false
}
cmd := session.Command() cmd := session.Command()
shell := getUserShell(privilegeResult.User.Uid) shell := getUserShell(privilegeResult.User.Uid)
@@ -282,6 +275,7 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResu
logger.Infof("executing command: %s", safeLogCommand(cmd)) logger.Infof("executing command: %s", safeLogCommand(cmd))
} }
// Always use user switching on Windows - no direct execution
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd) s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
return true return true
} }
@@ -295,8 +289,45 @@ func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
} }
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) { func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
logger.Info("starting interactive shell") localUser := privilegeResult.User
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
username, domain := s.parseUsername(localUser.Username)
shell := getUserShell(localUser.Uid)
var command string
rawCmd := session.RawCommand()
if rawCmd != "" {
command = rawCmd
}
req := PtyExecutionRequest{
Shell: shell,
Command: command,
Width: ptyReq.Window.Width,
Height: ptyReq.Window.Height,
Username: username,
Domain: domain,
}
err := executePtyCommandWithUserToken(session.Context(), session, req)
if err != nil {
logger.Errorf("Windows ConPty with user switching failed: %v", err)
var errorMsg string
if runtime.GOOS == "windows" {
errorMsg = "Windows user switching failed - NetBird must run as a Windows service or with elevated privileges for user switching\r\n"
} else {
errorMsg = "User switching failed - login command not available\r\n"
}
if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
logger.Debugf(errWriteSession, writeErr)
}
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return
}
logger.Debugf("Windows ConPty command execution with user switching completed")
} }
type PtyExecutionRequest struct { type PtyExecutionRequest struct {
@@ -324,7 +355,7 @@ func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, re
}() }()
server := &Server{} server := &Server{}
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain) userEnv, err := server.getUserEnvironment(req.Username, req.Domain)
if err != nil { if err != nil {
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err) log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
userEnv = os.Environ() userEnv = os.Environ()
@@ -377,54 +408,3 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) {
logger.Debugf("kill process failed: %v", err) logger.Debugf("kill process failed: %v", err)
} }
} }
// detectSuPtySupport always returns false on Windows as su is not available
func (s *Server) detectSuPtySupport(context.Context) bool {
return false
}
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
command := session.RawCommand()
if command == "" {
logger.Error("no command specified for PTY execution")
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
}
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
localUser := privilegeResult.User
if localUser == nil {
logger.Errorf("no user in privilege result")
return false
}
username, domain := s.parseUsername(localUser.Username)
shell := getUserShell(localUser.Uid)
req := PtyExecutionRequest{
Shell: shell,
Command: command,
Width: ptyReq.Window.Width,
Height: ptyReq.Window.Height,
Username: username,
Domain: domain,
}
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
logger.Errorf("ConPty execution failed: %v", err)
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
logger.Debug("ConPty execution completed")
return true
}

View File

@@ -61,14 +61,12 @@ const (
convertDomainError = "convert domain to UTF16: %w" convertDomainError = "convert domain to UTF16: %w"
) )
// CreateWindowsExecutorCommand creates a Windows command with privilege dropping. func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, config WindowsExecutorConfig) (*exec.Cmd, error) {
// The caller must close the returned token handle after starting the process.
func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, config WindowsExecutorConfig) (*exec.Cmd, windows.Token, error) {
if config.Username == "" { if config.Username == "" {
return nil, 0, errors.New("username cannot be empty") return nil, errors.New("username cannot be empty")
} }
if config.Shell == "" { if config.Shell == "" {
return nil, 0, errors.New("shell cannot be empty") return nil, errors.New("shell cannot be empty")
} }
shell := config.Shell shell := config.Shell
@@ -82,13 +80,13 @@ func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, co
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs) log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
cmd, token, err := pd.CreateWindowsProcessAsUser( cmd, err := pd.CreateWindowsProcessAsUser(
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir) ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("create Windows process as user: %w", err) return nil, fmt.Errorf("create Windows process as user: %w", err)
} }
return cmd, token, nil return cmd, nil
} }
const ( const (
@@ -183,6 +181,7 @@ func newLsaString(s string) lsaString {
func generateS4UUserToken(username, domain string) (windows.Handle, error) { func generateS4UUserToken(username, domain string) (windows.Handle, error) {
userCpn := buildUserCpn(username, domain) userCpn := buildUserCpn(username, domain)
// Use proper domain detection logic instead of simple string check
pd := NewPrivilegeDropper() pd := NewPrivilegeDropper()
isDomainUser := !pd.isLocalUser(domain) isDomainUser := !pd.isLocalUser(domain)
@@ -344,7 +343,7 @@ func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, er
upnOffset := structSize upnOffset := structSize
upnBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + upnOffset)) upnBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + upnOffset))
copy((*[1025]uint16)(unsafe.Pointer(upnBuffer))[:len(upnUtf16)], upnUtf16) copy((*[512]uint16)(unsafe.Pointer(upnBuffer))[:len(upnUtf16)], upnUtf16)
s4uLogon.ClientUpn = unicodeString{ s4uLogon.ClientUpn = unicodeString{
Length: uint16((len(upnUtf16) - 1) * 2), Length: uint16((len(upnUtf16) - 1) * 2),
@@ -516,34 +515,31 @@ func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsernam
return token, nil return token, nil
} }
// CreateWindowsProcessAsUser creates a process as user with safe argument passing (for SFTP and executables). // CreateWindowsProcessAsUser creates a process as user with safe argument passing (for SFTP and executables)
// The caller must close the returned token handle after starting the process. func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, executablePath string, args []string, username, domain, workingDir string) (*exec.Cmd, error) {
func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, executablePath string, args []string, username, domain, workingDir string) (*exec.Cmd, windows.Token, error) { fullUsername := buildUserCpn(username, domain)
token, err := pd.createToken(username, domain) token, err := pd.createToken(username, domain)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("user authentication: %w", err) return nil, fmt.Errorf("user authentication: %w", err)
} }
log.Debugf("using S4U authentication for user %s", fullUsername)
defer func() { defer func() {
if err := windows.CloseHandle(token); err != nil { if err := windows.CloseHandle(token); err != nil {
log.Debugf("close impersonation token: %v", err) log.Debugf("close impersonation token error: %v", err)
} }
}() }()
cmd, primaryToken, err := pd.createProcessWithToken(ctx, windows.Token(token), executablePath, args, workingDir) return pd.createProcessWithToken(ctx, windows.Token(token), executablePath, args, workingDir)
if err != nil {
return nil, 0, err
}
return cmd, primaryToken, nil
} }
// createProcessWithToken creates process with the specified token and executable path. // createProcessWithToken creates process with the specified token and executable path
// The caller must close the returned token handle after starting the process. func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceToken windows.Token, executablePath string, args []string, workingDir string) (*exec.Cmd, error) {
func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceToken windows.Token, executablePath string, args []string, workingDir string) (*exec.Cmd, windows.Token, error) {
cmd := exec.CommandContext(ctx, executablePath, args[1:]...) cmd := exec.CommandContext(ctx, executablePath, args[1:]...)
cmd.Dir = workingDir cmd.Dir = workingDir
// Duplicate the token to create a primary token that can be used to start a new process
var primaryToken windows.Token var primaryToken windows.Token
err := windows.DuplicateTokenEx( err := windows.DuplicateTokenEx(
sourceToken, sourceToken,
@@ -554,14 +550,14 @@ func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceTo
&primaryToken, &primaryToken,
) )
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("duplicate token to primary token: %w", err) return nil, fmt.Errorf("duplicate token to primary token: %w", err)
} }
cmd.SysProcAttr = &syscall.SysProcAttr{ cmd.SysProcAttr = &syscall.SysProcAttr{
Token: syscall.Token(primaryToken), Token: syscall.Token(primaryToken),
} }
return cmd, primaryToken, nil return cmd, nil
} }
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub) // createSuCommand creates a command using su -l -c for privilege switching (Windows stub)

View File

@@ -311,16 +311,6 @@ func TestJWTFailClose(t *testing.T) {
"iat": time.Now().Add(-2 * time.Hour).Unix(), "iat": time.Now().Add(-2 * time.Hour).Unix(),
}, },
}, },
{
name: "blocks_token_exceeding_max_age",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(),
},
},
} }
for _, tc := range testCases { for _, tc := range testCases {

View File

@@ -17,7 +17,6 @@ import (
gojwt "github.com/golang-jwt/jwt/v5" gojwt "github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh" cryptossh "golang.org/x/crypto/ssh"
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
@@ -106,20 +105,12 @@ type sshConnectionState struct {
remoteAddr string remoteAddr string
} }
type authKey string
func newAuthKey(username string, remoteAddr net.Addr) authKey {
return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String()))
}
type Server struct { type Server struct {
sshServer *ssh.Server sshServer *ssh.Server
mu sync.RWMutex mu sync.RWMutex
hostKeyPEM []byte hostKeyPEM []byte
sessions map[SessionKey]ssh.Session sessions map[SessionKey]ssh.Session
sessionCancels map[ConnectionKey]context.CancelFunc sessionCancels map[ConnectionKey]context.CancelFunc
sessionJWTUsers map[SessionKey]string
pendingAuthJWT map[authKey]string
allowLocalPortForwarding bool allowLocalPortForwarding bool
allowRemotePortForwarding bool allowRemotePortForwarding bool
@@ -137,8 +128,6 @@ type Server struct {
jwtValidator *jwt.Validator jwtValidator *jwt.Validator
jwtExtractor *jwt.ClaimsExtractor jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig jwtConfig *JWTConfig
suSupportsPty bool
} }
type JWTConfig struct { type JWTConfig struct {
@@ -157,14 +146,6 @@ type Config struct {
HostKeyPEM []byte HostKeyPEM []byte
} }
// SessionInfo contains information about an active SSH session
type SessionInfo struct {
Username string
RemoteAddress string
Command string
JWTUsername string
}
// New creates an SSH server instance with the provided host key and optional JWT configuration // New creates an SSH server instance with the provided host key and optional JWT configuration
// If jwtConfig is nil, JWT authentication is disabled // If jwtConfig is nil, JWT authentication is disabled
func New(config *Config) *Server { func New(config *Config) *Server {
@@ -172,8 +153,6 @@ func New(config *Config) *Server {
mu: sync.RWMutex{}, mu: sync.RWMutex{},
hostKeyPEM: config.HostKeyPEM, hostKeyPEM: config.HostKeyPEM,
sessions: make(map[SessionKey]ssh.Session), sessions: make(map[SessionKey]ssh.Session),
sessionJWTUsers: make(map[SessionKey]string),
pendingAuthJWT: make(map[authKey]string),
remoteForwardListeners: make(map[ForwardKey]net.Listener), remoteForwardListeners: make(map[ForwardKey]net.Listener),
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState), sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
jwtEnabled: config.JWT != nil, jwtEnabled: config.JWT != nil,
@@ -192,8 +171,6 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
return errors.New("SSH server is already running") return errors.New("SSH server is already running")
} }
s.suSupportsPty = s.detectSuPtySupport(ctx)
ln, addrDesc, err := s.createListener(ctx, addr) ln, addrDesc, err := s.createListener(ctx, addr)
if err != nil { if err != nil {
return fmt.Errorf("create listener: %w", err) return fmt.Errorf("create listener: %w", err)
@@ -209,7 +186,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
log.Infof("SSH server started on %s", addrDesc) log.Infof("SSH server started on %s", addrDesc)
go func() { go func() {
if err := sshServer.Serve(ln); err != nil && !errors.Is(err, ssh.ErrServerClosed) { if err := sshServer.Serve(ln); !isShutdownError(err) {
log.Errorf("SSH server error: %v", err) log.Errorf("SSH server error: %v", err)
} }
}() }()
@@ -252,58 +229,15 @@ func (s *Server) Stop() error {
return nil return nil
} }
if err := s.sshServer.Close(); err != nil { if err := s.sshServer.Close(); err != nil && !isShutdownError(err) {
log.Debugf("close SSH server: %v", err) return fmt.Errorf("shutdown SSH server: %w", err)
} }
s.sshServer = nil s.sshServer = nil
maps.Clear(s.sessions)
maps.Clear(s.sessionJWTUsers)
maps.Clear(s.pendingAuthJWT)
maps.Clear(s.sshConnections)
for _, cancelFunc := range s.sessionCancels {
cancelFunc()
}
maps.Clear(s.sessionCancels)
for _, listener := range s.remoteForwardListeners {
if err := listener.Close(); err != nil {
log.Debugf("close remote forward listener: %v", err)
}
}
maps.Clear(s.remoteForwardListeners)
return nil return nil
} }
// GetStatus returns the current status of the SSH server and active sessions
func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
s.mu.RLock()
defer s.mu.RUnlock()
enabled = s.sshServer != nil
for sessionKey, session := range s.sessions {
cmd := "<interactive shell>"
if len(session.Command()) > 0 {
cmd = safeLogCommand(session.Command())
}
jwtUsername := s.sessionJWTUsers[sessionKey]
sessions = append(sessions, SessionInfo{
Username: session.User(),
RemoteAddress: session.RemoteAddr().String(),
Command: cmd,
JWTUsername: jwtUsername,
})
}
return enabled, sessions
}
// SetNetstackNet sets the netstack network for userspace networking // SetNetstackNet sets the netstack network for userspace networking
func (s *Server) SetNetstackNet(net *netstack.Net) { func (s *Server) SetNetstackNet(net *netstack.Net) {
s.mu.Lock() s.mu.Lock()
@@ -388,15 +322,10 @@ func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
} }
func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error { func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
if jwtConfig == nil { if jwtConfig == nil || jwtConfig.MaxTokenAge <= 0 {
return nil return nil
} }
maxTokenAge := jwtConfig.MaxTokenAge
if maxTokenAge <= 0 {
maxTokenAge = DefaultJWTMaxTokenAge
}
claims, ok := token.Claims.(gojwt.MapClaims) claims, ok := token.Claims.(gojwt.MapClaims)
if !ok { if !ok {
userID := extractUserID(token) userID := extractUserID(token)
@@ -411,7 +340,7 @@ func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
issuedAt := time.Unix(int64(iat), 0) issuedAt := time.Unix(int64(iat), 0)
tokenAge := time.Since(issuedAt) tokenAge := time.Since(issuedAt)
maxAge := time.Duration(maxTokenAge) * time.Second maxAge := time.Duration(jwtConfig.MaxTokenAge) * time.Second
if tokenAge > maxAge { if tokenAge > maxAge {
userID := getUserIDFromClaims(claims) userID := getUserIDFromClaims(claims)
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge) return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
@@ -508,11 +437,6 @@ func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
return false return false
} }
key := newAuthKey(ctx.User(), ctx.RemoteAddr())
s.mu.Lock()
s.pendingAuthJWT[key] = userAuth.UserId
s.mu.Unlock()
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr()) log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
return true return true
} }
@@ -608,6 +532,19 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
return conn return conn
} }
func isShutdownError(err error) bool {
if errors.Is(err, net.ErrClosed) {
return true
}
var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Op == "accept" {
return true
}
return false
}
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) { func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
if err := enableUserSwitching(); err != nil { if err := enableUserSwitching(); err != nil {
log.Warnf("failed to enable user switching: %v", err) log.Warnf("failed to enable user switching: %v", err)

View File

@@ -65,12 +65,9 @@ func TestSSHServerIntegration(t *testing.T) {
return return
} }
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr started <- actualAddr
addrPort, _ := netip.ParseAddrPort(actualAddr)
errChan <- server.Start(context.Background(), addrPort)
}() }()
select { select {
@@ -82,6 +79,8 @@ func TestSSHServerIntegration(t *testing.T) {
t.Fatal("Server start timeout") t.Fatal("Server start timeout")
} }
// Server is ready when we get the started signal
defer func() { defer func() {
err := server.Stop() err := server.Stop()
require.NoError(t, err) require.NoError(t, err)
@@ -167,12 +166,9 @@ func TestSSHServerMultipleConnections(t *testing.T) {
return return
} }
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr started <- actualAddr
addrPort, _ := netip.ParseAddrPort(actualAddr)
errChan <- server.Start(context.Background(), addrPort)
}() }()
select { select {
@@ -184,6 +180,8 @@ func TestSSHServerMultipleConnections(t *testing.T) {
t.Fatal("Server start timeout") t.Fatal("Server start timeout")
} }
// Server is ready when we get the started signal
defer func() { defer func() {
err := server.Stop() err := server.Stop()
require.NoError(t, err) require.NoError(t, err)
@@ -279,12 +277,9 @@ func TestSSHServerNoAuthMode(t *testing.T) {
return return
} }
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr started <- actualAddr
addrPort, _ := netip.ParseAddrPort(actualAddr)
errChan <- server.Start(context.Background(), addrPort)
}() }()
select { select {
@@ -296,6 +291,8 @@ func TestSSHServerNoAuthMode(t *testing.T) {
t.Fatal("Server start timeout") t.Fatal("Server start timeout")
} }
// Server is ready when we get the started signal
defer func() { defer func() {
err := server.Stop() err := server.Stop()
require.NoError(t, err) require.NoError(t, err)
@@ -364,12 +361,9 @@ func TestSSHServerStartStopCycle(t *testing.T) {
return return
} }
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr started <- actualAddr
addrPort, _ := netip.ParseAddrPort(actualAddr)
errChan <- server.Start(context.Background(), addrPort)
}() }()
select { select {

View File

@@ -11,29 +11,13 @@ import (
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
) )
// sessionHandler handles SSH sessions // sessionHandler handles SSH sessions
func (s *Server) sessionHandler(session ssh.Session) { func (s *Server) sessionHandler(session ssh.Session) {
sessionKey := s.registerSession(session) sessionKey := s.registerSession(session)
key := newAuthKey(session.User(), session.RemoteAddr())
s.mu.Lock()
jwtUsername := s.pendingAuthJWT[key]
if jwtUsername != "" {
s.sessionJWTUsers[sessionKey] = jwtUsername
delete(s.pendingAuthJWT, key)
}
s.mu.Unlock()
logger := log.WithField("session", sessionKey) logger := log.WithField("session", sessionKey)
if jwtUsername != "" { logger.Infof("SSH session started")
logger = logger.WithField("jwt_user", jwtUsername)
logger.Infof("SSH session started (JWT user: %s)", jwtUsername)
} else {
logger.Infof("SSH session started")
}
sessionStart := time.Now() sessionStart := time.Now()
defer s.unregisterSession(sessionKey, session) defer s.unregisterSession(sessionKey, session)
@@ -102,10 +86,9 @@ func (s *Server) registerSession(session ssh.Session) SessionKey {
return sessionKey return sessionKey
} }
func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) { func (s *Server) unregisterSession(sessionKey SessionKey, _ ssh.Session) {
s.mu.Lock() s.mu.Lock()
delete(s.sessions, sessionKey) delete(s.sessions, sessionKey)
delete(s.sessionJWTUsers, sessionKey)
// Cancel all port forwarding connections for this session // Cancel all port forwarding connections for this session
var connectionsToCancel []ConnectionKey var connectionsToCancel []ConnectionKey
@@ -123,12 +106,6 @@ func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) {
} }
} }
if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil {
if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok {
delete(s.sshConnections, sshConn)
}
}
s.mu.Unlock() s.mu.Unlock()
} }

View File

@@ -62,12 +62,9 @@ func TestSSHServer_SFTPSubsystem(t *testing.T) {
return return
} }
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr started <- actualAddr
addrPort, _ := netip.ParseAddrPort(actualAddr)
errChan <- server.Start(context.Background(), addrPort)
}() }()
select { select {
@@ -171,12 +168,9 @@ func TestSSHServer_SFTPDisabled(t *testing.T) {
return return
} }
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr started <- actualAddr
addrPort, _ := netip.ParseAddrPort(actualAddr)
errChan <- server.Start(context.Background(), addrPort)
}() }()
select { select {

View File

@@ -14,14 +14,13 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
// createSftpCommand creates a Windows SFTP command with user switching. // createSftpCommand creates a Windows SFTP command with user switching
// The caller must close the returned token handle after starting the process. func (s *Server) createSftpCommand(targetUser *user.User, sess ssh.Session) (*exec.Cmd, error) {
func (s *Server) createSftpCommand(targetUser *user.User, sess ssh.Session) (*exec.Cmd, windows.Token, error) {
username, domain := s.parseUsername(targetUser.Username) username, domain := s.parseUsername(targetUser.Username)
netbirdPath, err := os.Executable() netbirdPath, err := os.Executable()
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("get netbird executable path: %w", err) return nil, fmt.Errorf("get netbird executable path: %w", err)
} }
args := []string{ args := []string{
@@ -34,32 +33,27 @@ func (s *Server) createSftpCommand(targetUser *user.User, sess ssh.Session) (*ex
pd := NewPrivilegeDropper() pd := NewPrivilegeDropper()
token, err := pd.createToken(username, domain) token, err := pd.createToken(username, domain)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("create token: %w", err) return nil, fmt.Errorf("create token: %w", err)
} }
defer func() { defer func() {
if err := windows.CloseHandle(token); err != nil { if err := windows.CloseHandle(token); err != nil {
log.Warnf("failed to close impersonation token: %v", err) log.Warnf("failed to close Windows token handle: %v", err)
} }
}() }()
cmd, primaryToken, err := pd.createProcessWithToken(sess.Context(), windows.Token(token), netbirdPath, append([]string{netbirdPath}, args...), targetUser.HomeDir) cmd, err := pd.createProcessWithToken(sess.Context(), windows.Token(token), netbirdPath, append([]string{netbirdPath}, args...), targetUser.HomeDir)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("create SFTP command: %w", err) return nil, fmt.Errorf("create SFTP command: %w", err)
} }
log.Debugf("Created Windows SFTP command with user switching for %s", targetUser.Username) log.Debugf("Created Windows SFTP command with user switching for %s", targetUser.Username)
return cmd, primaryToken, nil return cmd, nil
} }
// executeSftpCommand executes a Windows SFTP command with proper I/O handling // executeSftpCommand executes a Windows SFTP command with proper I/O handling
func (s *Server) executeSftpCommand(sess ssh.Session, sftpCmd *exec.Cmd, token windows.Token) error { func (s *Server) executeSftpCommand(sess ssh.Session, sftpCmd *exec.Cmd) error {
defer func() {
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
log.Debugf("close primary token: %v", err)
}
}()
sftpCmd.Stdin = sess sftpCmd.Stdin = sess
sftpCmd.Stdout = sess sftpCmd.Stdout = sess
sftpCmd.Stderr = sess.Stderr() sftpCmd.Stderr = sess.Stderr()
@@ -83,9 +77,9 @@ func (s *Server) executeSftpCommand(sess ssh.Session, sftpCmd *exec.Cmd, token w
// executeSftpWithPrivilegeDrop executes SFTP using Windows privilege dropping // executeSftpWithPrivilegeDrop executes SFTP using Windows privilege dropping
func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error { func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error {
sftpCmd, token, err := s.createSftpCommand(targetUser, sess) sftpCmd, err := s.createSftpCommand(targetUser, sess)
if err != nil { if err != nil {
return fmt.Errorf("create sftp: %w", err) return fmt.Errorf("create sftp: %w", err)
} }
return s.executeSftpCommand(sess, sftpCmd, token) return s.executeSftpCommand(sess, sftpCmd)
} }

View File

@@ -99,17 +99,12 @@ func getShellFromPasswd(userID string) string {
// prepareUserEnv prepares environment variables for user execution // prepareUserEnv prepares environment variables for user execution
func prepareUserEnv(user *user.User, shell string) []string { func prepareUserEnv(user *user.User, shell string) []string {
pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games"
if runtime.GOOS == "windows" {
pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`
}
return []string{ return []string{
fmt.Sprint("SHELL=" + shell), fmt.Sprint("SHELL=" + shell),
fmt.Sprint("USER=" + user.Username), fmt.Sprint("USER=" + user.Username),
fmt.Sprint("LOGNAME=" + user.Username), fmt.Sprint("LOGNAME=" + user.Username),
fmt.Sprint("HOME=" + user.HomeDir), fmt.Sprint("HOME=" + user.HomeDir),
"PATH=" + pathValue, "PATH=/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games",
} }
} }

View File

@@ -165,7 +165,7 @@ func (s *Server) resolveRequestedUser(requestedUsername string) (*user.User, err
} }
if err := validateUsername(requestedUsername); err != nil { if err := validateUsername(requestedUsername); err != nil {
return nil, fmt.Errorf("invalid username %q: %w", requestedUsername, err) return nil, fmt.Errorf("invalid username: %w", err)
} }
u, err := lookupUser(requestedUsername) u, err := lookupUser(requestedUsername)

View File

@@ -152,18 +152,17 @@ func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) {
return groups, nil return groups, nil
} }
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping. // createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
// Returns the command and a cleanup function (no-op on Unix). func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty) log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
if err := validateUsername(localUser.Username); err != nil { if err := validateUsername(localUser.Username); err != nil {
return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err) return nil, fmt.Errorf("invalid username: %w", err)
} }
uid, gid, groups, err := s.parseUserCredentials(localUser) uid, gid, groups, err := s.parseUserCredentials(localUser)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("parse user credentials: %w", err) return nil, fmt.Errorf("parse user credentials: %w", err)
} }
privilegeDropper := NewPrivilegeDropper() privilegeDropper := NewPrivilegeDropper()
config := ExecutorConfig{ config := ExecutorConfig{
@@ -176,8 +175,7 @@ func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User
PTY: hasPty, PTY: hasPty,
} }
cmd, err := privilegeDropper.CreateExecutorCommand(session.Context(), config) return privilegeDropper.CreateExecutorCommand(session.Context(), config)
return cmd, func() {}, err
} }
// enableUserSwitching is a no-op on Unix systems // enableUserSwitching is a no-op on Unix systems
@@ -188,9 +186,6 @@ func enableUserSwitching() error {
// createPtyCommand creates the exec.Cmd for Pty execution respecting privilege check results // createPtyCommand creates the exec.Cmd for Pty execution respecting privilege check results
func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) { func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) {
localUser := privilegeResult.User localUser := privilegeResult.User
if localUser == nil {
return nil, errors.New("no user in privilege result")
}
if privilegeResult.UsedFallback { if privilegeResult.UsedFallback {
return s.createDirectPtyCommand(session, localUser, ptyReq), nil return s.createDirectPtyCommand(session, localUser, ptyReq), nil

View File

@@ -86,22 +86,20 @@ func validateUsernameFormat(username string) error {
return nil return nil
} }
// createExecutorCommand creates a command using Windows executor for privilege dropping. // createExecutorCommand creates a command using Windows executor for privilege dropping
// Returns the command and a cleanup function that must be called after starting the process. func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty) log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
username, _ := s.parseUsername(localUser.Username) username, _ := s.parseUsername(localUser.Username)
if err := validateUsername(username); err != nil { if err := validateUsername(username); err != nil {
return nil, nil, fmt.Errorf("invalid username %q: %w", username, err) return nil, fmt.Errorf("invalid username: %w", err)
} }
return s.createUserSwitchCommand(localUser, session, hasPty) return s.createUserSwitchCommand(localUser, session, hasPty)
} }
// createUserSwitchCommand creates a command with Windows user switching. // createUserSwitchCommand creates a command with Windows user switching
// Returns the command and a cleanup function that must be called after starting the process. func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, error) {
func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) {
username, domain := s.parseUsername(localUser.Username) username, domain := s.parseUsername(localUser.Username)
shell := getUserShell(localUser.Uid) shell := getUserShell(localUser.Uid)
@@ -122,20 +120,7 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi
} }
dropper := NewPrivilegeDropper() dropper := NewPrivilegeDropper()
cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config) return dropper.CreateWindowsExecutorCommand(session.Context(), config)
if err != nil {
return nil, nil, err
}
cleanup := func() {
if token != 0 {
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
log.Debugf("close primary token: %v", err)
}
}
}
return cmd, cleanup, nil
} }
// parseUsername extracts username and domain from a Windows username // parseUsername extracts username and domain from a Windows username

View File

@@ -141,7 +141,7 @@ func executeConPtyWithConfig(commandLine string, config ExecutionConfig) error {
log.Debugf("close output write handle: %v", err) log.Debugf("close output write handle: %v", err)
} }
return bridgeConPtyIO(ctx, hPty, inputWrite, outputRead, session, session, session, pi.Process) return bridgeConPtyIO(ctx, hPty, inputWrite, outputRead, session, session, pi.Process)
} }
// createConPtyPipes creates input/output pipes for ConPty. // createConPtyPipes creates input/output pipes for ConPty.
@@ -323,13 +323,8 @@ func duplicateToPrimaryToken(token windows.Handle) (windows.Handle, error) {
return primaryToken, nil return primaryToken, nil
} }
// SessionExiter provides the Exit method for reporting process exit status.
type SessionExiter interface {
Exit(code int) error
}
// bridgeConPtyIO handles I/O bridging between ConPty and readers/writers. // bridgeConPtyIO handles I/O bridging between ConPty and readers/writers.
func bridgeConPtyIO(ctx context.Context, hPty, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer, session SessionExiter, process windows.Handle) error { func bridgeConPtyIO(ctx context.Context, hPty, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer, process windows.Handle) error {
if err := ctx.Err(); err != nil { if err := ctx.Err(); err != nil {
return err return err
} }
@@ -342,15 +337,6 @@ func bridgeConPtyIO(ctx context.Context, hPty, inputWrite, outputRead windows.Ha
return processErr return processErr
} }
var exitCode uint32
if err := windows.GetExitCodeProcess(process, &exitCode); err != nil {
log.Debugf("get exit code: %v", err)
} else {
if err := session.Exit(int(exitCode)); err != nil {
log.Debugf("report exit code: %v", err)
}
}
// Clean up in the original order after process completes // Clean up in the original order after process completes
if err := reader.Close(); err != nil { if err := reader.Close(); err != nil {
log.Debugf("close reader: %v", err) log.Debugf("close reader: %v", err)

View File

@@ -228,7 +228,6 @@ func TestWindowsHandleReader(t *testing.T) {
if err := windows.CloseHandle(writeHandle); err != nil { if err := windows.CloseHandle(writeHandle); err != nil {
t.Fatalf("Should close write handle: %v", err) t.Fatalf("Should close write handle: %v", err)
} }
writeHandle = windows.InvalidHandle
// Test reading // Test reading
reader := &windowsHandleReader{handle: readHandle} reader := &windowsHandleReader{handle: readHandle}

View File

@@ -11,12 +11,8 @@ import (
"strings" "strings"
"time" "time"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
probeRelay "github.com/netbirdio/netbird/client/internal/relay" probeRelay "github.com/netbirdio/netbird/client/internal/relay"
@@ -85,18 +81,6 @@ type NsServerGroupStateOutput struct {
Error string `json:"error" yaml:"error"` Error string `json:"error" yaml:"error"`
} }
type SSHSessionOutput struct {
Username string `json:"username" yaml:"username"`
RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
Command string `json:"command" yaml:"command"`
JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"`
}
type SSHServerStateOutput struct {
Enabled bool `json:"enabled" yaml:"enabled"`
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
}
type OutputOverview struct { type OutputOverview struct {
Peers PeersStateOutput `json:"peers" yaml:"peers"` Peers PeersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"` CliVersion string `json:"cliVersion" yaml:"cliVersion"`
@@ -116,10 +100,11 @@ type OutputOverview struct {
Events []SystemEventOutput `json:"events" yaml:"events"` Events []SystemEventOutput `json:"events" yaml:"events"`
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"` LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
ProfileName string `json:"profileName" yaml:"profileName"` ProfileName string `json:"profileName" yaml:"profileName"`
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
} }
func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, daemonVersion string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview { func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview {
pbFullStatus := resp.GetFullStatus()
managementState := pbFullStatus.GetManagementState() managementState := pbFullStatus.GetManagementState()
managementOverview := ManagementStateOutput{ managementOverview := ManagementStateOutput{
URL: managementState.GetURL(), URL: managementState.GetURL(),
@@ -135,13 +120,12 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, da
} }
relayOverview := mapRelays(pbFullStatus.GetRelays()) relayOverview := mapRelays(pbFullStatus.GetRelays())
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState()) peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
peersOverview := mapPeers(pbFullStatus.GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
overview := OutputOverview{ overview := OutputOverview{
Peers: peersOverview, Peers: peersOverview,
CliVersion: version.NetbirdVersion(), CliVersion: version.NetbirdVersion(),
DaemonVersion: daemonVersion, DaemonVersion: resp.GetDaemonVersion(),
ManagementState: managementOverview, ManagementState: managementOverview,
SignalState: signalOverview, SignalState: signalOverview,
Relays: relayOverview, Relays: relayOverview,
@@ -157,7 +141,6 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, da
Events: mapEvents(pbFullStatus.GetEvents()), Events: mapEvents(pbFullStatus.GetEvents()),
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(), LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
ProfileName: profName, ProfileName: profName,
SSHServerState: sshServerOverview,
} }
if anon { if anon {
@@ -207,30 +190,6 @@ func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput {
return mappedNSGroups return mappedNSGroups
} }
func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput {
if sshServerState == nil {
return SSHServerStateOutput{
Enabled: false,
Sessions: []SSHSessionOutput{},
}
}
sessions := make([]SSHSessionOutput, 0, len(sshServerState.GetSessions()))
for _, session := range sshServerState.GetSessions() {
sessions = append(sessions, SSHSessionOutput{
Username: session.GetUsername(),
RemoteAddress: session.GetRemoteAddress(),
Command: session.GetCommand(),
JWTUsername: session.GetJwtUsername(),
})
}
return SSHServerStateOutput{
Enabled: sshServerState.GetEnabled(),
Sessions: sessions,
}
}
func mapPeers( func mapPeers(
peers []*proto.PeerState, peers []*proto.PeerState,
statusFilter string, statusFilter string,
@@ -341,7 +300,7 @@ func ParseToYAML(overview OutputOverview) (string, error) {
return string(yamlBytes), nil return string(yamlBytes), nil
} }
func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string { func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
var managementConnString string var managementConnString string
if overview.ManagementState.Connected { if overview.ManagementState.Connected {
managementConnString = "Connected" managementConnString = "Connected"
@@ -446,41 +405,6 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
lazyConnectionEnabledStatus = "true" lazyConnectionEnabledStatus = "true"
} }
sshServerStatus := "Disabled"
if overview.SSHServerState.Enabled {
sessionCount := len(overview.SSHServerState.Sessions)
if sessionCount > 0 {
sessionWord := "session"
if sessionCount > 1 {
sessionWord = "sessions"
}
sshServerStatus = fmt.Sprintf("Enabled (%d active %s)", sessionCount, sessionWord)
} else {
sshServerStatus = "Enabled"
}
if showSSHSessions && sessionCount > 0 {
for _, session := range overview.SSHServerState.Sessions {
var sessionDisplay string
if session.JWTUsername != "" {
sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s",
session.JWTUsername,
session.RemoteAddress,
session.Username,
session.Command,
)
} else {
sessionDisplay = fmt.Sprintf("[%s@%s] %s",
session.Username,
session.RemoteAddress,
session.Command,
)
}
sshServerStatus += "\n " + sessionDisplay
}
}
}
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total) peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS goos := runtime.GOOS
@@ -504,7 +428,6 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
"Interface type: %s\n"+ "Interface type: %s\n"+
"Quantum resistance: %s\n"+ "Quantum resistance: %s\n"+
"Lazy connection: %s\n"+ "Lazy connection: %s\n"+
"SSH Server: %s\n"+
"Networks: %s\n"+ "Networks: %s\n"+
"Forwarding rules: %d\n"+ "Forwarding rules: %d\n"+
"Peers count: %s\n", "Peers count: %s\n",
@@ -521,7 +444,6 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
interfaceTypeString, interfaceTypeString,
rosenpassEnabledStatus, rosenpassEnabledStatus,
lazyConnectionEnabledStatus, lazyConnectionEnabledStatus,
sshServerStatus,
networks, networks,
overview.NumberOfForwardingRules, overview.NumberOfForwardingRules,
peersCountString, peersCountString,
@@ -532,7 +454,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
func ParseToFullDetailSummary(overview OutputOverview) string { func ParseToFullDetailSummary(overview OutputOverview) string {
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive) parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
parsedEventsString := parseEvents(overview.Events) parsedEventsString := parseEvents(overview.Events)
summary := ParseGeneralSummary(overview, true, true, true, true) summary := ParseGeneralSummary(overview, true, true, true)
return fmt.Sprintf( return fmt.Sprintf(
"Peers detail:"+ "Peers detail:"+
@@ -546,94 +468,6 @@ func ParseToFullDetailSummary(overview OutputOverview) string {
) )
} }
func ToProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
if err := fullStatus.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
if err := fullStatus.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fullStatus.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fullStatus.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
return &pbFullStatus
}
func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string { func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
var ( var (
peersString = "" peersString = ""
@@ -912,13 +746,4 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
event.Metadata[k] = a.AnonymizeString(v) event.Metadata[k] = a.AnonymizeString(v)
} }
} }
for i, session := range overview.SSHServerState.Sessions {
if host, port, err := net.SplitHostPort(session.RemoteAddress); err == nil {
overview.SSHServerState.Sessions[i].RemoteAddress = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
} else {
overview.SSHServerState.Sessions[i].RemoteAddress = a.AnonymizeIPString(session.RemoteAddress)
}
overview.SSHServerState.Sessions[i].Command = a.AnonymizeString(session.Command)
}
} }

View File

@@ -231,14 +231,10 @@ var overview = OutputOverview{
Networks: []string{ Networks: []string{
"10.10.0.0/24", "10.10.0.0/24",
}, },
SSHServerState: SSHServerStateOutput{
Enabled: false,
Sessions: []SSHSessionOutput{},
},
} }
func TestConversionFromFullStatusToOutputOverview(t *testing.T) { func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
convertedResult := ConvertToStatusOutputOverview(resp.GetFullStatus(), false, resp.GetDaemonVersion(), "", nil, nil, nil, "", "") convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil, "", "")
assert.Equal(t, overview, convertedResult) assert.Equal(t, overview, convertedResult)
} }
@@ -389,11 +385,7 @@ func TestParsingToJSON(t *testing.T) {
], ],
"events": [], "events": [],
"lazyConnectionEnabled": false, "lazyConnectionEnabled": false,
"profileName":"", "profileName":""
"sshServer":{
"enabled":false,
"sessions":[]
}
}` }`
// @formatter:on // @formatter:on
@@ -496,9 +488,6 @@ dnsServers:
events: [] events: []
lazyConnectionEnabled: false lazyConnectionEnabled: false
profileName: "" profileName: ""
sshServer:
enabled: false
sessions: []
` `
assert.Equal(t, expectedYAML, yaml) assert.Equal(t, expectedYAML, yaml)
@@ -565,7 +554,6 @@ NetBird IP: 192.168.178.100/16
Interface type: Kernel Interface type: Kernel
Quantum resistance: false Quantum resistance: false
Lazy connection: false Lazy connection: false
SSH Server: Disabled
Networks: 10.10.0.0/24 Networks: 10.10.0.0/24
Forwarding rules: 0 Forwarding rules: 0
Peers count: 2/2 Connected Peers count: 2/2 Connected
@@ -575,7 +563,7 @@ Peers count: 2/2 Connected
} }
func TestParsingToShortVersion(t *testing.T) { func TestParsingToShortVersion(t *testing.T) {
shortVersion := ParseGeneralSummary(overview, false, false, false, false) shortVersion := ParseGeneralSummary(overview, false, false, false)
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + ` expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
Daemon version: 0.14.1 Daemon version: 0.14.1
@@ -590,7 +578,6 @@ NetBird IP: 192.168.178.100/16
Interface type: Kernel Interface type: Kernel
Quantum resistance: false Quantum resistance: false
Lazy connection: false Lazy connection: false
SSH Server: Disabled
Networks: 10.10.0.0/24 Networks: 10.10.0.0/24
Forwarding rules: 0 Forwarding rules: 0
Peers count: 2/2 Connected Peers count: 2/2 Connected

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.4 KiB

View File

@@ -55,7 +55,6 @@ const (
const ( const (
censoredPreSharedKey = "**********" censoredPreSharedKey = "**********"
maxSSHJWTCacheTTL = 86_400 // 24 hours in seconds
) )
func main() { func main() {
@@ -86,22 +85,21 @@ func main() {
// Create the service client (this also builds the settings or networks UI if requested). // Create the service client (this also builds the settings or networks UI if requested).
client := newServiceClient(&newServiceClientArgs{ client := newServiceClient(&newServiceClientArgs{
addr: flags.daemonAddr, addr: flags.daemonAddr,
logFile: logFile, logFile: logFile,
app: a, app: a,
showSettings: flags.showSettings, showSettings: flags.showSettings,
showNetworks: flags.showNetworks, showNetworks: flags.showNetworks,
showLoginURL: flags.showLoginURL, showLoginURL: flags.showLoginURL,
showDebug: flags.showDebug, showDebug: flags.showDebug,
showProfiles: flags.showProfiles, showProfiles: flags.showProfiles,
showQuickActions: flags.showQuickActions,
}) })
// Watch for theme/settings changes to update the icon. // Watch for theme/settings changes to update the icon.
go watchSettingsChanges(a, client) go watchSettingsChanges(a, client)
// Run in window mode if any UI flag was set. // Run in window mode if any UI flag was set.
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions { if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles {
a.Run() a.Run()
return return
} }
@@ -113,29 +111,23 @@ func main() {
return return
} }
if running { if running {
log.Infof("another process is running with pid %d, sending signal to show window", pid) log.Warnf("another process is running with pid %d, exiting", pid)
if err := sendShowWindowSignal(pid); err != nil {
log.Errorf("send signal to running instance: %v", err)
}
return return
} }
client.setupSignalHandler(client.ctx)
client.setDefaultFonts() client.setDefaultFonts()
systray.Run(client.onTrayReady, client.onTrayExit) systray.Run(client.onTrayReady, client.onTrayExit)
} }
type cliFlags struct { type cliFlags struct {
daemonAddr string daemonAddr string
showSettings bool showSettings bool
showNetworks bool showNetworks bool
showProfiles bool showProfiles bool
showDebug bool showDebug bool
showLoginURL bool showLoginURL bool
showQuickActions bool errorMsg string
errorMsg string saveLogsInFile bool
saveLogsInFile bool
} }
// parseFlags reads and returns all needed command-line flags. // parseFlags reads and returns all needed command-line flags.
@@ -151,7 +143,6 @@ func parseFlags() *cliFlags {
flag.BoolVar(&flags.showNetworks, "networks", false, "run networks window") flag.BoolVar(&flags.showNetworks, "networks", false, "run networks window")
flag.BoolVar(&flags.showProfiles, "profiles", false, "run profiles window") flag.BoolVar(&flags.showProfiles, "profiles", false, "run profiles window")
flag.BoolVar(&flags.showDebug, "debug", false, "run debug window") flag.BoolVar(&flags.showDebug, "debug", false, "run debug window")
flag.BoolVar(&flags.showQuickActions, "quick-actions", false, "run quick actions window")
flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window") flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window")
flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir())) flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window") flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
@@ -167,9 +158,11 @@ func initLogFile() (string, error) {
// watchSettingsChanges listens for Fyne theme/settings changes and updates the client icon. // watchSettingsChanges listens for Fyne theme/settings changes and updates the client icon.
func watchSettingsChanges(a fyne.App, client *serviceClient) { func watchSettingsChanges(a fyne.App, client *serviceClient) {
a.Settings().AddListener(func(settings fyne.Settings) { settingsChangeChan := make(chan fyne.Settings)
a.Settings().AddChangeListener(settingsChangeChan)
for range settingsChangeChan {
client.updateIcon() client.updateIcon()
}) }
} }
// showErrorMessage displays an error message in a simple window. // showErrorMessage displays an error message in a simple window.
@@ -277,7 +270,6 @@ type serviceClient struct {
sEnableSSHLocalPortForward *widget.Check sEnableSSHLocalPortForward *widget.Check
sEnableSSHRemotePortForward *widget.Check sEnableSSHRemotePortForward *widget.Check
sDisableSSHAuth *widget.Check sDisableSSHAuth *widget.Check
iSSHJWTCacheTTL *widget.Entry
// observable settings over corresponding iMngURL and iPreSharedKey values. // observable settings over corresponding iMngURL and iPreSharedKey values.
managementURL string managementURL string
@@ -297,7 +289,6 @@ type serviceClient struct {
enableSSHLocalPortForward bool enableSSHLocalPortForward bool
enableSSHRemotePortForward bool enableSSHRemotePortForward bool
disableSSHAuth bool disableSSHAuth bool
sshJWTCacheTTL int
connected bool connected bool
update *version.Update update *version.Update
@@ -307,7 +298,6 @@ type serviceClient struct {
showNetworks bool showNetworks bool
wNetworks fyne.Window wNetworks fyne.Window
wProfiles fyne.Window wProfiles fyne.Window
wQuickActions fyne.Window
eventManager *event.Manager eventManager *event.Manager
@@ -327,15 +317,14 @@ type menuHandler struct {
} }
type newServiceClientArgs struct { type newServiceClientArgs struct {
addr string addr string
logFile string logFile string
app fyne.App app fyne.App
showSettings bool showSettings bool
showNetworks bool showNetworks bool
showDebug bool showDebug bool
showLoginURL bool showLoginURL bool
showProfiles bool showProfiles bool
showQuickActions bool
} }
// newServiceClient instance constructor // newServiceClient instance constructor
@@ -371,8 +360,6 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
s.showDebugUI() s.showDebugUI()
case args.showProfiles: case args.showProfiles:
s.showProfilesUI() s.showProfilesUI()
case args.showQuickActions:
s.showQuickActionsUI()
} }
return s return s
@@ -454,7 +441,6 @@ func (s *serviceClient) showSettingsUI() {
s.sEnableSSHLocalPortForward = widget.NewCheck("Enable SSH Local Port Forwarding", nil) s.sEnableSSHLocalPortForward = widget.NewCheck("Enable SSH Local Port Forwarding", nil)
s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil) s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil)
s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil) s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil)
s.iSSHJWTCacheTTL = widget.NewEntry()
s.wSettings.SetContent(s.getSettingsForm()) s.wSettings.SetContent(s.getSettingsForm())
s.wSettings.Resize(fyne.NewSize(600, 400)) s.wSettings.Resize(fyne.NewSize(600, 400))
@@ -510,15 +496,11 @@ func (s *serviceClient) saveSettings() {
} }
iMngURL := strings.TrimSpace(s.iMngURL.Text) iMngURL := strings.TrimSpace(s.iMngURL.Text)
defer s.wSettings.Close()
if s.hasSettingsChanged(iMngURL, port, mtu) { if s.hasSettingsChanged(iMngURL, port, mtu) {
if err := s.applySettingsChanges(iMngURL, port, mtu); err != nil { s.applySettingsChanges(iMngURL, port, mtu)
dialog.ShowError(err, s.wSettings)
return
}
} }
s.wSettings.Close()
} }
func (s *serviceClient) validateSettings() error { func (s *serviceClient) validateSettings() error {
@@ -535,9 +517,6 @@ func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
if err != nil { if err != nil {
return 0, 0, errors.New("Invalid interface port") return 0, 0, errors.New("Invalid interface port")
} }
if port < 1 || port > 65535 {
return 0, 0, errors.New("Invalid interface port: out of range 1-65535")
}
var mtu int64 var mtu int64
mtuText := strings.TrimSpace(s.iMTU.Text) mtuText := strings.TrimSpace(s.iMTU.Text)
@@ -569,21 +548,20 @@ func (s *serviceClient) hasSettingsChanged(iMngURL string, port, mtu int64) bool
s.hasSSHChanges() s.hasSSHChanges()
} }
func (s *serviceClient) applySettingsChanges(iMngURL string, port, mtu int64) error { func (s *serviceClient) applySettingsChanges(iMngURL string, port, mtu int64) {
s.managementURL = iMngURL s.managementURL = iMngURL
s.preSharedKey = s.iPreSharedKey.Text s.preSharedKey = s.iPreSharedKey.Text
s.mtu = uint16(mtu) s.mtu = uint16(mtu)
req, err := s.buildSetConfigRequest(iMngURL, port, mtu) req, err := s.buildSetConfigRequest(iMngURL, port, mtu)
if err != nil { if err != nil {
return fmt.Errorf("build config request: %w", err) log.Errorf("build config request: %v", err)
return
} }
if err := s.sendConfigUpdate(req); err != nil { if err := s.sendConfigUpdate(req); err != nil {
return fmt.Errorf("set configuration: %w", err) dialog.ShowError(fmt.Errorf("Failed to set configuration: %v", err), s.wSettings)
} }
return nil
} }
func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (*proto.SetConfigRequest, error) { func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (*proto.SetConfigRequest, error) {
@@ -621,23 +599,10 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
req.EnableSSHRoot = &s.sEnableSSHRoot.Checked req.EnableSSHRoot = &s.sEnableSSHRoot.Checked
req.EnableSSHSFTP = &s.sEnableSSHSFTP.Checked req.EnableSSHSFTP = &s.sEnableSSHSFTP.Checked
req.EnableSSHLocalPortForwarding = &s.sEnableSSHLocalPortForward.Checked req.EnableSSHLocalPortForward = &s.sEnableSSHLocalPortForward.Checked
req.EnableSSHRemotePortForwarding = &s.sEnableSSHRemotePortForward.Checked req.EnableSSHRemotePortForward = &s.sEnableSSHRemotePortForward.Checked
req.DisableSSHAuth = &s.sDisableSSHAuth.Checked req.DisableSSHAuth = &s.sDisableSSHAuth.Checked
sshJWTCacheTTLText := strings.TrimSpace(s.iSSHJWTCacheTTL.Text)
if sshJWTCacheTTLText != "" {
sshJWTCacheTTL, err := strconv.ParseInt(sshJWTCacheTTLText, 10, 32)
if err != nil {
return nil, errors.New("Invalid SSH JWT Cache TTL value")
}
if sshJWTCacheTTL < 0 || sshJWTCacheTTL > maxSSHJWTCacheTTL {
return nil, fmt.Errorf("SSH JWT Cache TTL must be between 0 and %d seconds", maxSSHJWTCacheTTL)
}
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
req.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if s.iPreSharedKey.Text != censoredPreSharedKey { if s.iPreSharedKey.Text != censoredPreSharedKey {
req.OptionalPreSharedKey = &s.iPreSharedKey.Text req.OptionalPreSharedKey = &s.iPreSharedKey.Text
} }
@@ -723,27 +688,16 @@ func (s *serviceClient) getSSHForm() *widget.Form {
{Text: "Enable SSH Local Port Forwarding", Widget: s.sEnableSSHLocalPortForward}, {Text: "Enable SSH Local Port Forwarding", Widget: s.sEnableSSHLocalPortForward},
{Text: "Enable SSH Remote Port Forwarding", Widget: s.sEnableSSHRemotePortForward}, {Text: "Enable SSH Remote Port Forwarding", Widget: s.sEnableSSHRemotePortForward},
{Text: "Disable SSH Authentication", Widget: s.sDisableSSHAuth}, {Text: "Disable SSH Authentication", Widget: s.sDisableSSHAuth},
{Text: "JWT Cache TTL (seconds, 0=disabled)", Widget: s.iSSHJWTCacheTTL},
}, },
} }
} }
func (s *serviceClient) hasSSHChanges() bool { func (s *serviceClient) hasSSHChanges() bool {
currentSSHJWTCacheTTL := s.sshJWTCacheTTL
if text := strings.TrimSpace(s.iSSHJWTCacheTTL.Text); text != "" {
val, err := strconv.Atoi(text)
if err != nil {
return true
}
currentSSHJWTCacheTTL = val
}
return s.enableSSHRoot != s.sEnableSSHRoot.Checked || return s.enableSSHRoot != s.sEnableSSHRoot.Checked ||
s.enableSSHSFTP != s.sEnableSSHSFTP.Checked || s.enableSSHSFTP != s.sEnableSSHSFTP.Checked ||
s.enableSSHLocalPortForward != s.sEnableSSHLocalPortForward.Checked || s.enableSSHLocalPortForward != s.sEnableSSHLocalPortForward.Checked ||
s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked || s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked ||
s.disableSSHAuth != s.sDisableSSHAuth.Checked || s.disableSSHAuth != s.sDisableSSHAuth.Checked
s.sshJWTCacheTTL != currentSSHJWTCacheTTL
} }
func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) { func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
@@ -762,20 +716,11 @@ func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginRe
return nil, fmt.Errorf("get current user: %w", err) return nil, fmt.Errorf("get current user: %w", err)
} }
loginReq := &proto.LoginRequest{ loginResp, err := conn.Login(ctx, &proto.LoginRequest{
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
ProfileName: &activeProf.Name, ProfileName: &activeProf.Name,
Username: &currUser.Username, Username: &currUser.Username,
} })
profileState, err := s.profileManager.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginReq.Hint = &profileState.Email
}
loginResp, err := conn.Login(ctx, loginReq)
if err != nil { if err != nil {
return nil, fmt.Errorf("login to management: %w", err) return nil, fmt.Errorf("login to management: %w", err)
} }
@@ -1280,9 +1225,6 @@ func (s *serviceClient) getSrvConfig() {
if cfg.DisableSSHAuth != nil { if cfg.DisableSSHAuth != nil {
s.disableSSHAuth = *cfg.DisableSSHAuth s.disableSSHAuth = *cfg.DisableSSHAuth
} }
if cfg.SSHJWTCacheTTL != nil {
s.sshJWTCacheTTL = *cfg.SSHJWTCacheTTL
}
if s.showAdvancedSettings { if s.showAdvancedSettings {
s.iMngURL.SetText(s.managementURL) s.iMngURL.SetText(s.managementURL)
@@ -1319,9 +1261,6 @@ func (s *serviceClient) getSrvConfig() {
if cfg.DisableSSHAuth != nil { if cfg.DisableSSHAuth != nil {
s.sDisableSSHAuth.SetChecked(*cfg.DisableSSHAuth) s.sDisableSSHAuth.SetChecked(*cfg.DisableSSHAuth)
} }
if cfg.SSHJWTCacheTTL != nil {
s.iSSHJWTCacheTTL.SetText(strconv.Itoa(*cfg.SSHJWTCacheTTL))
}
} }
if s.mNotifications == nil { if s.mNotifications == nil {
@@ -1392,14 +1331,21 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
config.DisableServerRoutes = cfg.DisableServerRoutes config.DisableServerRoutes = cfg.DisableServerRoutes
config.BlockLANAccess = cfg.BlockLanAccess config.BlockLANAccess = cfg.BlockLanAccess
config.EnableSSHRoot = &cfg.EnableSSHRoot if cfg.EnableSSHRoot {
config.EnableSSHSFTP = &cfg.EnableSSHSFTP config.EnableSSHRoot = &cfg.EnableSSHRoot
config.EnableSSHLocalPortForwarding = &cfg.EnableSSHLocalPortForwarding }
config.EnableSSHRemotePortForwarding = &cfg.EnableSSHRemotePortForwarding if cfg.EnableSSHSFTP {
config.DisableSSHAuth = &cfg.DisableSSHAuth config.EnableSSHSFTP = &cfg.EnableSSHSFTP
}
ttl := int(cfg.SshJWTCacheTTL) if cfg.EnableSSHLocalPortForwarding {
config.SSHJWTCacheTTL = &ttl config.EnableSSHLocalPortForwarding = &cfg.EnableSSHLocalPortForwarding
}
if cfg.EnableSSHRemotePortForwarding {
config.EnableSSHRemotePortForwarding = &cfg.EnableSSHRemotePortForwarding
}
if cfg.DisableSSHAuth {
config.DisableSSHAuth = &cfg.DisableSSHAuth
}
return &config return &config
} }

View File

@@ -18,7 +18,9 @@ import (
"github.com/skratchdot/open-golang/open" "github.com/skratchdot/open-golang/open"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
uptypes "github.com/netbirdio/netbird/upload-server/types" uptypes "github.com/netbirdio/netbird/upload-server/types"
) )
@@ -289,18 +291,19 @@ func (s *serviceClient) handleRunForDuration(
return return
} }
defer s.restoreServiceState(conn, initialState) statusOutput, err := s.collectDebugData(conn, initialState, params, progressUI)
if err != nil {
if err := s.collectDebugData(conn, initialState, params, progressUI); err != nil {
handleError(progressUI, err.Error()) handleError(progressUI, err.Error())
return return
} }
if err := s.createDebugBundleFromCollection(conn, params, progressUI); err != nil { if err := s.createDebugBundleFromCollection(conn, params, statusOutput, progressUI); err != nil {
handleError(progressUI, err.Error()) handleError(progressUI, err.Error())
return return
} }
s.restoreServiceState(conn, initialState)
progressUI.statusLabel.SetText("Bundle created successfully") progressUI.statusLabel.SetText("Bundle created successfully")
} }
@@ -414,33 +417,68 @@ func (s *serviceClient) collectDebugData(
state *debugInitialState, state *debugInitialState,
params *debugCollectionParams, params *debugCollectionParams,
progress *progressUI, progress *progressUI,
) error { ) (string, error) {
ctx, cancel := context.WithTimeout(s.ctx, params.duration) ctx, cancel := context.WithTimeout(s.ctx, params.duration)
defer cancel() defer cancel()
var wg sync.WaitGroup var wg sync.WaitGroup
startProgressTracker(ctx, &wg, params.duration, progress) startProgressTracker(ctx, &wg, params.duration, progress)
if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil { if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil {
return err return "", err
} }
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("Failed to get post-up status: %v", err)
}
var postUpStatusOutput string
if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
wg.Wait() wg.Wait()
progress.progressBar.Hide() progress.progressBar.Hide()
progress.statusLabel.SetText("Collecting debug data...") progress.statusLabel.SetText("Collecting debug data...")
return nil preDownStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("Failed to get pre-down status: %v", err)
}
var preDownStatusOutput string
if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
time.Now().Format(time.RFC3339), params.duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, preDownStatusOutput)
return statusOutput, nil
} }
// Create the debug bundle with collected data // Create the debug bundle with collected data
func (s *serviceClient) createDebugBundleFromCollection( func (s *serviceClient) createDebugBundleFromCollection(
conn proto.DaemonServiceClient, conn proto.DaemonServiceClient,
params *debugCollectionParams, params *debugCollectionParams,
statusOutput string,
progress *progressUI, progress *progressUI,
) error { ) error {
progress.statusLabel.SetText("Creating debug bundle with collected logs...") progress.statusLabel.SetText("Creating debug bundle with collected logs...")
request := &proto.DebugBundleRequest{ request := &proto.DebugBundleRequest{
Anonymize: params.anonymize, Anonymize: params.anonymize,
Status: statusOutput,
SystemInfo: params.systemInfo, SystemInfo: params.systemInfo,
} }
@@ -462,7 +500,7 @@ func (s *serviceClient) createDebugBundleFromCollection(
if uploadFailureReason != "" { if uploadFailureReason != "" {
showUploadFailedDialog(progress.window, localPath, uploadFailureReason) showUploadFailedDialog(progress.window, localPath, uploadFailureReason)
} else { } else {
showUploadSuccessDialog(s.app, progress.window, localPath, uploadedKey) showUploadSuccessDialog(progress.window, localPath, uploadedKey)
} }
} else { } else {
showBundleCreatedDialog(progress.window, localPath) showBundleCreatedDialog(progress.window, localPath)
@@ -527,7 +565,7 @@ func (s *serviceClient) handleDebugCreation(
if uploadFailureReason != "" { if uploadFailureReason != "" {
showUploadFailedDialog(w, localPath, uploadFailureReason) showUploadFailedDialog(w, localPath, uploadFailureReason)
} else { } else {
showUploadSuccessDialog(s.app, w, localPath, uploadedKey) showUploadSuccessDialog(w, localPath, uploadedKey)
} }
} else { } else {
showBundleCreatedDialog(w, localPath) showBundleCreatedDialog(w, localPath)
@@ -543,8 +581,26 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
return nil, fmt.Errorf("get client: %v", err) return nil, fmt.Errorf("get client: %v", err)
} }
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("failed to get status for debug bundle: %v", err)
}
var statusOutput string
if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
request := &proto.DebugBundleRequest{ request := &proto.DebugBundleRequest{
Anonymize: anonymize, Anonymize: anonymize,
Status: statusOutput,
SystemInfo: systemInfo, SystemInfo: systemInfo,
} }
@@ -609,7 +665,7 @@ func showUploadFailedDialog(w fyne.Window, localPath, failureReason string) {
} }
// showUploadSuccessDialog displays a dialog when upload succeeds // showUploadSuccessDialog displays a dialog when upload succeeds
func showUploadSuccessDialog(a fyne.App, w fyne.Window, localPath, uploadedKey string) { func showUploadSuccessDialog(w fyne.Window, localPath, uploadedKey string) {
log.Infof("Upload key: %s", uploadedKey) log.Infof("Upload key: %s", uploadedKey)
keyEntry := widget.NewEntry() keyEntry := widget.NewEntry()
keyEntry.SetText(uploadedKey) keyEntry.SetText(uploadedKey)
@@ -627,7 +683,7 @@ func showUploadSuccessDialog(a fyne.App, w fyne.Window, localPath, uploadedKey s
customDialog := dialog.NewCustom("Upload Successful", "OK", content, w) customDialog := dialog.NewCustom("Upload Successful", "OK", content, w)
copyBtn := createButtonWithAction("Copy key", func() { copyBtn := createButtonWithAction("Copy key", func() {
a.Clipboard().SetContent(uploadedKey) w.Clipboard().SetContent(uploadedKey)
log.Info("Upload key copied to clipboard") log.Info("Upload key copied to clipboard")
}) })

View File

@@ -9,9 +9,6 @@ import (
//go:embed assets/netbird.png //go:embed assets/netbird.png
var iconAbout []byte var iconAbout []byte
//go:embed assets/netbird-disconnected.png
var iconAboutDisconnected []byte
//go:embed assets/netbird-systemtray-connected.png //go:embed assets/netbird-systemtray-connected.png
var iconConnected []byte var iconConnected []byte

View File

@@ -7,9 +7,6 @@ import (
//go:embed assets/netbird.ico //go:embed assets/netbird.ico
var iconAbout []byte var iconAbout []byte
//go:embed assets/netbird-disconnected.ico
var iconAboutDisconnected []byte
//go:embed assets/netbird-systemtray-connected.ico //go:embed assets/netbird-systemtray-connected.ico
var iconConnected []byte var iconConnected []byte

View File

@@ -1,349 +0,0 @@
//go:build !(linux && 386)
//go:generate fyne bundle -o quickactions_assets.go assets/connected.png
//go:generate fyne bundle -o quickactions_assets.go -append assets/disconnected.png
package main
import (
"context"
_ "embed"
"fmt"
"runtime"
"sync/atomic"
"time"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/canvas"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/layout"
"fyne.io/fyne/v2/widget"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
)
type quickActionsUiState struct {
connectionStatus string
isToggleButtonEnabled bool
isConnectionChanged bool
toggleAction func()
}
func newQuickActionsUiState() quickActionsUiState {
return quickActionsUiState{
connectionStatus: string(internal.StatusIdle),
isToggleButtonEnabled: false,
isConnectionChanged: false,
}
}
type clientConnectionStatusProvider interface {
connectionStatus(ctx context.Context) (string, error)
}
type daemonClientConnectionStatusProvider struct {
client proto.DaemonServiceClient
}
func (d daemonClientConnectionStatusProvider) connectionStatus(ctx context.Context) (string, error) {
childCtx, cancel := context.WithTimeout(ctx, 400*time.Millisecond)
defer cancel()
status, err := d.client.Status(childCtx, &proto.StatusRequest{})
if err != nil {
return "", err
}
return status.Status, nil
}
type clientCommand interface {
execute() error
}
type connectCommand struct {
connectClient func() error
}
func (c connectCommand) execute() error {
return c.connectClient()
}
type disconnectCommand struct {
disconnectClient func() error
}
func (c disconnectCommand) execute() error {
return c.disconnectClient()
}
type quickActionsViewModel struct {
provider clientConnectionStatusProvider
connect clientCommand
disconnect clientCommand
uiChan chan quickActionsUiState
isWatchingConnectionStatus atomic.Bool
}
func newQuickActionsViewModel(ctx context.Context, provider clientConnectionStatusProvider, connect, disconnect clientCommand, uiChan chan quickActionsUiState) {
viewModel := quickActionsViewModel{
provider: provider,
connect: connect,
disconnect: disconnect,
uiChan: uiChan,
}
viewModel.isWatchingConnectionStatus.Store(true)
// base UI status
uiChan <- newQuickActionsUiState()
// this retrieves the current connection status
// and pushes the UI state that reflects it via uiChan
go viewModel.watchConnectionStatus(ctx)
}
func (q *quickActionsViewModel) updateUiState(ctx context.Context) {
uiState := newQuickActionsUiState()
connectionStatus, err := q.provider.connectionStatus(ctx)
if err != nil {
log.Errorf("Status: Error - %v", err)
q.uiChan <- uiState
return
}
if connectionStatus == string(internal.StatusConnected) {
uiState.toggleAction = func() {
q.executeCommand(q.disconnect)
}
} else {
uiState.toggleAction = func() {
q.executeCommand(q.connect)
}
}
uiState.isToggleButtonEnabled = true
uiState.connectionStatus = connectionStatus
q.uiChan <- uiState
}
func (q *quickActionsViewModel) watchConnectionStatus(ctx context.Context) {
ticker := time.NewTicker(1000 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if q.isWatchingConnectionStatus.Load() {
q.updateUiState(ctx)
}
}
}
}
func (q *quickActionsViewModel) executeCommand(command clientCommand) {
uiState := newQuickActionsUiState()
// newQuickActionsUiState starts with Idle connection status,
// and all that's necessary here is to just disable the toggle button.
uiState.connectionStatus = ""
q.uiChan <- uiState
q.isWatchingConnectionStatus.Store(false)
err := command.execute()
if err != nil {
log.Errorf("Status: Error - %v", err)
q.isWatchingConnectionStatus.Store(true)
} else {
uiState = newQuickActionsUiState()
uiState.isConnectionChanged = true
q.uiChan <- uiState
}
}
func getSystemTrayName() string {
os := runtime.GOOS
switch os {
case "darwin":
return "menu bar"
default:
return "system tray"
}
}
func (s *serviceClient) getNetBirdImage(name string, content []byte) *canvas.Image {
imageSize := fyne.NewSize(64, 64)
resource := fyne.NewStaticResource(name, content)
image := canvas.NewImageFromResource(resource)
image.FillMode = canvas.ImageFillContain
image.SetMinSize(imageSize)
image.Resize(imageSize)
return image
}
type quickActionsUiComponents struct {
content *fyne.Container
toggleConnectionButton *widget.Button
connectedLabelText, disconnectedLabelText string
connectedImage, disconnectedImage *canvas.Image
connectedCircleRes, disconnectedCircleRes fyne.Resource
}
// applyQuickActionsUiState applies a single UI state to the quick actions window.
// It closes the window and returns true if the connection status has changed,
// in which case the caller should stop processing further states.
func (s *serviceClient) applyQuickActionsUiState(
uiState quickActionsUiState,
components quickActionsUiComponents,
) bool {
if uiState.isConnectionChanged {
fyne.DoAndWait(func() {
s.wQuickActions.Close()
})
return true
}
var logo *canvas.Image
var buttonText string
var buttonIcon fyne.Resource
if uiState.connectionStatus == string(internal.StatusConnected) {
buttonText = components.connectedLabelText
buttonIcon = components.connectedCircleRes
logo = components.connectedImage
} else if uiState.connectionStatus == string(internal.StatusIdle) {
buttonText = components.disconnectedLabelText
buttonIcon = components.disconnectedCircleRes
logo = components.disconnectedImage
}
fyne.DoAndWait(func() {
if buttonText != "" {
components.toggleConnectionButton.SetText(buttonText)
}
if buttonIcon != nil {
components.toggleConnectionButton.SetIcon(buttonIcon)
}
if uiState.isToggleButtonEnabled {
components.toggleConnectionButton.Enable()
} else {
components.toggleConnectionButton.Disable()
}
components.toggleConnectionButton.OnTapped = func() {
if uiState.toggleAction != nil {
go uiState.toggleAction()
}
}
components.toggleConnectionButton.Refresh()
// the second position in the content's object array is the NetBird logo.
if logo != nil {
components.content.Objects[1] = logo
components.content.Refresh()
}
})
return false
}
// showQuickActionsUI displays a simple window with the NetBird logo and a connection toggle button.
func (s *serviceClient) showQuickActionsUI() {
s.wQuickActions = s.app.NewWindow("NetBird")
vmCtx, vmCancel := context.WithCancel(s.ctx)
s.wQuickActions.SetOnClosed(vmCancel)
client, err := s.getSrvClient(defaultFailTimeout)
connCmd := connectCommand{
connectClient: func() error {
return s.menuUpClick(s.ctx)
},
}
disConnCmd := disconnectCommand{
disconnectClient: func() error {
return s.menuDownClick()
},
}
if err != nil {
log.Errorf("get service client: %v", err)
return
}
uiChan := make(chan quickActionsUiState, 1)
newQuickActionsViewModel(vmCtx, daemonClientConnectionStatusProvider{client: client}, connCmd, disConnCmd, uiChan)
connectedImage := s.getNetBirdImage("netbird.png", iconAbout)
disconnectedImage := s.getNetBirdImage("netbird-disconnected.png", iconAboutDisconnected)
connectedCircle := canvas.NewImageFromResource(resourceConnectedPng)
disconnectedCircle := canvas.NewImageFromResource(resourceDisconnectedPng)
connectedLabelText := "Disconnect"
disconnectedLabelText := "Connect"
toggleConnectionButton := widget.NewButtonWithIcon(disconnectedLabelText, disconnectedCircle.Resource, func() {
// This button's tap function will be set when an ui state arrives via the uiChan channel.
})
// Button starts disabled until the first ui state arrives.
toggleConnectionButton.Disable()
hintLabelText := fmt.Sprintf("You can always access NetBird from your %s.", getSystemTrayName())
hintLabel := widget.NewLabel(hintLabelText)
content := container.NewVBox(
layout.NewSpacer(),
disconnectedImage,
layout.NewSpacer(),
container.NewCenter(toggleConnectionButton),
layout.NewSpacer(),
container.NewCenter(hintLabel),
)
// this watches for ui state updates.
go func() {
for {
select {
case <-vmCtx.Done():
return
case uiState, ok := <-uiChan:
if !ok {
return
}
closed := s.applyQuickActionsUiState(
uiState,
quickActionsUiComponents{
content,
toggleConnectionButton,
connectedLabelText, disconnectedLabelText,
connectedImage, disconnectedImage,
connectedCircle.Resource, disconnectedCircle.Resource,
},
)
if closed {
return
}
}
}
}()
s.wQuickActions.SetContent(content)
s.wQuickActions.Resize(fyne.NewSize(400, 200))
s.wQuickActions.SetFixedSize(true)
s.wQuickActions.Show()
}

View File

@@ -1,23 +0,0 @@
// auto-generated
// Code generated by '$ fyne bundle'. DO NOT EDIT.
package main
import (
_ "embed"
"fyne.io/fyne/v2"
)
//go:embed assets/connected.png
var resourceConnectedPngData []byte
var resourceConnectedPng = &fyne.StaticResource{
StaticName: "assets/connected.png",
StaticContent: resourceConnectedPngData,
}
//go:embed assets/disconnected.png
var resourceDisconnectedPngData []byte
var resourceDisconnectedPng = &fyne.StaticResource{
StaticName: "assets/disconnected.png",
StaticContent: resourceDisconnectedPngData,
}

View File

@@ -1,76 +0,0 @@
//go:build !windows && !(linux && 386)
package main
import (
"context"
"os"
"os/exec"
"os/signal"
"syscall"
log "github.com/sirupsen/logrus"
)
// setupSignalHandler sets up a signal handler to listen for SIGUSR1.
// When received, it opens the quick actions window.
func (s *serviceClient) setupSignalHandler(ctx context.Context) {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGUSR1)
go func() {
for {
select {
case <-ctx.Done():
return
case <-sigChan:
log.Info("received SIGUSR1 signal, opening quick actions window")
s.openQuickActions()
}
}
}()
}
// openQuickActions opens the quick actions window by spawning a new process.
func (s *serviceClient) openQuickActions() {
proc, err := os.Executable()
if err != nil {
log.Errorf("get executable path: %v", err)
return
}
cmd := exec.CommandContext(s.ctx, proc,
"--quick-actions=true",
"--daemon-addr="+s.addr,
)
if out := s.attachOutput(cmd); out != nil {
defer func() {
if err := out.Close(); err != nil {
log.Errorf("close log file %s: %v", s.logFile, err)
}
}()
}
log.Infof("running command: %s --quick-actions=true --daemon-addr=%s", proc, s.addr)
if err := cmd.Start(); err != nil {
log.Errorf("start quick actions window: %v", err)
return
}
go func() {
if err := cmd.Wait(); err != nil {
log.Debugf("quick actions window exited: %v", err)
}
}()
}
// sendShowWindowSignal sends SIGUSR1 to the specified PID.
func sendShowWindowSignal(pid int32) error {
process, err := os.FindProcess(int(pid))
if err != nil {
return err
}
return process.Signal(syscall.SIGUSR1)
}

View File

@@ -1,171 +0,0 @@
//go:build windows
package main
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
quickActionsTriggerEventName = `Global\NetBirdQuickActionsTriggerEvent`
waitTimeout = 5 * time.Second
// SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
desiredAccesses = windows.SYNCHRONIZE | windows.EVENT_MODIFY_STATE
)
func getEventNameUint16Pointer() (*uint16, error) {
eventNamePtr, err := windows.UTF16PtrFromString(quickActionsTriggerEventName)
if err != nil {
log.Errorf("Failed to convert event name '%s' to UTF16: %v", quickActionsTriggerEventName, err)
return nil, err
}
return eventNamePtr, nil
}
// setupSignalHandler sets up signal handling for Windows.
// Windows doesn't support SIGUSR1, so this uses a similar approach using windows.Events.
func (s *serviceClient) setupSignalHandler(ctx context.Context) {
eventNamePtr, err := getEventNameUint16Pointer()
if err != nil {
return
}
eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
if err != nil {
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
log.Warnf("Quick actions trigger event '%s' already exists. Attempting to open.", quickActionsTriggerEventName)
eventHandle, err = windows.OpenEvent(desiredAccesses, false, eventNamePtr)
if err != nil {
log.Errorf("Failed to open existing quick actions trigger event '%s': %v", quickActionsTriggerEventName, err)
return
}
log.Infof("Successfully opened existing quick actions trigger event '%s'.", quickActionsTriggerEventName)
} else {
log.Errorf("Failed to create quick actions trigger event '%s': %v", quickActionsTriggerEventName, err)
return
}
}
if eventHandle == windows.InvalidHandle {
log.Errorf("Obtained an invalid handle for quick actions trigger event '%s'", quickActionsTriggerEventName)
return
}
log.Infof("Quick actions handler waiting for signal on event: %s", quickActionsTriggerEventName)
go s.waitForEvent(ctx, eventHandle)
}
func (s *serviceClient) waitForEvent(ctx context.Context, eventHandle windows.Handle) {
defer func() {
if err := windows.CloseHandle(eventHandle); err != nil {
log.Errorf("Failed to close quick actions event handle '%s': %v", quickActionsTriggerEventName, err)
}
}()
for {
if ctx.Err() != nil {
return
}
status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
switch status {
case windows.WAIT_OBJECT_0:
log.Info("Received signal on quick actions event. Opening quick actions window.")
// reset the event so it can be triggered again later (manual reset == 1)
if err := windows.ResetEvent(eventHandle); err != nil {
log.Errorf("Failed to reset quick actions event '%s': %v", quickActionsTriggerEventName, err)
}
s.openQuickActions()
case uint32(windows.WAIT_TIMEOUT):
default:
if isDone := logUnexpectedStatus(ctx, status, err); isDone {
return
}
}
}
}
func logUnexpectedStatus(ctx context.Context, status uint32, err error) bool {
log.Errorf("Unexpected status %d from WaitForSingleObject for quick actions event '%s': %v",
status, quickActionsTriggerEventName, err)
select {
case <-time.After(5 * time.Second):
return false
case <-ctx.Done():
return true
}
}
// openQuickActions opens the quick actions window by spawning a new process.
func (s *serviceClient) openQuickActions() {
proc, err := os.Executable()
if err != nil {
log.Errorf("get executable path: %v", err)
return
}
cmd := exec.CommandContext(s.ctx, proc,
"--quick-actions=true",
"--daemon-addr="+s.addr,
)
if out := s.attachOutput(cmd); out != nil {
defer func() {
if err := out.Close(); err != nil {
log.Errorf("close log file %s: %v", s.logFile, err)
}
}()
}
log.Infof("running command: %s --quick-actions=true --daemon-addr=%s", proc, s.addr)
if err := cmd.Start(); err != nil {
log.Errorf("error starting quick actions window: %v", err)
return
}
go func() {
if err := cmd.Wait(); err != nil {
log.Debugf("quick actions window exited: %v", err)
}
}()
}
func sendShowWindowSignal(pid int32) error {
_, err := os.FindProcess(int(pid))
if err != nil {
return err
}
eventNamePtr, err := getEventNameUint16Pointer()
if err != nil {
return err
}
eventHandle, err := windows.OpenEvent(desiredAccesses, false, eventNamePtr)
if err != nil {
return err
}
err = windows.SetEvent(eventHandle)
if err != nil {
return fmt.Errorf("Error setting event: %w", err)
}
return nil
}

57
go.mod
View File

@@ -16,7 +16,7 @@ require (
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.3.1 github.com/vishvananda/netlink v1.3.0
golang.org/x/crypto v0.41.0 golang.org/x/crypto v0.41.0
golang.org/x/sys v0.35.0 golang.org/x/sys v0.35.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
@@ -28,8 +28,8 @@ require (
) )
require ( require (
fyne.io/fyne/v2 v2.7.0 fyne.io/fyne/v2 v2.5.3
fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58 fyne.io/systray v1.11.0
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
github.com/awnumar/memguard v0.23.0 github.com/awnumar/memguard v0.23.0
github.com/aws/aws-sdk-go-v2 v1.36.3 github.com/aws/aws-sdk-go-v2 v1.36.3
@@ -44,7 +44,7 @@ require (
github.com/eko/gocache/lib/v4 v4.2.0 github.com/eko/gocache/lib/v4 v4.2.0
github.com/eko/gocache/store/go_cache/v4 v4.2.2 github.com/eko/gocache/store/go_cache/v4 v4.2.2
github.com/eko/gocache/store/redis/v4 v4.2.2 github.com/eko/gocache/store/redis/v4 v4.2.2
github.com/fsnotify/fsnotify v1.9.0 github.com/fsnotify/fsnotify v1.7.0
github.com/gliderlabs/ssh v0.3.8 github.com/gliderlabs/ssh v0.3.8
github.com/godbus/dbus/v5 v5.1.0 github.com/godbus/dbus/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.3.0 github.com/golang-jwt/jwt/v5 v5.3.0
@@ -57,16 +57,14 @@ require (
github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
github.com/hashicorp/go-version v1.6.0 github.com/hashicorp/go-version v1.6.0
github.com/jackc/pgx/v5 v5.5.5
github.com/libdns/route53 v1.5.0 github.com/libdns/route53 v1.5.0
github.com/libp2p/go-netroute v0.2.1 github.com/libp2p/go-netroute v0.2.1
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
github.com/mdlayher/socket v0.5.1 github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/oapi-codegen/runtime v1.1.2
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -86,7 +84,7 @@ require (
github.com/shirou/gopsutil/v3 v3.24.4 github.com/shirou/gopsutil/v3 v3.24.4
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.10.0
github.com/testcontainers/testcontainers-go v0.31.0 github.com/testcontainers/testcontainers-go v0.31.0
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0 github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0
@@ -102,17 +100,15 @@ require (
go.opentelemetry.io/otel/exporters/prometheus v0.48.0 go.opentelemetry.io/otel/exporters/prometheus v0.48.0
go.opentelemetry.io/otel/metric v1.35.0 go.opentelemetry.io/otel/metric v1.35.0
go.opentelemetry.io/otel/sdk/metric v1.35.0 go.opentelemetry.io/otel/sdk/metric v1.35.0
go.uber.org/mock v0.5.0
go.uber.org/zap v1.27.0 go.uber.org/zap v1.27.0
goauthentik.io/api/v3 v3.2023051.3 goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
golang.org/x/mod v0.26.0 golang.org/x/mod v0.26.0
golang.org/x/net v0.42.0 golang.org/x/net v0.42.0
golang.org/x/oauth2 v0.30.0 golang.org/x/oauth2 v0.28.0
golang.org/x/sync v0.16.0 golang.org/x/sync v0.16.0
golang.org/x/term v0.34.0 golang.org/x/term v0.34.0
golang.org/x/time v0.12.0
google.golang.org/api v0.177.0 google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7 gorm.io/driver/mysql v1.5.7
@@ -129,11 +125,10 @@ require (
dario.cat/mergo v1.0.0 // indirect dario.cat/mergo v1.0.0 // indirect
filippo.io/edwards25519 v1.1.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/BurntSushi/toml v1.5.0 // indirect github.com/BurntSushi/toml v1.4.0 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/Microsoft/hcsshim v0.12.3 // indirect github.com/Microsoft/hcsshim v0.12.3 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
github.com/awnumar/memcall v0.4.0 // indirect github.com/awnumar/memcall v0.4.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect
@@ -154,7 +149,7 @@ require (
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/caddyserver/zerossl v0.1.3 // indirect github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/containerd v1.7.29 // indirect github.com/containerd/containerd v1.7.27 // indirect
github.com/containerd/log v0.1.0 // indirect github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect
@@ -165,12 +160,11 @@ require (
github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fredbi/uri v1.1.1 // indirect github.com/fredbi/uri v1.1.0 // indirect
github.com/fyne-io/gl-js v0.2.0 // indirect github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe // indirect
github.com/fyne-io/glfw-js v0.3.0 // indirect github.com/fyne-io/glfw-js v0.0.0-20241126112943-313d8a0fe1d0 // indirect
github.com/fyne-io/image v0.1.1 // indirect github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 // indirect
github.com/fyne-io/oksvg v0.2.0 // indirect github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6 // indirect
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
@@ -178,7 +172,7 @@ require (
github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/go-text/render v0.2.0 // indirect github.com/go-text/render v0.2.0 // indirect
github.com/go-text/typesetting v0.2.1 // indirect github.com/go-text/typesetting v0.2.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/btree v1.1.2 // indirect github.com/google/btree v1.1.2 // indirect
@@ -186,19 +180,19 @@ require (
github.com/google/s2a-go v0.1.7 // indirect github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.3 // indirect github.com/googleapis/gax-go/v2 v2.12.3 // indirect
github.com/hack-pad/go-indexeddb v0.3.2 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/hack-pad/safejs v0.1.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 // indirect github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect
@@ -218,8 +212,7 @@ require (
github.com/moby/term v0.5.0 // indirect github.com/moby/term v0.5.0 // indirect
github.com/morikuni/aec v1.0.0 // indirect github.com/morikuni/aec v1.0.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect github.com/nicksnyder/go-i18n/v2 v2.4.0 // indirect
github.com/nicksnyder/go-i18n/v2 v2.5.1 // indirect
github.com/nxadm/tail v1.4.8 // indirect github.com/nxadm/tail v1.4.8 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect
@@ -235,26 +228,28 @@ require (
github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect github.com/prometheus/procfs v0.15.1 // indirect
github.com/rymdport/portal v0.4.2 // indirect github.com/rymdport/portal v0.3.0 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
github.com/stretchr/objx v0.5.2 // indirect github.com/stretchr/objx v0.5.2 // indirect
github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/go-sysconf v0.3.14 // indirect
github.com/tklauser/numcpus v0.8.0 // indirect github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.5 // indirect github.com/vishvananda/netns v0.0.4 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/wlynxg/anet v0.0.3 // indirect github.com/wlynxg/anet v0.0.3 // indirect
github.com/yuin/goldmark v1.7.8 // indirect github.com/yuin/goldmark v1.7.1 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect github.com/zeebo/blake3 v0.2.3 // indirect
go.opencensus.io v0.24.0 // indirect go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/sdk v1.35.0 // indirect
go.opentelemetry.io/otel/trace v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect
go.uber.org/mock v0.5.0 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.24.0 // indirect golang.org/x/image v0.18.0 // indirect
golang.org/x/text v0.28.0 // indirect golang.org/x/text v0.28.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.35.0 // indirect golang.org/x/tools v0.35.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect

551
go.sum

File diff suppressed because it is too large Load Diff

View File

@@ -1,31 +0,0 @@
package cache
import (
"sync"
"github.com/netbirdio/netbird/shared/management/proto"
)
// DNSConfigCache is a thread-safe cache for DNS configuration components
type DNSConfigCache struct {
NameServerGroups sync.Map
}
// GetNameServerGroup retrieves a cached name server group
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
if c == nil {
return nil, false
}
if value, ok := c.NameServerGroups.Load(key); ok {
return value.(*proto.NameServerGroup), true
}
return nil, false
}
// SetNameServerGroup stores a name server group in the cache
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
if c == nil {
return
}
c.NameServerGroups.Store(key, value)
}

View File

@@ -1,784 +0,0 @@
package controller
import (
"context"
"errors"
"fmt"
"os"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"golang.org/x/mod/semver"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/util"
)
type Controller struct {
repo Repository
metrics *metrics
// This should not be here, but we need to maintain it for the time being
accountManagerMetrics *telemetry.AccountManagerMetrics
peersUpdateManager network_map.PeersUpdateManager
settingsManager settings.Manager
accountUpdateLocks sync.Map
sendAccountUpdateLocks sync.Map
updateAccountPeersBufferInterval atomic.Int64
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
requestBuffer account.RequestBuffer
proxyController port_forwarding.Controller
integratedPeerValidator integrated_validator.IntegratedValidator
holder *types.Holder
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
}
type bufferUpdate struct {
mu sync.Mutex
next *time.Timer
update atomic.Bool
}
var _ network_map.Controller = (*Controller)(nil)
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller) *Controller {
nMetrics, err := newMetrics(metrics.UpdateChannelMetrics())
if err != nil {
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
}
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
newNetworkMapBuilder = false
}
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
expIDs[id] = struct{}{}
}
return &Controller{
repo: newRepository(store),
metrics: nMetrics,
accountManagerMetrics: metrics.AccountManagerMetrics(),
peersUpdateManager: peersUpdateManager,
requestBuffer: requestBuffer,
integratedPeerValidator: integratedPeerValidator,
settingsManager: settingsManager,
dnsDomain: dnsDomain,
proxyController: proxyController,
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
}
}
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
var (
account *types.Account
err error
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
}
globalStart := time.Now()
hasPeersConnected := false
for _, peer := range account.Peers {
if c.peersUpdateManager.HasChannel(peer.ID) {
hasPeersConnected = true
break
}
}
if !hasPeersConnected {
return nil
}
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return fmt.Errorf("failed to get validate peers: %v", err)
}
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
if c.experimentalNetworkMap(accountID) {
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
}
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return fmt.Errorf("failed to get proxy network maps: %v", err)
}
extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get flow enabled status: %v", err)
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
for _, peer := range account.Peers {
if !c.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
continue
}
wg.Add(1)
semaphore <- struct{}{}
go func(p *nbpeer.Peer) {
defer wg.Done()
defer func() { <-semaphore }()
start := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, p.ID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
return
}
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
}
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now()
update := grpc.ToSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update})
}(peer)
}
wg.Wait()
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart))
}
return nil
}
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return nil
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
_ = c.sendUpdateAccountPeers(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.sendUpdateAccountPeers(ctx, accountID)
})
return
}
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
}()
return nil
}
// UpdatePeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return fmt.Errorf("recalculate network map cache: %v", err)
}
return c.sendUpdateAccountPeers(ctx, accountID)
}
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
if !c.peersUpdateManager.HasChannel(peerId) {
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
return fmt.Errorf("failed to send out updates to peer %s: %v", peerId, err)
}
peer := account.GetPeer(peerId)
if peer == nil {
return fmt.Errorf("peer %s doesn't exists in account %s", peerId, accountId)
}
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return fmt.Errorf("failed to get validated peers: %v", err)
}
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
postureChecks, err := c.getPeerPostureChecks(account, peerId)
if err != nil {
log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err)
return fmt.Errorf("failed to get posture checks for peer %s: %v", peerId, err)
}
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return err
}
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
if err != nil {
return fmt.Errorf("failed to get extra settings: %v", err)
}
peerGroups := account.GetPeerGroups(peerId)
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
update := grpc.ToSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update})
return nil
}
func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName())
bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return nil
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
_ = c.UpdateAccountPeers(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.UpdateAccountPeers(ctx, accountID)
})
return
}
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
}()
return nil
}
func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error {
network, err := c.repo.GetAccountNetwork(ctx, accountId)
if err != nil {
return err
}
peers, err := c.repo.GetAccountPeers(ctx, accountId)
if err != nil {
return err
}
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{
Update: &proto.SyncResponse{
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
NetworkMap: &proto.NetworkMap{
Serial: network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
DNSConfig: &proto.DNSConfig{
ForwarderPort: dnsFwdPort,
},
},
},
})
c.peersUpdateManager.CloseChannel(ctx, peerId)
return nil
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
emptyMap := &types.NetworkMap{
Network: network.Copy(),
}
return peer, emptyMap, nil, 0, nil
}
var (
account *types.Account
err error
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
}
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, nil, 0, err
}
startPosture := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, 0, err
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, nil, nil, 0, err
}
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
networkMap.Merge(proxyNetworkMap)
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
return peer, networkMap, postureChecks, dnsFwdPort, nil
}
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
c.enrichAccountFromHolder(account)
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
}
func (c *Controller) getPeerNetworkMapExp(
ctx context.Context,
accountId string,
peerId string,
validatedPeers map[string]struct{},
customZone nbdns.CustomZone,
metrics *telemetry.AccountManagerMetrics,
) *types.NetworkMap {
account := c.getAccountFromHolderOrInit(accountId)
if account == nil {
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
return &types.NetworkMap{
Network: &types.Network{},
}
}
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
}
func (c *Controller) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error {
c.enrichAccountFromHolder(account)
return account.OnPeerAddedUpdNetworkMapCache(peerId)
}
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
c.enrichAccountFromHolder(account)
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
}
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
account := c.getAccountFromHolder(accountId)
if account == nil {
return
}
account.UpdatePeerInNetworkMapCache(peer)
}
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
account.RecalculateNetworkMapCache(validatedPeers)
c.updateAccountInHolder(account)
}
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
if c.experimentalNetworkMap(accountId) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
return err
}
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
return err
}
c.recalculateNetworkMapCache(account, validatedPeers)
}
return nil
}
func (c *Controller) experimentalNetworkMap(accountId string) bool {
_, ok := c.expNewNetworkMapAIDs[accountId]
return c.expNewNetworkMap || ok
}
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
a := c.holder.GetAccount(account.Id)
if a == nil {
c.holder.AddAccount(account)
return
}
account.NetworkMapCache = a.NetworkMapCache
if account.NetworkMapCache == nil {
return
}
account.NetworkMapCache.UpdateAccountPointer(account)
c.holder.AddAccount(account)
}
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
return c.holder.GetAccount(accountID)
}
func (c *Controller) getAccountFromHolderOrInit(accountID string) *types.Account {
a := c.holder.GetAccount(accountID)
if a != nil {
return a
}
account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountWithBackpressure)
if err != nil {
return nil
}
return account
}
func (c *Controller) updateAccountInHolder(account *types.Account) {
c.holder.AddAccount(account)
}
// GetDNSDomain returns the configured dnsDomain
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
if settings == nil {
return c.dnsDomain
}
if settings.DNSDomain == "" {
return c.dnsDomain
}
return settings.DNSDomain
}
// getPeerPostureChecks returns the posture checks applied for a given peer.
func (c *Controller) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) {
peerPostureChecks := make(map[string]*posture.Checks)
if len(account.PostureChecks) == 0 {
return nil, nil
}
for _, policy := range account.Policies {
if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
continue
}
if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
return nil, err
}
}
return maps.Values(peerPostureChecks), nil
}
func (c *Controller) StartWarmup(ctx context.Context) {
var initialInterval int64
intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS")
interval, err := strconv.Atoi(intervalStr)
if err != nil {
initialInterval = 1
log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err)
} else {
initialInterval = int64(interval) * 10
go func() {
startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S")
startupPeriod, err := strconv.Atoi(startupPeriodStr)
if err != nil {
startupPeriod = 1
log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err)
}
time.Sleep(time.Duration(startupPeriod) * time.Second)
c.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond))
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval)
}()
}
c.updateAccountPeersBufferInterval.Store(int64(time.Duration(initialInterval) * time.Millisecond))
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval)
}
// computeForwarderPort checks if all peers in the account have updated to a specific version or newer.
// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
if len(peers) == 0 {
return int64(network_map.OldForwarderPort)
}
reqVer := semver.Canonical(requiredVersion)
// Check if all peers have the required version or newer
for _, peer := range peers {
// Development version is always supported
if peer.Meta.WtVersion == "development" {
continue
}
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
if peerVersion == "" {
// If any peer doesn't have version info, return 0
return int64(network_map.OldForwarderPort)
}
// Compare versions
if semver.Compare(peerVersion, reqVer) < 0 {
return int64(network_map.OldForwarderPort)
}
}
// All peers have the required version or newer
return int64(network_map.DnsForwarderPort)
}
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error {
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
if err != nil {
return err
}
if !isInGroup {
return nil
}
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
postureCheck := account.GetPostureChecks(sourcePostureCheckID)
if postureCheck == nil {
return errors.New("failed to add policy posture checks: posture checks not found")
}
peerPostureChecks[sourcePostureCheckID] = postureCheck
}
return nil
}
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
for _, sourceGroup := range rule.Sources {
group := account.GetGroup(sourceGroup)
if group == nil {
return false, fmt.Errorf("failed to check peer in policy source group: group not found")
}
if slices.Contains(group.Peers, peerID) {
return true, nil
}
}
}
return false, nil
}
func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) {
c.UpdatePeerInNetworkMapCache(accountId, peer)
_ = c.bufferSendUpdateAccountPeers(context.Background(), accountId)
}
func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error {
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
err = c.onPeerAddedUpdNetworkMapCache(account, peerID)
if err != nil {
return err
}
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
}
func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error {
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
if err != nil {
return err
}
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
}
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
account, err := c.repo.GetAccountByPeerID(ctx, peerID)
if err != nil {
return nil, err
}
peer := account.GetPeer(peerID)
if peer == nil {
return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
}
groups := make(map[string][]string)
for groupID, group := range account.Groups {
groups[groupID] = group.Peers
}
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, err
}
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, err
}
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
networkMap.Merge(proxyNetworkMap)
}
return networkMap, nil
}
func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) {
c.peersUpdateManager.CloseChannels(ctx, peerIDs)
}
func (c *Controller) IsConnected(peerID string) bool {
return c.peersUpdateManager.HasChannel(peerID)
}

Some files were not shown because too many files have changed in this diff Show More