mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-04 06:59:54 +00:00
Compare commits
91 Commits
test-ldfla
...
move-licen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
24b66fb406 | ||
|
|
9378b6b0a3 | ||
|
|
3779a3385f | ||
|
|
b5d75ad9c4 | ||
|
|
8db91abfdf | ||
|
|
6f817cad6d | ||
|
|
e3bb8c1b7b | ||
|
|
107066fa3d | ||
|
|
a7a85d4dc8 | ||
|
|
576b4a779c | ||
|
|
e6854dfd99 | ||
|
|
6f14134988 | ||
|
|
4fd64379da | ||
|
|
c20202a6c3 | ||
|
|
4386a21956 | ||
|
|
5882daf5d9 | ||
|
|
11d71e6e22 | ||
|
|
4dadcfd9bd | ||
|
|
34b55c600e | ||
|
|
316c0afa9a | ||
|
|
cf97799db8 | ||
|
|
4d297205c3 | ||
|
|
559f6aeeaf | ||
|
|
7216c201da | ||
|
|
4d89d0f115 | ||
|
|
610c880ec9 | ||
|
|
19adcb5f63 | ||
|
|
f3d31698da | ||
|
|
d9efe4e944 | ||
|
|
7e0bbaaa3c | ||
|
|
b3c7b3c7b2 | ||
|
|
66483ab48d | ||
|
|
5272fc2b18 | ||
|
|
4c53372815 | ||
|
|
79d28b71ee | ||
|
|
77a352763d | ||
|
|
cdd5c6c005 | ||
|
|
b1a9242c98 | ||
|
|
b43ef4f17b | ||
|
|
758a97c352 | ||
|
|
d93b7c2f38 | ||
|
|
fa893aa0a4 | ||
|
|
ac7120871b | ||
|
|
9a7daa132e | ||
|
|
cdded8c22e | ||
|
|
e4e0b8fff9 | ||
|
|
a4b067553d | ||
|
|
088956645f | ||
|
|
aa30b7afe8 | ||
|
|
f1bb4d2ac3 | ||
|
|
982841e25b | ||
|
|
a476b8d12f | ||
|
|
a21f924b26 | ||
|
|
9e51d2e8fb | ||
|
|
3e490d974c | ||
|
|
04bb314426 | ||
|
|
6e15882c11 | ||
|
|
76f9e11b29 | ||
|
|
612de2c784 | ||
|
|
1fdde66c31 | ||
|
|
5970591d24 | ||
|
|
0d5408baec | ||
|
|
96084e3a02 | ||
|
|
4bbca28eb6 | ||
|
|
279b77dee0 | ||
|
|
9d1554f9f7 | ||
|
|
f56075ca15 | ||
|
|
6ed846ae29 | ||
|
|
520f2cfdb4 | ||
|
|
0f79a8942d | ||
|
|
5299e9fda3 | ||
|
|
11bdf5b3a5 | ||
|
|
5fc95d4a0c | ||
|
|
c7884039b8 | ||
|
|
26fc32f1be | ||
|
|
a79cb1c11b | ||
|
|
306d75fe1a | ||
|
|
9468e69c8c | ||
|
|
f51ce7cee5 | ||
|
|
d47c6b624e | ||
|
|
471f90e8db | ||
|
|
1a3b04d2fe | ||
|
|
51b9e93eb9 | ||
|
|
2952669e97 | ||
|
|
7cd44a9a3c | ||
|
|
8684981b57 | ||
|
|
8e94d85d14 | ||
|
|
631b77dc3c | ||
|
|
50ac3d437e | ||
|
|
49bbd90557 | ||
|
|
bb74e903cd |
116
.github/workflows/check-license-dependencies.yml
vendored
116
.github/workflows/check-license-dependencies.yml
vendored
@@ -3,108 +3,40 @@ name: Check License Dependencies
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
paths:
|
|
||||||
- 'go.mod'
|
|
||||||
- 'go.sum'
|
|
||||||
- '.github/workflows/check-license-dependencies.yml'
|
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
|
||||||
- 'go.mod'
|
|
||||||
- 'go.sum'
|
|
||||||
- '.github/workflows/check-license-dependencies.yml'
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
check-internal-dependencies:
|
check-dependencies:
|
||||||
name: Check Internal AGPL Dependencies
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Check for problematic license dependencies
|
|
||||||
run: |
|
|
||||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Find all directories except the problematic ones and system dirs
|
|
||||||
FOUND_ISSUES=0
|
|
||||||
while IFS= read -r dir; do
|
|
||||||
echo "=== Checking $dir ==="
|
|
||||||
# Search for problematic imports, excluding test files
|
|
||||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
|
||||||
if [ -n "$RESULTS" ]; then
|
|
||||||
echo "❌ Found problematic dependencies:"
|
|
||||||
echo "$RESULTS"
|
|
||||||
FOUND_ISSUES=1
|
|
||||||
else
|
|
||||||
echo "✓ No problematic dependencies found"
|
|
||||||
fi
|
|
||||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
|
||||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
|
||||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
|
||||||
exit 1
|
|
||||||
else
|
|
||||||
echo ""
|
|
||||||
echo "✅ All internal license dependencies are clean"
|
|
||||||
fi
|
|
||||||
|
|
||||||
check-external-licenses:
|
|
||||||
name: Check External GPL/AGPL Licenses
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Check for problematic license dependencies
|
||||||
uses: actions/setup-go@v5
|
|
||||||
with:
|
|
||||||
go-version-file: 'go.mod'
|
|
||||||
cache: true
|
|
||||||
|
|
||||||
- name: Install go-licenses
|
|
||||||
run: go install github.com/google/go-licenses@v1.6.0
|
|
||||||
|
|
||||||
- name: Check for GPL/AGPL licensed dependencies
|
|
||||||
run: |
|
run: |
|
||||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
# Find all directories except the problematic ones and system dirs
|
||||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
FOUND_ISSUES=0
|
||||||
|
while IFS= read -r dir; do
|
||||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
echo "=== Checking $dir ==="
|
||||||
echo "Found copyleft licensed dependencies:"
|
# Search for problematic imports, excluding test files
|
||||||
echo "$COPYLEFT_DEPS"
|
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||||
echo ""
|
if [ -n "$RESULTS" ]; then
|
||||||
|
echo "❌ Found problematic dependencies:"
|
||||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
echo "$RESULTS"
|
||||||
INCOMPATIBLE=""
|
FOUND_ISSUES=1
|
||||||
while IFS=',' read -r package url license; do
|
else
|
||||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
echo "✓ No problematic dependencies found"
|
||||||
# Find ALL packages that import this GPL package using go list
|
|
||||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
|
||||||
|
|
||||||
# Check if any importer is NOT in management/signal/relay
|
|
||||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
|
|
||||||
|
|
||||||
if [ -n "$BSD_IMPORTER" ]; then
|
|
||||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
|
||||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
|
||||||
else
|
|
||||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
done <<< "$COPYLEFT_DEPS"
|
|
||||||
|
|
||||||
if [ -n "$INCOMPATIBLE" ]; then
|
|
||||||
echo ""
|
|
||||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
|
||||||
echo -e "$INCOMPATIBLE"
|
|
||||||
exit 1
|
|
||||||
fi
|
fi
|
||||||
fi
|
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
||||||
|
|
||||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
echo ""
|
||||||
|
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||||
|
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||||
|
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "✅ All license dependencies are clean"
|
||||||
|
fi
|
||||||
|
|||||||
2
.github/workflows/wasm-build-validation.yml
vendored
2
.github/workflows/wasm-build-validation.yml
vendored
@@ -47,7 +47,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version: "1.23.x"
|
||||||
- name: Build Wasm client
|
- name: Build Wasm client
|
||||||
run: GOOS=js GOARCH=wasm go build -o netbird.wasm -ldflags="-s -w" ./client/wasm/cmd
|
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
||||||
env:
|
env:
|
||||||
CGO_ENABLED: 0
|
CGO_ENABLED: 0
|
||||||
- name: Check Wasm build size
|
- name: Check Wasm build size
|
||||||
|
|||||||
@@ -200,7 +200,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
|
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/server"
|
"github.com/netbirdio/netbird/client/server"
|
||||||
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/upload-server/types"
|
"github.com/netbirdio/netbird/upload-server/types"
|
||||||
)
|
)
|
||||||
@@ -97,6 +98,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
|||||||
client := proto.NewDaemonServiceClient(conn)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
request := &proto.DebugBundleRequest{
|
request := &proto.DebugBundleRequest{
|
||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
|
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||||
SystemInfo: systemInfoFlag,
|
SystemInfo: systemInfoFlag,
|
||||||
LogFileCount: logFileCount,
|
LogFileCount: logFileCount,
|
||||||
}
|
}
|
||||||
@@ -218,6 +220,9 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
time.Sleep(3 * time.Second)
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
|
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||||
|
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
|
||||||
|
|
||||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||||
return waitErr
|
return waitErr
|
||||||
}
|
}
|
||||||
@@ -225,8 +230,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
cmd.Println("Creating debug bundle...")
|
cmd.Println("Creating debug bundle...")
|
||||||
|
|
||||||
|
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||||
|
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||||
request := &proto.DebugBundleRequest{
|
request := &proto.DebugBundleRequest{
|
||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
|
Status: statusOutput,
|
||||||
SystemInfo: systemInfoFlag,
|
SystemInfo: systemInfoFlag,
|
||||||
LogFileCount: logFileCount,
|
LogFileCount: logFileCount,
|
||||||
}
|
}
|
||||||
@@ -293,6 +301,25 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||||
|
var statusOutputString string
|
||||||
|
statusResp, err := getStatus(cmd.Context(), true)
|
||||||
|
if err != nil {
|
||||||
|
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||||
|
} else {
|
||||||
|
pm := profilemanager.NewProfileManager()
|
||||||
|
var profName string
|
||||||
|
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||||
|
profName = activeProf.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||||
|
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return statusOutputString
|
||||||
|
}
|
||||||
|
|
||||||
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
|
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
|
||||||
ticker := time.NewTicker(1 * time.Second)
|
ticker := time.NewTicker(1 * time.Second)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
@@ -351,7 +378,7 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
|
|||||||
InternalConfig: config,
|
InternalConfig: config,
|
||||||
StatusRecorder: recorder,
|
StatusRecorder: recorder,
|
||||||
SyncResponse: syncResponse,
|
SyncResponse: syncResponse,
|
||||||
LogPath: logFilePath,
|
LogFile: logFilePath,
|
||||||
},
|
},
|
||||||
debug.BundleConfig{
|
debug.BundleConfig{
|
||||||
IncludeSystemInfo: true,
|
IncludeSystemInfo: true,
|
||||||
|
|||||||
@@ -106,13 +106,6 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
|
|||||||
Username: &username,
|
Username: &username,
|
||||||
}
|
}
|
||||||
|
|
||||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
|
||||||
} else if profileState.Email != "" {
|
|
||||||
loginRequest.Hint = &profileState.Email
|
|
||||||
}
|
|
||||||
|
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
loginRequest.OptionalPreSharedKey = &preSharedKey
|
loginRequest.OptionalPreSharedKey = &preSharedKey
|
||||||
}
|
}
|
||||||
@@ -248,7 +241,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
|||||||
return fmt.Errorf("read config file %s: %v", configFilePath, err)
|
return fmt.Errorf("read config file %s: %v", configFilePath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
|
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("foreground login failed: %v", err)
|
return fmt.Errorf("foreground login failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -276,7 +269,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error {
|
||||||
needsLogin := false
|
needsLogin := false
|
||||||
|
|
||||||
err := WithBackOff(func() error {
|
err := WithBackOff(func() error {
|
||||||
@@ -293,7 +286,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
|||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
if setupKey == "" && needsLogin {
|
if setupKey == "" && needsLogin {
|
||||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
|
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -322,17 +315,8 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
|
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) {
|
||||||
hint := ""
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
|
||||||
pm := profilemanager.NewProfileManager()
|
|
||||||
profileState, err := pm.GetProfileState(profileName)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
|
||||||
} else if profileState.Email != "" {
|
|
||||||
hint = profileState.Email
|
|
||||||
}
|
|
||||||
|
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -259,7 +259,6 @@ func isServiceRunning() (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
networkdConf = "/etc/systemd/networkd.conf"
|
|
||||||
networkdConfDir = "/etc/systemd/networkd.conf.d"
|
networkdConfDir = "/etc/systemd/networkd.conf.d"
|
||||||
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
|
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
|
||||||
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
|
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
|
||||||
@@ -274,16 +273,12 @@ ManageForeignRoutingPolicyRules=no
|
|||||||
// configureSystemdNetworkd creates a drop-in configuration file to prevent
|
// configureSystemdNetworkd creates a drop-in configuration file to prevent
|
||||||
// systemd-networkd from removing NetBird's routes and policy rules.
|
// systemd-networkd from removing NetBird's routes and policy rules.
|
||||||
func configureSystemdNetworkd() error {
|
func configureSystemdNetworkd() error {
|
||||||
if _, err := os.Stat(networkdConf); os.IsNotExist(err) {
|
parentDir := filepath.Dir(networkdConfDir)
|
||||||
log.Debug("systemd-networkd not in use, skipping configuration")
|
if _, err := os.Stat(parentDir); os.IsNotExist(err) {
|
||||||
|
log.Debug("systemd networkd.conf.d parent directory does not exist, skipping configuration")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint:gosec // standard networkd permissions
|
|
||||||
if err := os.MkdirAll(networkdConfDir, 0755); err != nil {
|
|
||||||
return fmt.Errorf("create networkd.conf.d directory: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// nolint:gosec // standard networkd permissions
|
// nolint:gosec // standard networkd permissions
|
||||||
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
|
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
|
||||||
return fmt.Errorf("write networkd configuration: %w", err)
|
return fmt.Errorf("write networkd configuration: %w", err)
|
||||||
|
|||||||
@@ -14,9 +14,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
||||||
@@ -36,7 +34,6 @@ const (
|
|||||||
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
||||||
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
||||||
disableSSHAuthFlag = "disable-ssh-auth"
|
disableSSHAuthFlag = "disable-ssh-auth"
|
||||||
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -50,7 +47,6 @@ var (
|
|||||||
knownHostsFile string
|
knownHostsFile string
|
||||||
identityFile string
|
identityFile string
|
||||||
skipCachedToken bool
|
skipCachedToken bool
|
||||||
requestPTY bool
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -60,7 +56,6 @@ var (
|
|||||||
enableSSHLocalPortForward bool
|
enableSSHLocalPortForward bool
|
||||||
enableSSHRemotePortForward bool
|
enableSSHRemotePortForward bool
|
||||||
disableSSHAuth bool
|
disableSSHAuth bool
|
||||||
sshJWTCacheTTL int
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -70,16 +65,13 @@ func init() {
|
|||||||
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
||||||
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
||||||
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
||||||
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
|
|
||||||
|
|
||||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
||||||
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
||||||
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
|
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||||
sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation")
|
|
||||||
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
|
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
|
||||||
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
|
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
|
||||||
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
|
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file")
|
||||||
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
|
|
||||||
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||||
|
|
||||||
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
||||||
@@ -105,9 +97,9 @@ SSH Options:
|
|||||||
-p, --port int Remote SSH port (default 22)
|
-p, --port int Remote SSH port (default 22)
|
||||||
-u, --user string SSH username
|
-u, --user string SSH username
|
||||||
--login string SSH username (alias for --user)
|
--login string SSH username (alias for --user)
|
||||||
-t, --tty Force pseudo-terminal allocation
|
|
||||||
--strict-host-key-checking Enable strict host key checking (default: true)
|
--strict-host-key-checking Enable strict host key checking (default: true)
|
||||||
-o, --known-hosts string Path to known_hosts file
|
-o, --known-hosts string Path to known_hosts file
|
||||||
|
-i, --identity string Path to SSH private key file
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
netbird ssh peer-hostname
|
netbird ssh peer-hostname
|
||||||
@@ -115,10 +107,8 @@ Examples:
|
|||||||
netbird ssh --login root peer-hostname
|
netbird ssh --login root peer-hostname
|
||||||
netbird ssh peer-hostname ls -la
|
netbird ssh peer-hostname ls -la
|
||||||
netbird ssh peer-hostname whoami
|
netbird ssh peer-hostname whoami
|
||||||
netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen
|
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
|
||||||
netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo
|
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
|
||||||
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
|
|
||||||
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
|
|
||||||
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
|
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
|
||||||
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
|
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
|
||||||
DisableFlagParsing: true,
|
DisableFlagParsing: true,
|
||||||
@@ -153,10 +143,10 @@ func sshFn(cmd *cobra.Command, args []string) error {
|
|||||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||||
sshctx, cancel := context.WithCancel(ctx)
|
sshctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
errCh := make(chan error, 1)
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := runSSH(sshctx, host, cmd); err != nil {
|
if err := runSSH(sshctx, host, cmd); err != nil {
|
||||||
errCh <- err
|
cmd.Printf("Error: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
@@ -164,10 +154,6 @@ func sshFn(cmd *cobra.Command, args []string) error {
|
|||||||
select {
|
select {
|
||||||
case <-sig:
|
case <-sig:
|
||||||
cancel()
|
cancel()
|
||||||
<-sshctx.Done()
|
|
||||||
return nil
|
|
||||||
case err := <-errCh:
|
|
||||||
return err
|
|
||||||
case <-sshctx.Done():
|
case <-sshctx.Done():
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -365,7 +351,6 @@ type sshFlags struct {
|
|||||||
Port int
|
Port int
|
||||||
Username string
|
Username string
|
||||||
Login string
|
Login string
|
||||||
RequestPTY bool
|
|
||||||
StrictHostKeyChecking bool
|
StrictHostKeyChecking bool
|
||||||
KnownHostsFile string
|
KnownHostsFile string
|
||||||
IdentityFile string
|
IdentityFile string
|
||||||
@@ -388,24 +373,22 @@ func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
|||||||
flags := &sshFlags{}
|
flags := &sshFlags{}
|
||||||
|
|
||||||
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
|
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
|
||||||
fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port")
|
fs.Int("port", sshserver.DefaultSSHPort, "SSH port")
|
||||||
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
|
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
|
||||||
fs.StringVar(&flags.Username, "user", "", sshUsernameDesc)
|
fs.String("user", "", sshUsernameDesc)
|
||||||
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
|
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||||
fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation")
|
|
||||||
fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation")
|
|
||||||
|
|
||||||
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
|
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
|
||||||
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
|
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
|
||||||
fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file")
|
fs.String("known-hosts", "", "Path to known_hosts file")
|
||||||
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
|
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
|
||||||
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
|
fs.String("identity", "", "Path to SSH private key file")
|
||||||
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||||
|
|
||||||
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
|
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
|
||||||
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
|
fs.String("config", defaultConfigPath, "Netbird config file location")
|
||||||
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
|
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
|
||||||
fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level")
|
fs.String("log-level", defaultLogLevel, "sets Netbird log level")
|
||||||
|
|
||||||
return fs, flags
|
return fs, flags
|
||||||
}
|
}
|
||||||
@@ -426,10 +409,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
|||||||
fs, flags := createSSHFlagSet()
|
fs, flags := createSSHFlagSet()
|
||||||
|
|
||||||
if err := fs.Parse(filteredArgs); err != nil {
|
if err := fs.Parse(filteredArgs); err != nil {
|
||||||
if errors.Is(err, flag.ErrHelp) {
|
return parseHostnameAndCommand(filteredArgs)
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
remaining := fs.Args()
|
remaining := fs.Args()
|
||||||
@@ -444,7 +424,6 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
|||||||
username = flags.Login
|
username = flags.Login
|
||||||
}
|
}
|
||||||
|
|
||||||
requestPTY = flags.RequestPTY
|
|
||||||
strictHostKeyChecking = flags.StrictHostKeyChecking
|
strictHostKeyChecking = flags.StrictHostKeyChecking
|
||||||
knownHostsFile = flags.KnownHostsFile
|
knownHostsFile = flags.KnownHostsFile
|
||||||
identityFile = flags.IdentityFile
|
identityFile = flags.IdentityFile
|
||||||
@@ -541,29 +520,10 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
|||||||
|
|
||||||
// executeSSHCommand executes a command over SSH.
|
// executeSSHCommand executes a command over SSH.
|
||||||
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
|
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
|
||||||
var err error
|
if err := c.ExecuteCommandWithIO(ctx, command); err != nil {
|
||||||
if requestPTY {
|
|
||||||
err = c.ExecuteCommandWithPTY(ctx, command)
|
|
||||||
} else {
|
|
||||||
err = c.ExecuteCommandWithIO(ctx, command)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var exitErr *ssh.ExitError
|
|
||||||
if errors.As(err, &exitErr) {
|
|
||||||
os.Exit(exitErr.ExitStatus())
|
|
||||||
}
|
|
||||||
|
|
||||||
var exitMissingErr *ssh.ExitMissingError
|
|
||||||
if errors.As(err, &exitMissingErr) {
|
|
||||||
log.Debugf("Remote command exited without exit status: %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("execute command: %w", err)
|
return fmt.Errorf("execute command: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -575,13 +535,6 @@ func openSSHTerminal(ctx context.Context, c *sshclient.Client) error {
|
|||||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var exitMissingErr *ssh.ExitMissingError
|
|
||||||
if errors.As(err, &exitMissingErr) {
|
|
||||||
log.Debugf("Remote terminal exited without exit status: %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("open terminal: %w", err)
|
return fmt.Errorf("open terminal: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -765,11 +718,6 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create SSH proxy: %w", err)
|
return fmt.Errorf("create SSH proxy: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
if err := proxy.Close(); err != nil {
|
|
||||||
log.Debugf("close SSH proxy: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err := proxy.Connect(cmd.Context()); err != nil {
|
if err := proxy.Connect(cmd.Context()); err != nil {
|
||||||
return fmt.Errorf("SSH proxy: %w", err)
|
return fmt.Errorf("SSH proxy: %w", err)
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"os/user"
|
"os/user"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/pkg/sftp"
|
"github.com/pkg/sftp"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -52,7 +51,7 @@ func sftpMainDirect(cmd *cobra.Command) error {
|
|||||||
if windowsDomain != "" {
|
if windowsDomain != "" {
|
||||||
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
|
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
|
||||||
}
|
}
|
||||||
if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) {
|
if currentUser.Username != expectedUsername && currentUser.Username != windowsUsername {
|
||||||
cmd.PrintErrf("user switching failed\n")
|
cmd.PrintErrf("user switching failed\n")
|
||||||
os.Exit(sshserver.ExitCodeValidationFail)
|
os.Exit(sshserver.ExitCodeValidationFail)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -667,51 +667,3 @@ func TestSSHCommand_ParameterIsolation(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSSHCommand_InvalidFlagRejection(t *testing.T) {
|
|
||||||
// Test that invalid flags are properly rejected and not misinterpreted as hostnames
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args []string
|
|
||||||
description string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "invalid long flag before hostname",
|
|
||||||
args: []string{"--invalid-flag", "hostname"},
|
|
||||||
description: "Invalid flag should return parse error, not treat flag as hostname",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid short flag before hostname",
|
|
||||||
args: []string{"-x", "hostname"},
|
|
||||||
description: "Invalid short flag should return parse error",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid flag with value before hostname",
|
|
||||||
args: []string{"--invalid-option=value", "hostname"},
|
|
||||||
description: "Invalid flag with value should return parse error",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "typo in known flag",
|
|
||||||
args: []string{"--por", "2222", "hostname"},
|
|
||||||
description: "Typo in flag name should return parse error (not silently ignored)",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// Reset global variables
|
|
||||||
host = ""
|
|
||||||
username = ""
|
|
||||||
port = 22
|
|
||||||
command = ""
|
|
||||||
|
|
||||||
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
|
|
||||||
|
|
||||||
// Should return an error for invalid flags
|
|
||||||
assert.Error(t, err, tt.description)
|
|
||||||
|
|
||||||
// Should not have set host to the invalid flag
|
|
||||||
assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
profName = activeProf.Name
|
profName = activeProf.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
|
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
|
||||||
var statusOutputString string
|
var statusOutputString string
|
||||||
switch {
|
switch {
|
||||||
case detailFlag:
|
case detailFlag:
|
||||||
@@ -109,7 +109,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
case yamlFlag:
|
case yamlFlag:
|
||||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||||
default:
|
default:
|
||||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
|
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -13,11 +13,6 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
|
||||||
|
|
||||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||||
client "github.com/netbirdio/netbird/client/server"
|
client "github.com/netbirdio/netbird/client/server"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
@@ -89,7 +84,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
}
|
}
|
||||||
t.Cleanup(cleanUp)
|
t.Cleanup(cleanUp)
|
||||||
|
|
||||||
jobManager := job.NewJobManager(nil, store)
|
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -115,18 +110,13 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
Return(&types.Settings{}, nil).
|
Return(&types.Settings{}, nil).
|
||||||
AnyTimes()
|
AnyTimes()
|
||||||
|
|
||||||
ctx := context.Background()
|
accountManager, err := mgmt.BuildManager(context.Background(), config, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
|
||||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
|
||||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
|
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, jobManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
|||||||
|
|
||||||
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
|
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
|
||||||
|
|
||||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
|
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("foreground login failed: %v", err)
|
return fmt.Errorf("foreground login failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -200,7 +200,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
|||||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||||
|
|
||||||
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
|
return connectClient.Run(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
|
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
|
||||||
@@ -286,13 +286,6 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
|
|||||||
loginRequest.ProfileName = &activeProf.Name
|
loginRequest.ProfileName = &activeProf.Name
|
||||||
loginRequest.Username = &username
|
loginRequest.Username = &username
|
||||||
|
|
||||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
|
||||||
} else if profileState.Email != "" {
|
|
||||||
loginRequest.Hint = &profileState.Email
|
|
||||||
}
|
|
||||||
|
|
||||||
var loginErr error
|
var loginErr error
|
||||||
var loginResp *proto.LoginResponse
|
var loginResp *proto.LoginResponse
|
||||||
|
|
||||||
@@ -362,18 +355,14 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
|||||||
req.EnableSSHSFTP = &enableSSHSFTP
|
req.EnableSSHSFTP = &enableSSHSFTP
|
||||||
}
|
}
|
||||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||||
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
req.EnableSSHLocalPortForward = &enableSSHLocalPortForward
|
||||||
}
|
}
|
||||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||||
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
req.EnableSSHRemotePortForward = &enableSSHRemotePortForward
|
||||||
}
|
}
|
||||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||||
req.DisableSSHAuth = &disableSSHAuth
|
req.DisableSSHAuth = &disableSSHAuth
|
||||||
}
|
}
|
||||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
|
||||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
|
||||||
req.SshJWTCacheTTL = &sshJWTCacheTTL32
|
|
||||||
}
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
log.Errorf("parse interface name: %v", err)
|
log.Errorf("parse interface name: %v", err)
|
||||||
@@ -478,10 +467,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
|||||||
ic.DisableSSHAuth = &disableSSHAuth
|
ic.DisableSSHAuth = &disableSSHAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
|
||||||
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -602,11 +587,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
|||||||
loginRequest.DisableSSHAuth = &disableSSHAuth
|
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
|
||||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
|
||||||
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||||
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -173,7 +173,6 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||||
|
|
||||||
client := internal.NewConnectClient(ctx, c.config, recorder)
|
client := internal.NewConnectClient(ctx, c.config, recorder)
|
||||||
|
|
||||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
@@ -181,7 +180,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
run := make(chan struct{})
|
run := make(chan struct{})
|
||||||
clientErr := make(chan error, 1)
|
clientErr := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
if err := client.Run(run, ""); err != nil {
|
if err := client.Run(run); err != nil {
|
||||||
clientErr <- err
|
clientErr <- err
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
ipset "github.com/lrh3321/ipset-go"
|
"github.com/nadoo/ipset"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
@@ -41,13 +40,19 @@ type aclManager struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
||||||
return &aclManager{
|
m := &aclManager{
|
||||||
iptablesClient: iptablesClient,
|
iptablesClient: iptablesClient,
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
entries: make(map[string][][]string),
|
entries: make(map[string][][]string),
|
||||||
optionalEntries: make(map[string][]entry),
|
optionalEntries: make(map[string][]entry),
|
||||||
ipsetStore: newIpsetStore(),
|
ipsetStore: newIpsetStore(),
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
if err := ipset.Init(); err != nil {
|
||||||
|
return nil, fmt.Errorf("init ipset: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||||
@@ -93,8 +98,8 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
specs = append(specs, "-j", actionToStr(action))
|
specs = append(specs, "-j", actionToStr(action))
|
||||||
if ipsetName != "" {
|
if ipsetName != "" {
|
||||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||||
}
|
}
|
||||||
// if ruleset already exists it means we already have the firewall rule
|
// if ruleset already exists it means we already have the firewall rule
|
||||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||||
@@ -108,18 +113,14 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
}}, nil
|
}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.flushIPSet(ipsetName); err != nil {
|
if err := ipset.Flush(ipsetName); err != nil {
|
||||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
|
||||||
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
|
|
||||||
} else {
|
|
||||||
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if err := m.createIPSet(ipsetName); err != nil {
|
if err := ipset.Create(ipsetName); err != nil {
|
||||||
return nil, fmt.Errorf("create ipset: %w", err)
|
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
||||||
}
|
}
|
||||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ipList := newIpList(ip.String())
|
ipList := newIpList(ip.String())
|
||||||
@@ -171,16 +172,11 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
return fmt.Errorf("invalid rule type")
|
return fmt.Errorf("invalid rule type")
|
||||||
}
|
}
|
||||||
|
|
||||||
shouldDestroyIpset := false
|
|
||||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||||
// delete IP from ruleset IPs list and ipset
|
// delete IP from ruleset IPs list and ipset
|
||||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||||
ip := net.ParseIP(r.ip)
|
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
||||||
if ip == nil {
|
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
||||||
return fmt.Errorf("parse IP %s", r.ip)
|
|
||||||
}
|
|
||||||
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
|
|
||||||
return fmt.Errorf("delete ip from ipset: %w", err)
|
|
||||||
}
|
}
|
||||||
delete(ipsetList.ips, r.ip)
|
delete(ipsetList.ips, r.ip)
|
||||||
}
|
}
|
||||||
@@ -194,7 +190,10 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
// we delete last IP from the set, that means we need to delete
|
// we delete last IP from the set, that means we need to delete
|
||||||
// set itself and associated firewall rule too
|
// set itself and associated firewall rule too
|
||||||
m.ipsetStore.deleteIpset(r.ipsetName)
|
m.ipsetStore.deleteIpset(r.ipsetName)
|
||||||
shouldDestroyIpset = true
|
|
||||||
|
if err := ipset.Destroy(r.ipsetName); err != nil {
|
||||||
|
log.Errorf("delete empty ipset: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
||||||
@@ -207,16 +206,6 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if shouldDestroyIpset {
|
|
||||||
if err := m.destroyIPSet(r.ipsetName); err != nil {
|
|
||||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
|
||||||
log.Debugf("destroy empty ipset: %v", err)
|
|
||||||
} else {
|
|
||||||
log.Errorf("destroy empty ipset: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.updateState()
|
m.updateState()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -275,19 +264,11 @@ func (m *aclManager) cleanChains() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||||
if err := m.flushIPSet(ipsetName); err != nil {
|
if err := ipset.Flush(ipsetName); err != nil {
|
||||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||||
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
|
|
||||||
} else {
|
|
||||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if err := m.destroyIPSet(ipsetName); err != nil {
|
if err := ipset.Destroy(ipsetName); err != nil {
|
||||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
||||||
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
|
|
||||||
} else {
|
|
||||||
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
m.ipsetStore.deleteIpset(ipsetName)
|
m.ipsetStore.deleteIpset(ipsetName)
|
||||||
}
|
}
|
||||||
@@ -387,8 +368,8 @@ func (m *aclManager) updateState() {
|
|||||||
// filterRuleSpecs returns the specs of a filtering rule
|
// filterRuleSpecs returns the specs of a filtering rule
|
||||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||||
matchByIP := true
|
matchByIP := true
|
||||||
// don't use IP matching if IP is 0.0.0.0
|
// don't use IP matching if IP is ip 0.0.0.0
|
||||||
if ip.IsUnspecified() {
|
if ip.String() == "0.0.0.0" {
|
||||||
matchByIP = false
|
matchByIP = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -435,61 +416,3 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
|
|||||||
return ipsetName + actionSuffix
|
return ipsetName + actionSuffix
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) createIPSet(name string) error {
|
|
||||||
opts := ipset.CreateOptions{
|
|
||||||
Replace: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
|
||||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("created ipset %s with type hash:net", name)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
|
|
||||||
cidr := uint8(32)
|
|
||||||
if ip.To4() == nil {
|
|
||||||
cidr = 128
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := &ipset.Entry{
|
|
||||||
IP: ip,
|
|
||||||
CIDR: cidr,
|
|
||||||
Replace: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipset.Add(name, entry); err != nil {
|
|
||||||
return fmt.Errorf("add IP to ipset %s: %w", name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
|
|
||||||
cidr := uint8(32)
|
|
||||||
if ip.To4() == nil {
|
|
||||||
cidr = 128
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := &ipset.Entry{
|
|
||||||
IP: ip,
|
|
||||||
CIDR: cidr,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipset.Del(name, entry); err != nil {
|
|
||||||
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) flushIPSet(name string) error {
|
|
||||||
return ipset.Flush(name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) destroyIPSet(name string) error {
|
|
||||||
return ipset.Destroy(name)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
ipset "github.com/lrh3321/ipset-go"
|
"github.com/nadoo/ipset"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
@@ -107,6 +107,10 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if err := ipset.Init(); err != nil {
|
||||||
|
return nil, fmt.Errorf("init ipset: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -228,12 +232,12 @@ func (r *router) findSets(rule []string) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||||
if err := r.createIPSet(setName); err != nil {
|
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
|
||||||
return fmt.Errorf("create set %s: %w", setName, err)
|
return fmt.Errorf("create set %s: %w", setName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, prefix := range sources {
|
for _, prefix := range sources {
|
||||||
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
|
if err := ipset.AddPrefix(setName, prefix); err != nil {
|
||||||
return fmt.Errorf("add element to set %s: %w", setName, err)
|
return fmt.Errorf("add element to set %s: %w", setName, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -242,7 +246,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) deleteIpSet(setName string) error {
|
func (r *router) deleteIpSet(setName string) error {
|
||||||
if err := r.destroyIPSet(setName); err != nil {
|
if err := ipset.Destroy(setName); err != nil {
|
||||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -911,8 +915,8 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
|
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if merr == nil {
|
if merr == nil {
|
||||||
@@ -989,37 +993,3 @@ func applyPort(flag string, port *firewall.Port) []string {
|
|||||||
|
|
||||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createIPSet(name string) error {
|
|
||||||
opts := ipset.CreateOptions{
|
|
||||||
Replace: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
|
||||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("created ipset %s with type hash:net", name)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
|
|
||||||
addr := prefix.Addr()
|
|
||||||
ip := addr.AsSlice()
|
|
||||||
|
|
||||||
entry := &ipset.Entry{
|
|
||||||
IP: ip,
|
|
||||||
CIDR: uint8(prefix.Bits()),
|
|
||||||
Replace: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipset.Add(name, entry); err != nil {
|
|
||||||
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) destroyIPSet(name string) error {
|
|
||||||
return ipset.Destroy(name)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
@@ -11,6 +12,7 @@ import (
|
|||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/connectivity"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
@@ -18,6 +20,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
|
||||||
|
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
|
||||||
|
|
||||||
// Backoff returns a backoff configuration for gRPC calls
|
// Backoff returns a backoff configuration for gRPC calls
|
||||||
func Backoff(ctx context.Context) backoff.BackOff {
|
func Backoff(ctx context.Context) backoff.BackOff {
|
||||||
b := backoff.NewExponentialBackOff()
|
b := backoff.NewExponentialBackOff()
|
||||||
@@ -26,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
|||||||
return backoff.WithContext(b, ctx)
|
return backoff.WithContext(b, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// waitForConnectionReady blocks until the connection becomes ready or fails.
|
||||||
|
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
|
||||||
|
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
|
||||||
|
conn.Connect()
|
||||||
|
|
||||||
|
state := conn.GetState()
|
||||||
|
for state != connectivity.Ready && state != connectivity.Shutdown {
|
||||||
|
if !conn.WaitForStateChange(ctx, state) {
|
||||||
|
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
|
||||||
|
}
|
||||||
|
state = conn.GetState()
|
||||||
|
}
|
||||||
|
|
||||||
|
if state == connectivity.Shutdown {
|
||||||
|
return ErrConnectionShutdown
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||||
@@ -43,22 +68,25 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
conn, err := grpc.NewClient(
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
conn, err := grpc.DialContext(
|
|
||||||
connCtx,
|
|
||||||
addr,
|
addr,
|
||||||
transportOption,
|
transportOption,
|
||||||
WithCustomDialer(tlsEnabled, component),
|
WithCustomDialer(tlsEnabled, component),
|
||||||
grpc.WithBlock(),
|
|
||||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
Time: 30 * time.Second,
|
Time: 30 * time.Second,
|
||||||
Timeout: 10 * time.Second,
|
Timeout: 10 * time.Second,
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("dial context: %w", err)
|
return nil, fmt.Errorf("new client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := waitForConnectionReady(ctx, conn); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -10,13 +9,13 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/pion/transport/v3/stdnet"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// keep darwin compatibility
|
// keep darwin compatibility
|
||||||
@@ -41,7 +40,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
|||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||||
addr := "100.64.0.1/8"
|
addr := "100.64.0.1/8"
|
||||||
wgPort := 33100
|
wgPort := 33100
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -124,7 +123,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
|
|||||||
func Test_CreateInterface(t *testing.T) {
|
func Test_CreateInterface(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
|
||||||
wgIP := "10.99.99.1/32"
|
wgIP := "10.99.99.1/32"
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -167,7 +166,7 @@ func Test_Close(t *testing.T) {
|
|||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||||
wgIP := "10.99.99.2/32"
|
wgIP := "10.99.99.2/32"
|
||||||
wgPort := 33100
|
wgPort := 33100
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -212,7 +211,7 @@ func TestRecreation(t *testing.T) {
|
|||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||||
wgIP := "10.99.99.2/32"
|
wgIP := "10.99.99.2/32"
|
||||||
wgPort := 33100
|
wgPort := 33100
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -285,7 +284,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
|||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
|
||||||
wgIP := "10.99.99.5/30"
|
wgIP := "10.99.99.5/30"
|
||||||
wgPort := 33100
|
wgPort := 33100
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -340,7 +339,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
|||||||
func Test_UpdatePeer(t *testing.T) {
|
func Test_UpdatePeer(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||||
wgIP := "10.99.99.9/30"
|
wgIP := "10.99.99.9/30"
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -410,7 +409,7 @@ func Test_UpdatePeer(t *testing.T) {
|
|||||||
func Test_RemovePeer(t *testing.T) {
|
func Test_RemovePeer(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||||
wgIP := "10.99.99.13/30"
|
wgIP := "10.99.99.13/30"
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -472,7 +471,7 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
peer2wgPort := 33200
|
peer2wgPort := 33200
|
||||||
|
|
||||||
keepAlive := 1 * time.Second
|
keepAlive := 1 * time.Second
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -515,7 +514,7 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
guid = fmt.Sprintf("{%s}", uuid.New().String())
|
guid = fmt.Sprintf("{%s}", uuid.New().String())
|
||||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||||
|
|
||||||
newNet, err = stdnet.NewNet(context.Background(), nil)
|
newNet, err = stdnet.NewNet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package udpmux
|
package udpmux
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@@ -13,9 +12,8 @@ import (
|
|||||||
"github.com/pion/logging"
|
"github.com/pion/logging"
|
||||||
"github.com/pion/stun/v3"
|
"github.com/pion/stun/v3"
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
|
"github.com/pion/transport/v3/stdnet"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -201,7 +199,7 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() {
|
|||||||
if len(networks) > 0 {
|
if len(networks) > 0 {
|
||||||
if m.params.Net == nil {
|
if m.params.Net == nil {
|
||||||
var err error
|
var err error
|
||||||
if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil {
|
if m.params.Net, err = stdnet.NewNet(); err != nil {
|
||||||
m.params.Logger.Errorf("failed to get create network: %v", err)
|
m.params.Logger.Errorf("failed to get create network: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -128,34 +128,9 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow
|
|||||||
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
||||||
}
|
}
|
||||||
|
|
||||||
if d.providerConfig.LoginHint != "" {
|
|
||||||
deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint)
|
|
||||||
if deviceCode.VerificationURI != "" {
|
|
||||||
deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return deviceCode, err
|
return deviceCode, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func appendLoginHint(uri, loginHint string) string {
|
|
||||||
if uri == "" || loginHint == "" {
|
|
||||||
return uri
|
|
||||||
}
|
|
||||||
|
|
||||||
parsedURL, err := url.Parse(uri)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to parse verification URI for login_hint: %v", err)
|
|
||||||
return uri
|
|
||||||
}
|
|
||||||
|
|
||||||
query := parsedURL.Query()
|
|
||||||
query.Set("login_hint", loginHint)
|
|
||||||
parsedURL.RawQuery = query.Encode()
|
|
||||||
|
|
||||||
return parsedURL.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
|
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
|
||||||
form := url.Values{}
|
form := url.Values{}
|
||||||
form.Add("client_id", d.providerConfig.ClientID)
|
form.Add("client_id", d.providerConfig.ClientID)
|
||||||
|
|||||||
@@ -66,34 +66,32 @@ func (t TokenInfo) GetTokenToUse() string {
|
|||||||
// and if that also fails, the authentication process is deemed unsuccessful
|
// and if that also fails, the authentication process is deemed unsuccessful
|
||||||
//
|
//
|
||||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) {
|
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
|
||||||
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint)
|
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// fallback to device code flow
|
||||||
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
|
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
|
||||||
log.Debug("falling back to device code flow")
|
log.Debug("falling back to device code flow")
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||||
}
|
}
|
||||||
return pkceFlow, nil
|
return pkceFlow, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
|
||||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
|
||||||
|
|
||||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
|
||||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch s, ok := gstatus.FromError(err); {
|
switch s, ok := gstatus.FromError(err); {
|
||||||
@@ -109,7 +107,5 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
|
||||||
|
|
||||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -109,9 +109,6 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
|||||||
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if p.providerConfig.LoginHint != "" {
|
|
||||||
params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint))
|
|
||||||
}
|
|
||||||
|
|
||||||
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||||
|
|
||||||
@@ -192,20 +189,17 @@ func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token,
|
|||||||
|
|
||||||
if authError := query.Get(queryError); authError != "" {
|
if authError := query.Get(queryError); authError != "" {
|
||||||
authErrorDesc := query.Get(queryErrorDesc)
|
authErrorDesc := query.Get(queryErrorDesc)
|
||||||
if authErrorDesc != "" {
|
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
||||||
return nil, fmt.Errorf("authentication failed: %s", authErrorDesc)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("authentication failed: %s", authError)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prevent timing attacks on the state
|
// Prevent timing attacks on the state
|
||||||
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
||||||
return nil, fmt.Errorf("authentication failed: Invalid state")
|
return nil, fmt.Errorf("invalid state")
|
||||||
}
|
}
|
||||||
|
|
||||||
code := query.Get(queryCode)
|
code := query.Get(queryCode)
|
||||||
if code == "" {
|
if code == "" {
|
||||||
return nil, fmt.Errorf("authentication failed: missing code")
|
return nil, fmt.Errorf("missing code")
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.oAuthConfig.Exchange(
|
return p.oAuthConfig.Exchange(
|
||||||
@@ -234,7 +228,7 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
|
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
|
||||||
return TokenInfo{}, fmt.Errorf("authentication failed: invalid access token - %w", err)
|
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
email, err := parseEmailFromIDToken(tokenInfo.IDToken)
|
email, err := parseEmailFromIDToken(tokenInfo.IDToken)
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ func NewConnectClient(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
config *profilemanager.Config,
|
config *profilemanager.Config,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
|
|
||||||
) *ConnectClient {
|
) *ConnectClient {
|
||||||
return &ConnectClient{
|
return &ConnectClient{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@@ -62,8 +63,8 @@ func NewConnectClient(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run with main logic.
|
// Run with main logic.
|
||||||
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
func (c *ConnectClient) Run(runningChan chan struct{}) error {
|
||||||
return c.run(MobileDependency{}, runningChan, logPath)
|
return c.run(MobileDependency{}, runningChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunOnAndroid with main logic on mobile system
|
// RunOnAndroid with main logic on mobile system
|
||||||
@@ -82,7 +83,7 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
HostDNSAddresses: dnsAddresses,
|
HostDNSAddresses: dnsAddresses,
|
||||||
DnsReadyListener: dnsReadyListener,
|
DnsReadyListener: dnsReadyListener,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil, "")
|
return c.run(mobileDependency, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) RunOniOS(
|
func (c *ConnectClient) RunOniOS(
|
||||||
@@ -100,10 +101,10 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
DnsManager: dnsManager,
|
DnsManager: dnsManager,
|
||||||
StateFilePath: stateFilePath,
|
StateFilePath: stateFilePath,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil, "")
|
return c.run(mobileDependency, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
|
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
rec := c.statusRecorder
|
rec := c.statusRecorder
|
||||||
@@ -246,7 +247,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
relayURLs, token := parseRelayInfo(loginResp)
|
relayURLs, token := parseRelayInfo(loginResp)
|
||||||
peerConfig := loginResp.GetPeerConfig()
|
peerConfig := loginResp.GetPeerConfig()
|
||||||
|
|
||||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
@@ -270,7 +271,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
checks := loginResp.GetChecks()
|
checks := loginResp.GetChecks()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, c.config)
|
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
||||||
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
@@ -409,7 +410,7 @@ func (c *ConnectClient) SetSyncResponsePersistence(enabled bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||||
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig, logPath string) (*EngineConfig, error) {
|
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
||||||
nm := false
|
nm := false
|
||||||
if config.NetworkMonitor != nil {
|
if config.NetworkMonitor != nil {
|
||||||
nm = *config.NetworkMonitor
|
nm = *config.NetworkMonitor
|
||||||
@@ -444,10 +445,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
|||||||
|
|
||||||
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||||
|
|
||||||
MTU: selectMTU(config.MTU, peerConfig.Mtu),
|
MTU: selectMTU(config.MTU, peerConfig.Mtu),
|
||||||
LogPath: logPath,
|
|
||||||
|
|
||||||
ProfileConfig: config,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.PreSharedKey != "" {
|
if config.PreSharedKey != "" {
|
||||||
|
|||||||
@@ -27,10 +27,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/anonymize"
|
"github.com/netbirdio/netbird/client/anonymize"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
"github.com/netbirdio/netbird/version"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const readmeContent = `Netbird debug bundle
|
const readmeContent = `Netbird debug bundle
|
||||||
@@ -46,8 +44,6 @@ interfaces.txt: Anonymized network interface information, if --system-info flag
|
|||||||
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
|
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
|
||||||
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
|
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
|
||||||
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
|
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
|
||||||
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
|
|
||||||
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
|
|
||||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||||
config.txt: Anonymized configuration information of the NetBird client.
|
config.txt: Anonymized configuration information of the NetBird client.
|
||||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||||
@@ -188,20 +184,6 @@ The ip_rules.txt file contains detailed IP routing rule information:
|
|||||||
The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing.
|
The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing.
|
||||||
|
|
||||||
For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged.
|
For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged.
|
||||||
|
|
||||||
DNS Configuration
|
|
||||||
The debug bundle includes platform-specific DNS configuration files:
|
|
||||||
|
|
||||||
resolv.conf (Unix systems):
|
|
||||||
- Contains DNS resolver configuration from /etc/resolv.conf
|
|
||||||
- Includes nameserver entries, search domains, and resolver options
|
|
||||||
- All IP addresses and domain names are anonymized following the same rules as other files
|
|
||||||
|
|
||||||
scutil_dns.txt (macOS only):
|
|
||||||
- Contains detailed DNS configuration from scutil --dns
|
|
||||||
- Shows DNS configuration for all network interfaces
|
|
||||||
- Includes search domains, nameservers, and DNS resolver settings
|
|
||||||
- All IP addresses and domain names are anonymized
|
|
||||||
`
|
`
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -220,9 +202,10 @@ type BundleGenerator struct {
|
|||||||
internalConfig *profilemanager.Config
|
internalConfig *profilemanager.Config
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
syncResponse *mgmProto.SyncResponse
|
syncResponse *mgmProto.SyncResponse
|
||||||
logPath string
|
logFile string
|
||||||
|
|
||||||
anonymize bool
|
anonymize bool
|
||||||
|
clientStatus string
|
||||||
includeSystemInfo bool
|
includeSystemInfo bool
|
||||||
logFileCount uint32
|
logFileCount uint32
|
||||||
|
|
||||||
@@ -231,6 +214,7 @@ type BundleGenerator struct {
|
|||||||
|
|
||||||
type BundleConfig struct {
|
type BundleConfig struct {
|
||||||
Anonymize bool
|
Anonymize bool
|
||||||
|
ClientStatus string
|
||||||
IncludeSystemInfo bool
|
IncludeSystemInfo bool
|
||||||
LogFileCount uint32
|
LogFileCount uint32
|
||||||
}
|
}
|
||||||
@@ -239,7 +223,7 @@ type GeneratorDependencies struct {
|
|||||||
InternalConfig *profilemanager.Config
|
InternalConfig *profilemanager.Config
|
||||||
StatusRecorder *peer.Status
|
StatusRecorder *peer.Status
|
||||||
SyncResponse *mgmProto.SyncResponse
|
SyncResponse *mgmProto.SyncResponse
|
||||||
LogPath string
|
LogFile string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
||||||
@@ -255,9 +239,10 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
|||||||
internalConfig: deps.InternalConfig,
|
internalConfig: deps.InternalConfig,
|
||||||
statusRecorder: deps.StatusRecorder,
|
statusRecorder: deps.StatusRecorder,
|
||||||
syncResponse: deps.SyncResponse,
|
syncResponse: deps.SyncResponse,
|
||||||
logPath: deps.LogPath,
|
logFile: deps.LogFile,
|
||||||
|
|
||||||
anonymize: cfg.Anonymize,
|
anonymize: cfg.Anonymize,
|
||||||
|
clientStatus: cfg.ClientStatus,
|
||||||
includeSystemInfo: cfg.IncludeSystemInfo,
|
includeSystemInfo: cfg.IncludeSystemInfo,
|
||||||
logFileCount: logFileCount,
|
logFileCount: logFileCount,
|
||||||
}
|
}
|
||||||
@@ -303,6 +288,13 @@ func (g *BundleGenerator) createArchive() error {
|
|||||||
return fmt.Errorf("add status: %w", err)
|
return fmt.Errorf("add status: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if g.statusRecorder != nil {
|
||||||
|
status := g.statusRecorder.GetFullStatus()
|
||||||
|
seedFromStatus(g.anonymizer, &status)
|
||||||
|
} else {
|
||||||
|
log.Debugf("no status recorder available for seeding")
|
||||||
|
}
|
||||||
|
|
||||||
if err := g.addConfig(); err != nil {
|
if err := g.addConfig(); err != nil {
|
||||||
log.Errorf("failed to add config to debug bundle: %v", err)
|
log.Errorf("failed to add config to debug bundle: %v", err)
|
||||||
}
|
}
|
||||||
@@ -335,7 +327,7 @@ func (g *BundleGenerator) createArchive() error {
|
|||||||
log.Errorf("failed to add wg show output: %v", err)
|
log.Errorf("failed to add wg show output: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
if g.logFile != "" && !slices.Contains(util.SpecialLogs, g.logFile) {
|
||||||
if err := g.addLogfile(); err != nil {
|
if err := g.addLogfile(); err != nil {
|
||||||
log.Errorf("failed to add log file to debug bundle: %v", err)
|
log.Errorf("failed to add log file to debug bundle: %v", err)
|
||||||
if err := g.trySystemdLogFallback(); err != nil {
|
if err := g.trySystemdLogFallback(); err != nil {
|
||||||
@@ -365,10 +357,6 @@ func (g *BundleGenerator) addSystemInfo() {
|
|||||||
if err := g.addFirewallRules(); err != nil {
|
if err := g.addFirewallRules(); err != nil {
|
||||||
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
|
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.addDNSInfo(); err != nil {
|
|
||||||
log.Errorf("failed to add DNS info to debug bundle: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addReadme() error {
|
func (g *BundleGenerator) addReadme() error {
|
||||||
@@ -380,26 +368,11 @@ func (g *BundleGenerator) addReadme() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addStatus() error {
|
func (g *BundleGenerator) addStatus() error {
|
||||||
if g.statusRecorder != nil {
|
if status := g.clientStatus; status != "" {
|
||||||
pm := profilemanager.NewProfileManager()
|
statusReader := strings.NewReader(status)
|
||||||
var profName string
|
|
||||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
|
||||||
profName = activeProf.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
fullStatus := g.statusRecorder.GetFullStatus()
|
|
||||||
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
|
||||||
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
|
|
||||||
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
|
|
||||||
statusOutput := nbstatus.ParseToFullDetailSummary(overview)
|
|
||||||
|
|
||||||
statusReader := strings.NewReader(statusOutput)
|
|
||||||
if err := g.addFileToZip(statusReader, "status.txt"); err != nil {
|
if err := g.addFileToZip(statusReader, "status.txt"); err != nil {
|
||||||
return fmt.Errorf("add status file to zip: %w", err)
|
return fmt.Errorf("add status file to zip: %w", err)
|
||||||
}
|
}
|
||||||
seedFromStatus(g.anonymizer, &fullStatus)
|
|
||||||
} else {
|
|
||||||
log.Debugf("no status recorder available for seeding")
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -669,14 +642,14 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addLogfile() error {
|
func (g *BundleGenerator) addLogfile() error {
|
||||||
if g.logPath == "" {
|
if g.logFile == "" {
|
||||||
log.Debugf("skipping empty log file in debug bundle")
|
log.Debugf("skipping empty log file in debug bundle")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
logDir := filepath.Dir(g.logPath)
|
logDir := filepath.Dir(g.logFile)
|
||||||
|
|
||||||
if err := g.addSingleLogfile(g.logPath, clientLogFile); err != nil {
|
if err := g.addSingleLogfile(g.logFile, clientLogFile); err != nil {
|
||||||
return fmt.Errorf("add client log file to zip: %w", err)
|
return fmt.Errorf("add client log file to zip: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,53 +0,0 @@
|
|||||||
//go:build darwin && !ios
|
|
||||||
|
|
||||||
package debug
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os/exec"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// addDNSInfo collects and adds DNS configuration information to the archive
|
|
||||||
func (g *BundleGenerator) addDNSInfo() error {
|
|
||||||
if err := g.addResolvConf(); err != nil {
|
|
||||||
log.Errorf("failed to add resolv.conf: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := g.addScutilDNS(); err != nil {
|
|
||||||
log.Errorf("failed to add scutil DNS output: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *BundleGenerator) addScutilDNS() error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
cmd := exec.CommandContext(ctx, "scutil", "--dns")
|
|
||||||
output, err := cmd.CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("execute scutil --dns: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(bytes.TrimSpace(output)) == 0 {
|
|
||||||
return fmt.Errorf("no scutil DNS output")
|
|
||||||
}
|
|
||||||
|
|
||||||
content := string(output)
|
|
||||||
if g.anonymize {
|
|
||||||
content = g.anonymizer.AnonymizeString(content)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := g.addFileToZip(strings.NewReader(content), "scutil_dns.txt"); err != nil {
|
|
||||||
return fmt.Errorf("add scutil DNS output to zip: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -5,7 +5,3 @@ package debug
|
|||||||
func (g *BundleGenerator) addRoutes() error {
|
func (g *BundleGenerator) addRoutes() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addDNSInfo() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
//go:build unix && !darwin && !android
|
|
||||||
|
|
||||||
package debug
|
|
||||||
|
|
||||||
import (
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// addDNSInfo collects and adds DNS configuration information to the archive
|
|
||||||
func (g *BundleGenerator) addDNSInfo() error {
|
|
||||||
if err := g.addResolvConf(); err != nil {
|
|
||||||
log.Errorf("failed to add resolv.conf: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
//go:build !unix
|
|
||||||
|
|
||||||
package debug
|
|
||||||
|
|
||||||
func (g *BundleGenerator) addDNSInfo() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
//go:build unix && !android
|
|
||||||
|
|
||||||
package debug
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
const resolvConfPath = "/etc/resolv.conf"
|
|
||||||
|
|
||||||
func (g *BundleGenerator) addResolvConf() error {
|
|
||||||
data, err := os.ReadFile(resolvConfPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read %s: %w", resolvConfPath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
content := string(data)
|
|
||||||
if g.anonymize {
|
|
||||||
content = g.anonymizer.AnonymizeString(content)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := g.addFileToZip(strings.NewReader(content), "resolv.conf"); err != nil {
|
|
||||||
return fmt.Errorf("add resolv.conf to zip: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
package debug
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/upload-server/types"
|
|
||||||
)
|
|
||||||
|
|
||||||
const maxBundleUploadSize = 50 * 1024 * 1024
|
|
||||||
|
|
||||||
func UploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
|
|
||||||
response, err := getUploadURL(ctx, url, managementURL)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = upload(ctx, filePath, response)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return response.Key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
|
|
||||||
fileData, err := os.Open(filePath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("open file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer fileData.Close()
|
|
||||||
|
|
||||||
stat, err := fileData.Stat()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("stat file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if stat.Size() > maxBundleUploadSize {
|
|
||||||
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create PUT request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.ContentLength = stat.Size()
|
|
||||||
req.Header.Set("Content-Type", "application/octet-stream")
|
|
||||||
|
|
||||||
putResp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("upload failed: %v", err)
|
|
||||||
}
|
|
||||||
defer putResp.Body.Close()
|
|
||||||
|
|
||||||
if putResp.StatusCode != http.StatusOK {
|
|
||||||
body, _ := io.ReadAll(putResp.Body)
|
|
||||||
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
|
|
||||||
id := getURLHash(managementURL)
|
|
||||||
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create GET request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
|
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(getReq)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get presigned URL: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
|
||||||
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
urlBytes, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("read response body: %w", err)
|
|
||||||
}
|
|
||||||
var response types.GetURLResponse
|
|
||||||
if err := json.Unmarshal(urlBytes, &response); err != nil {
|
|
||||||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
|
||||||
}
|
|
||||||
return &response, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getURLHash(url string) string {
|
|
||||||
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
|
|
||||||
}
|
|
||||||
@@ -38,8 +38,6 @@ type DeviceAuthProviderConfig struct {
|
|||||||
Scope string
|
Scope string
|
||||||
// UseIDToken indicates if the id token should be used for authentication
|
// UseIDToken indicates if the id token should be used for authentication
|
||||||
UseIDToken bool
|
UseIDToken bool
|
||||||
// LoginHint is used to pre-fill the email/username field during authentication
|
|
||||||
LoginHint string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
||||||
|
|||||||
@@ -335,7 +335,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
for n, testCase := range testCases {
|
for n, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
privKey, _ := wgtypes.GenerateKey()
|
privKey, _ := wgtypes.GenerateKey()
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -434,7 +434,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create stdnet: %v", err)
|
t.Errorf("create stdnet: %v", err)
|
||||||
return
|
return
|
||||||
@@ -915,7 +915,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|||||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create stdnet: %v", err)
|
t.Fatalf("create stdnet: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/debug"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
@@ -49,7 +48,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/jobexec"
|
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
@@ -134,11 +132,6 @@ type EngineConfig struct {
|
|||||||
LazyConnectionEnabled bool
|
LazyConnectionEnabled bool
|
||||||
|
|
||||||
MTU uint16
|
MTU uint16
|
||||||
|
|
||||||
// for debug bundle generation
|
|
||||||
ProfileConfig *profilemanager.Config
|
|
||||||
|
|
||||||
LogPath string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||||
@@ -202,8 +195,7 @@ type Engine struct {
|
|||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
srWatcher *guard.SRWatcher
|
srWatcher *guard.SRWatcher
|
||||||
|
|
||||||
// Sync response persistence (protected by syncRespMux)
|
// Sync response persistence
|
||||||
syncRespMux sync.RWMutex
|
|
||||||
persistSyncResponse bool
|
persistSyncResponse bool
|
||||||
latestSyncResponse *mgmProto.SyncResponse
|
latestSyncResponse *mgmProto.SyncResponse
|
||||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||||
@@ -216,9 +208,6 @@ type Engine struct {
|
|||||||
shutdownWg sync.WaitGroup
|
shutdownWg sync.WaitGroup
|
||||||
|
|
||||||
probeStunTurn *relay.StunTurnProbe
|
probeStunTurn *relay.StunTurnProbe
|
||||||
|
|
||||||
jobExecutor *jobexec.Executor
|
|
||||||
jobExecutorWG sync.WaitGroup
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@@ -232,7 +221,17 @@ type localIpUpdater interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewEngine creates a new Connection Engine with probes attached
|
// NewEngine creates a new Connection Engine with probes attached
|
||||||
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, c *profilemanager.Config) *Engine {
|
func NewEngine(
|
||||||
|
clientCtx context.Context,
|
||||||
|
clientCancel context.CancelFunc,
|
||||||
|
signalClient signal.Client,
|
||||||
|
mgmClient mgm.Client,
|
||||||
|
relayManager *relayClient.Manager,
|
||||||
|
config *EngineConfig,
|
||||||
|
mobileDep MobileDependency,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
checks []*mgmProto.Checks,
|
||||||
|
) *Engine {
|
||||||
engine := &Engine{
|
engine := &Engine{
|
||||||
clientCtx: clientCtx,
|
clientCtx: clientCtx,
|
||||||
clientCancel: clientCancel,
|
clientCancel: clientCancel,
|
||||||
@@ -251,7 +250,6 @@ func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signa
|
|||||||
checks: checks,
|
checks: checks,
|
||||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||||
jobExecutor: jobexec.NewExecutor(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sm := profilemanager.NewServiceManager("")
|
sm := profilemanager.NewServiceManager("")
|
||||||
@@ -332,8 +330,6 @@ func (e *Engine) Stop() error {
|
|||||||
e.cancel()
|
e.cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
e.jobExecutorWG.Wait() // block until job goroutines finish
|
|
||||||
|
|
||||||
e.close()
|
e.close()
|
||||||
|
|
||||||
// stop flow manager after wg interface is gone
|
// stop flow manager after wg interface is gone
|
||||||
@@ -520,7 +516,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
|
|
||||||
e.receiveSignalEvents()
|
e.receiveSignalEvents()
|
||||||
e.receiveManagementEvents()
|
e.receiveManagementEvents()
|
||||||
e.receiveJobEvents()
|
|
||||||
|
|
||||||
// starting network monitor at the very last to avoid disruptions
|
// starting network monitor at the very last to avoid disruptions
|
||||||
e.startNetworkMonitor()
|
e.startNetworkMonitor()
|
||||||
@@ -798,18 +793,9 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
|
|
||||||
// Read the storage-enabled flag under the syncRespMux too.
|
|
||||||
e.syncRespMux.RLock()
|
|
||||||
enabled := e.persistSyncResponse
|
|
||||||
e.syncRespMux.RUnlock()
|
|
||||||
|
|
||||||
// Store sync response if persistence is enabled
|
// Store sync response if persistence is enabled
|
||||||
if enabled {
|
if e.persistSyncResponse {
|
||||||
e.syncRespMux.Lock()
|
|
||||||
e.latestSyncResponse = update
|
e.latestSyncResponse = update
|
||||||
e.syncRespMux.Unlock()
|
|
||||||
|
|
||||||
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
|
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -939,77 +925,6 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (e *Engine) receiveJobEvents() {
|
|
||||||
e.jobExecutorWG.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer e.jobExecutorWG.Done()
|
|
||||||
err := e.mgmClient.Job(e.ctx, func(msg *mgmProto.JobRequest) *mgmProto.JobResponse {
|
|
||||||
resp := mgmProto.JobResponse{
|
|
||||||
ID: msg.ID,
|
|
||||||
Status: mgmProto.JobStatus_failed,
|
|
||||||
}
|
|
||||||
switch params := msg.WorkloadParameters.(type) {
|
|
||||||
case *mgmProto.JobRequest_Bundle:
|
|
||||||
bundleResult, err := e.handleBundle(params.Bundle)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("handling bundle: %v", err)
|
|
||||||
resp.Reason = []byte(err.Error())
|
|
||||||
return &resp
|
|
||||||
}
|
|
||||||
resp.Status = mgmProto.JobStatus_succeeded
|
|
||||||
resp.WorkloadResults = bundleResult
|
|
||||||
return &resp
|
|
||||||
default:
|
|
||||||
resp.Reason = []byte(jobexec.ErrJobNotImplemented.Error())
|
|
||||||
return &resp
|
|
||||||
}
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
// happens if management is unavailable for a long time.
|
|
||||||
// We want to cancel the operation of the whole client
|
|
||||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
|
||||||
e.clientCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Info("stopped receiving jobs from Management Service")
|
|
||||||
}()
|
|
||||||
log.Info("connecting to Management Service jobs stream")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobResponse_Bundle, error) {
|
|
||||||
log.Infof("handle remote debug bundle request: %s", params.String())
|
|
||||||
syncResponse, err := e.GetLatestSyncResponse()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("get latest sync response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
bundleDeps := debug.GeneratorDependencies{
|
|
||||||
InternalConfig: e.config.ProfileConfig,
|
|
||||||
StatusRecorder: e.statusRecorder,
|
|
||||||
SyncResponse: syncResponse,
|
|
||||||
LogPath: e.config.LogPath,
|
|
||||||
}
|
|
||||||
|
|
||||||
bundleJobParams := debug.BundleConfig{
|
|
||||||
Anonymize: params.Anonymize,
|
|
||||||
IncludeSystemInfo: true,
|
|
||||||
LogFileCount: uint32(params.LogFileCount),
|
|
||||||
}
|
|
||||||
|
|
||||||
waitFor := time.Duration(params.BundleForTime) * time.Minute
|
|
||||||
|
|
||||||
uploadKey, err := e.jobExecutor.BundleJob(e.ctx, bundleDeps, bundleJobParams, waitFor, e.config.ProfileConfig.ManagementURL.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
response := &mgmProto.JobResponse_Bundle{
|
|
||||||
Bundle: &mgmProto.BundleResult{
|
|
||||||
UploadKey: uploadKey,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return response, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
||||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||||
@@ -1277,7 +1192,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
|
|||||||
}
|
}
|
||||||
|
|
||||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
||||||
forwarderPort := uint16(protoDNSConfig.GetForwarderPort()) //nolint
|
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
|
||||||
if forwarderPort == 0 {
|
if forwarderPort == 0 {
|
||||||
forwarderPort = nbdns.ForwarderClientPort
|
forwarderPort = nbdns.ForwarderClientPort
|
||||||
}
|
}
|
||||||
@@ -1870,8 +1785,8 @@ func (e *Engine) stopDNSServer() {
|
|||||||
|
|
||||||
// SetSyncResponsePersistence enables or disables sync response persistence
|
// SetSyncResponsePersistence enables or disables sync response persistence
|
||||||
func (e *Engine) SetSyncResponsePersistence(enabled bool) {
|
func (e *Engine) SetSyncResponsePersistence(enabled bool) {
|
||||||
e.syncRespMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncRespMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
if enabled == e.persistSyncResponse {
|
if enabled == e.persistSyncResponse {
|
||||||
return
|
return
|
||||||
@@ -1886,22 +1801,20 @@ func (e *Engine) SetSyncResponsePersistence(enabled bool) {
|
|||||||
|
|
||||||
// GetLatestSyncResponse returns the stored sync response if persistence is enabled
|
// GetLatestSyncResponse returns the stored sync response if persistence is enabled
|
||||||
func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||||
e.syncRespMux.RLock()
|
e.syncMsgMux.Lock()
|
||||||
enabled := e.persistSyncResponse
|
defer e.syncMsgMux.Unlock()
|
||||||
latest := e.latestSyncResponse
|
|
||||||
e.syncRespMux.RUnlock()
|
|
||||||
|
|
||||||
if !enabled {
|
if !e.persistSyncResponse {
|
||||||
return nil, errors.New("sync response persistence is disabled")
|
return nil, errors.New("sync response persistence is disabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
if latest == nil {
|
if e.latestSyncResponse == nil {
|
||||||
//nolint:nilnil
|
//nolint:nilnil
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(latest))
|
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(e.latestSyncResponse))
|
||||||
sr, ok := proto.Clone(latest).(*mgmProto.SyncResponse)
|
sr, ok := proto.Clone(e.latestSyncResponse).(*mgmProto.SyncResponse)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("failed to clone sync response")
|
return nil, fmt.Errorf("failed to clone sync response")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
type sshServer interface {
|
type sshServer interface {
|
||||||
Start(ctx context.Context, addr netip.AddrPort) error
|
Start(ctx context.Context, addr netip.AddrPort) error
|
||||||
Stop() error
|
Stop() error
|
||||||
GetStatus() (bool, []sshserver.SessionInfo)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) setupSSHPortRedirection() error {
|
func (e *Engine) setupSSHPortRedirection() error {
|
||||||
@@ -235,17 +234,7 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
|
|||||||
|
|
||||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||||
server.SetNetstackNet(netstackNet)
|
server.SetNetstackNet(netstackNet)
|
||||||
}
|
|
||||||
|
|
||||||
e.configureSSHServer(server)
|
|
||||||
|
|
||||||
if err := server.Start(e.ctx, listenAddr); err != nil {
|
|
||||||
return fmt.Errorf("start SSH server: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.sshServer = server
|
|
||||||
|
|
||||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
|
||||||
if registrar, ok := e.firewall.(interface {
|
if registrar, ok := e.firewall.(interface {
|
||||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||||
}); ok {
|
}); ok {
|
||||||
@@ -254,10 +243,17 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.configureSSHServer(server)
|
||||||
|
e.sshServer = server
|
||||||
|
|
||||||
if err := e.setupSSHPortRedirection(); err != nil {
|
if err := e.setupSSHPortRedirection(); err != nil {
|
||||||
log.Warnf("failed to setup SSH port redirection: %v", err)
|
log.Warnf("failed to setup SSH port redirection: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := server.Start(e.ctx, listenAddr); err != nil {
|
||||||
|
return fmt.Errorf("start SSH server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -340,16 +336,3 @@ func (e *Engine) stopSSHServer() error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSSHServerStatus returns the SSH server status and active sessions
|
|
||||||
func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.SessionInfo) {
|
|
||||||
e.syncMsgMux.Lock()
|
|
||||||
sshServer := e.sshServer
|
|
||||||
e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
if sshServer == nil {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return sshServer.GetStatus()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -7,5 +7,5 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||||
return stdnet.NewNet(e.clientCtx, e.config.IFaceBlackList)
|
return stdnet.NewNet(e.config.IFaceBlackList)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,5 +3,5 @@ package internal
|
|||||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
|
||||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||||
return stdnet.NewNetWithDiscover(e.clientCtx, e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
|
return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/pion/transport/v3/stdnet"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -24,15 +25,8 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||||
@@ -249,7 +243,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
},
|
},
|
||||||
MobileDependency{},
|
MobileDependency{},
|
||||||
peer.NewRecorder("https://mgm"),
|
peer.NewRecorder("https://mgm"),
|
||||||
nil, nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
@@ -411,13 +405,21 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
engine := NewEngine(
|
||||||
WgIfaceName: "utun102",
|
ctx, cancel,
|
||||||
WgAddr: "100.64.0.1/24",
|
&signal.MockClient{},
|
||||||
WgPrivateKey: key,
|
&mgmt.MockClient{},
|
||||||
WgPort: 33100,
|
relayMgr,
|
||||||
MTU: iface.DefaultMTU,
|
&EngineConfig{
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
WgIfaceName: "utun102",
|
||||||
|
WgAddr: "100.64.0.1/24",
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: 33100,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
},
|
||||||
|
MobileDependency{},
|
||||||
|
peer.NewRecorder("https://mgm"),
|
||||||
|
nil)
|
||||||
|
|
||||||
wgIface := &MockWGIface{
|
wgIface := &MockWGIface{
|
||||||
NameFunc: func() string { return "utun102" },
|
NameFunc: func() string { return "utun102" },
|
||||||
@@ -636,7 +638,7 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
@@ -801,9 +803,9 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -1003,10 +1005,10 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
|
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -1529,7 +1531,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
|
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
|
||||||
e.ctx = ctx
|
e.ctx = ctx
|
||||||
return e, err
|
return e, err
|
||||||
}
|
}
|
||||||
@@ -1588,7 +1590,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
}
|
}
|
||||||
t.Cleanup(cleanUp)
|
t.Cleanup(cleanUp)
|
||||||
|
|
||||||
jobManager := job.NewJobManager(nil, store)
|
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
@@ -1616,16 +1618,13 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
|
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
accountManager, err := server.BuildManager(context.Background(), config, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
|
||||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, jobManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
|
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
|
|||||||
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
|
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
|
||||||
log.Debugf("Gathering ICE candidates")
|
log.Debugf("Gathering ICE candidates")
|
||||||
|
|
||||||
agent, err := icemaker.NewAgent(ctx, cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
|
agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("create ICE agent: %w", err)
|
return false, fmt.Errorf("create ICE agent: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package ice
|
package ice
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -23,8 +22,6 @@ const (
|
|||||||
iceFailedTimeoutDefault = 6 * time.Second
|
iceFailedTimeoutDefault = 6 * time.Second
|
||||||
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
|
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
|
||||||
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
|
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
|
||||||
// iceAgentCloseTimeout is the maximum time to wait for ICE agent close to complete
|
|
||||||
iceAgentCloseTimeout = 3 * time.Second
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ThreadSafeAgent struct {
|
type ThreadSafeAgent struct {
|
||||||
@@ -35,28 +32,18 @@ type ThreadSafeAgent struct {
|
|||||||
func (a *ThreadSafeAgent) Close() error {
|
func (a *ThreadSafeAgent) Close() error {
|
||||||
var err error
|
var err error
|
||||||
a.once.Do(func() {
|
a.once.Do(func() {
|
||||||
done := make(chan error, 1)
|
err = a.Agent.Close()
|
||||||
go func() {
|
|
||||||
done <- a.Agent.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err = <-done:
|
|
||||||
case <-time.After(iceAgentCloseTimeout):
|
|
||||||
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
||||||
iceKeepAlive := iceKeepAlive()
|
iceKeepAlive := iceKeepAlive()
|
||||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||||
iceFailedTimeout := iceFailedTimeout()
|
iceFailedTimeout := iceFailedTimeout()
|
||||||
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
|
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
|
||||||
|
|
||||||
transportNet, err := newStdNet(ctx, iFaceDiscover, config.InterfaceBlackList)
|
transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,11 +3,9 @@
|
|||||||
package ice
|
package ice
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newStdNet(ctx context.Context, _ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||||
return stdnet.NewNet(ctx, ifaceBlacklist)
|
return stdnet.NewNet(ifaceBlacklist)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,7 @@
|
|||||||
package ice
|
package ice
|
||||||
|
|
||||||
import (
|
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||||
)
|
return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist)
|
||||||
|
|
||||||
func newStdNet(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
|
||||||
return stdnet.NewNetWithDiscover(ctx, iFaceDiscover, ifaceBlacklist)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -209,7 +209,7 @@ func (w *WorkerICE) Close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
|
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
|
||||||
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create agent: %w", err)
|
return nil, fmt.Errorf("create agent: %w", err)
|
||||||
}
|
}
|
||||||
@@ -411,7 +411,7 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
|||||||
|
|
||||||
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
|
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
|
||||||
if isController(w.config) {
|
if isController(w.config) {
|
||||||
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||||
} else {
|
} else {
|
||||||
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,8 +44,6 @@ type PKCEAuthProviderConfig struct {
|
|||||||
DisablePromptLogin bool
|
DisablePromptLogin bool
|
||||||
// LoginFlag is used to configure the PKCE flow login behavior
|
// LoginFlag is used to configure the PKCE flow login behavior
|
||||||
LoginFlag common.LoginFlag
|
LoginFlag common.LoginFlag
|
||||||
// LoginHint is used to pre-fill the email/username field during authentication
|
|
||||||
LoginHint string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||||
|
|||||||
@@ -55,7 +55,6 @@ type ConfigInput struct {
|
|||||||
EnableSSHLocalPortForwarding *bool
|
EnableSSHLocalPortForwarding *bool
|
||||||
EnableSSHRemotePortForwarding *bool
|
EnableSSHRemotePortForwarding *bool
|
||||||
DisableSSHAuth *bool
|
DisableSSHAuth *bool
|
||||||
SSHJWTCacheTTL *int
|
|
||||||
NATExternalIPs []string
|
NATExternalIPs []string
|
||||||
CustomDNSAddress []byte
|
CustomDNSAddress []byte
|
||||||
RosenpassEnabled *bool
|
RosenpassEnabled *bool
|
||||||
@@ -105,7 +104,6 @@ type Config struct {
|
|||||||
EnableSSHLocalPortForwarding *bool
|
EnableSSHLocalPortForwarding *bool
|
||||||
EnableSSHRemotePortForwarding *bool
|
EnableSSHRemotePortForwarding *bool
|
||||||
DisableSSHAuth *bool
|
DisableSSHAuth *bool
|
||||||
SSHJWTCacheTTL *int
|
|
||||||
|
|
||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
DisableServerRoutes bool
|
DisableServerRoutes bool
|
||||||
@@ -438,12 +436,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
|
|
||||||
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
|
|
||||||
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
|
|
||||||
updated = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
|
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
|
||||||
log.Infof("updating DNS route interval to %s (old value %s)",
|
log.Infof("updating DNS route interval to %s (old value %s)",
|
||||||
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
|
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
|
||||||
|
|||||||
@@ -132,21 +132,3 @@ func (pm *ProfileManager) setActiveProfileState(profileName string) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLoginHint retrieves the email from the active profile to use as login_hint.
|
|
||||||
func GetLoginHint() string {
|
|
||||||
pm := NewProfileManager()
|
|
||||||
activeProf, err := pm.GetActiveProfile()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to get active profile for login hint: %v", err)
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return profileState.Email
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr stri
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
net, err := stdnet.NewNet(ctx, nil)
|
net, err := stdnet.NewNet(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
probeErr = fmt.Errorf("new net: %w", err)
|
probeErr = fmt.Errorf("new net: %w", err)
|
||||||
return
|
return
|
||||||
@@ -286,7 +286,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
net, err := stdnet.NewNet(ctx, nil)
|
net, err := stdnet.NewNet(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
probeErr = fmt.Errorf("new net: %w", err)
|
probeErr = fmt.Errorf("new net: %w", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
||||||
@@ -38,7 +39,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/pion/transport/v3/stdnet"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -403,7 +403,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
for n, testCase := range testCases {
|
for n, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/pion/transport/v3/stdnet"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
@@ -436,7 +436,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
|
|||||||
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
opts := iface.WGIFaceOpts{
|
opts := iface.WGIFaceOpts{
|
||||||
|
|||||||
@@ -4,28 +4,17 @@
|
|||||||
package stdnet
|
package stdnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/pion/transport/v3/stdnet"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const updateInterval = 30 * time.Second
|
||||||
updateInterval = 30 * time.Second
|
|
||||||
dnsResolveTimeout = 30 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
var errNoSuitableAddress = errors.New("no suitable address found")
|
|
||||||
|
|
||||||
// Net is an implementation of the net.Net interface
|
// Net is an implementation of the net.Net interface
|
||||||
// based on functions of the standard net package.
|
// based on functions of the standard net package.
|
||||||
@@ -39,19 +28,12 @@ type Net struct {
|
|||||||
|
|
||||||
// mu is shared between interfaces and lastUpdate
|
// mu is shared between interfaces and lastUpdate
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
||||||
// ctx is the context for network operations that supports cancellation
|
|
||||||
ctx context.Context
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNetWithDiscover creates a new StdNet instance.
|
// NewNetWithDiscover creates a new StdNet instance.
|
||||||
func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
|
func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
|
||||||
if ctx == nil {
|
|
||||||
ctx = context.Background()
|
|
||||||
}
|
|
||||||
n := &Net{
|
n := &Net{
|
||||||
interfaceFilter: InterfaceFilter(disallowList),
|
interfaceFilter: InterfaceFilter(disallowList),
|
||||||
ctx: ctx,
|
|
||||||
}
|
}
|
||||||
// current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client
|
// current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client
|
||||||
// so in android cli use pionDiscover
|
// so in android cli use pionDiscover
|
||||||
@@ -64,64 +46,14 @@ func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewNet creates a new StdNet instance.
|
// NewNet creates a new StdNet instance.
|
||||||
func NewNet(ctx context.Context, disallowList []string) (*Net, error) {
|
func NewNet(disallowList []string) (*Net, error) {
|
||||||
if ctx == nil {
|
|
||||||
ctx = context.Background()
|
|
||||||
}
|
|
||||||
n := &Net{
|
n := &Net{
|
||||||
iFaceDiscover: pionDiscover{},
|
iFaceDiscover: pionDiscover{},
|
||||||
interfaceFilter: InterfaceFilter(disallowList),
|
interfaceFilter: InterfaceFilter(disallowList),
|
||||||
ctx: ctx,
|
|
||||||
}
|
}
|
||||||
return n, n.UpdateInterfaces()
|
return n, n.UpdateInterfaces()
|
||||||
}
|
}
|
||||||
|
|
||||||
// resolveAddr performs DNS resolution with context support and timeout.
|
|
||||||
func (n *Net) resolveAddr(network, address string) (netip.AddrPort, error) {
|
|
||||||
host, portStr, err := net.SplitHostPort(address)
|
|
||||||
if err != nil {
|
|
||||||
return netip.AddrPort{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
port, err := strconv.Atoi(portStr)
|
|
||||||
if err != nil {
|
|
||||||
return netip.AddrPort{}, fmt.Errorf("invalid port: %w", err)
|
|
||||||
}
|
|
||||||
if port < 0 || port > 65535 {
|
|
||||||
return netip.AddrPort{}, fmt.Errorf("invalid port: %d", port)
|
|
||||||
}
|
|
||||||
|
|
||||||
ipNet := "ip"
|
|
||||||
switch network {
|
|
||||||
case "tcp4", "udp4":
|
|
||||||
ipNet = "ip4"
|
|
||||||
case "tcp6", "udp6":
|
|
||||||
ipNet = "ip6"
|
|
||||||
}
|
|
||||||
|
|
||||||
if host == "" {
|
|
||||||
addr := netip.IPv4Unspecified()
|
|
||||||
if ipNet == "ip6" {
|
|
||||||
addr = netip.IPv6Unspecified()
|
|
||||||
}
|
|
||||||
return netip.AddrPortFrom(addr, uint16(port)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(n.ctx, dnsResolveTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
addrs, err := net.DefaultResolver.LookupNetIP(ctx, ipNet, host)
|
|
||||||
if err != nil {
|
|
||||||
return netip.AddrPort{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(addrs) == 0 {
|
|
||||||
return netip.AddrPort{}, errNoSuitableAddress
|
|
||||||
}
|
|
||||||
|
|
||||||
return netip.AddrPortFrom(addrs[0], uint16(port)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateInterfaces updates the internal list of network interfaces
|
// UpdateInterfaces updates the internal list of network interfaces
|
||||||
// and associated addresses filtering them by name.
|
// and associated addresses filtering them by name.
|
||||||
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
|
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
|
||||||
@@ -205,39 +137,3 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I
|
|||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResolveUDPAddr resolves UDP addresses with context support and timeout.
|
|
||||||
func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) {
|
|
||||||
switch network {
|
|
||||||
case "udp", "udp4", "udp6":
|
|
||||||
case "":
|
|
||||||
network = "udp"
|
|
||||||
default:
|
|
||||||
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
|
|
||||||
}
|
|
||||||
|
|
||||||
addrPort, err := n.resolveAddr(network, address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.UDPAddr{IP: nil}, Err: err}
|
|
||||||
}
|
|
||||||
|
|
||||||
return net.UDPAddrFromAddrPort(addrPort), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResolveTCPAddr resolves TCP addresses with context support and timeout.
|
|
||||||
func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) {
|
|
||||||
switch network {
|
|
||||||
case "tcp", "tcp4", "tcp6":
|
|
||||||
case "":
|
|
||||||
network = "tcp"
|
|
||||||
default:
|
|
||||||
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
|
|
||||||
}
|
|
||||||
|
|
||||||
addrPort, err := n.resolveAddr(network, address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.TCPAddr{IP: nil}, Err: err}
|
|
||||||
}
|
|
||||||
|
|
||||||
return net.TCPAddrFromAddrPort(addrPort), nil
|
|
||||||
}
|
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,299 +0,0 @@
|
|||||||
package templates
|
|
||||||
|
|
||||||
import (
|
|
||||||
"html/template"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPKCEAuthMsgTemplate(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
data map[string]string
|
|
||||||
outputFile string
|
|
||||||
expectedTitle string
|
|
||||||
expectedInContent []string
|
|
||||||
notExpectedInContent []string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "error_state",
|
|
||||||
data: map[string]string{
|
|
||||||
"Error": "authentication failed: invalid state",
|
|
||||||
},
|
|
||||||
outputFile: "pkce-auth-error.html",
|
|
||||||
expectedTitle: "Login Failed",
|
|
||||||
expectedInContent: []string{
|
|
||||||
"authentication failed: invalid state",
|
|
||||||
"Login Failed",
|
|
||||||
},
|
|
||||||
notExpectedInContent: []string{
|
|
||||||
"Login Successful",
|
|
||||||
"Your device is now registered and logged in to NetBird",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "success_state",
|
|
||||||
data: map[string]string{
|
|
||||||
// No error field means success
|
|
||||||
},
|
|
||||||
outputFile: "pkce-auth-success.html",
|
|
||||||
expectedTitle: "Login Successful",
|
|
||||||
expectedInContent: []string{
|
|
||||||
"Login Successful",
|
|
||||||
"Your device is now registered and logged in to NetBird. You can now close this window.",
|
|
||||||
},
|
|
||||||
notExpectedInContent: []string{
|
|
||||||
"Login Failed",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error_state_timeout",
|
|
||||||
data: map[string]string{
|
|
||||||
"Error": "authentication timeout: request expired after 5 minutes",
|
|
||||||
},
|
|
||||||
outputFile: "pkce-auth-timeout.html",
|
|
||||||
expectedTitle: "Login Failed",
|
|
||||||
expectedInContent: []string{
|
|
||||||
"authentication timeout: request expired after 5 minutes",
|
|
||||||
"Login Failed",
|
|
||||||
},
|
|
||||||
notExpectedInContent: []string{
|
|
||||||
"Login Successful",
|
|
||||||
"Your device is now registered and logged in to NetBird",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// Parse the template
|
|
||||||
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to parse template: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create temp directory for this test
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
outputPath := filepath.Join(tempDir, tt.outputFile)
|
|
||||||
|
|
||||||
// Create output file
|
|
||||||
file, err := os.Create(outputPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create output file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute the template
|
|
||||||
if err := tmpl.Execute(file, tt.data); err != nil {
|
|
||||||
file.Close()
|
|
||||||
t.Fatalf("Failed to execute template: %v", err)
|
|
||||||
}
|
|
||||||
file.Close()
|
|
||||||
|
|
||||||
t.Logf("Generated test output: %s", outputPath)
|
|
||||||
|
|
||||||
// Read the generated file
|
|
||||||
content, err := os.ReadFile(outputPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to read output file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
contentStr := string(content)
|
|
||||||
|
|
||||||
// Verify file has content
|
|
||||||
if len(contentStr) == 0 {
|
|
||||||
t.Error("Output file is empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify basic HTML structure
|
|
||||||
basicElements := []string{
|
|
||||||
"<!DOCTYPE html>",
|
|
||||||
"<html",
|
|
||||||
"<head>",
|
|
||||||
"<body>",
|
|
||||||
"NetBird",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, elem := range basicElements {
|
|
||||||
if !contains(contentStr, elem) {
|
|
||||||
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify expected title
|
|
||||||
if !contains(contentStr, tt.expectedTitle) {
|
|
||||||
t.Errorf("Expected HTML to contain title '%s', but it was not found", tt.expectedTitle)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify expected content is present
|
|
||||||
for _, expected := range tt.expectedInContent {
|
|
||||||
if !contains(contentStr, expected) {
|
|
||||||
t.Errorf("Expected HTML to contain '%s', but it was not found", expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify unexpected content is not present
|
|
||||||
for _, notExpected := range tt.notExpectedInContent {
|
|
||||||
if contains(contentStr, notExpected) {
|
|
||||||
t.Errorf("Expected HTML to NOT contain '%s', but it was found", notExpected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPKCEAuthMsgTemplateValidation(t *testing.T) {
|
|
||||||
// Test that the template can be parsed without errors
|
|
||||||
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Template parsing failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test with empty data
|
|
||||||
t.Run("empty_data", func(t *testing.T) {
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
outputPath := filepath.Join(tempDir, "empty-data.html")
|
|
||||||
|
|
||||||
file, err := os.Create(outputPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create output file: %v", err)
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
if err := tmpl.Execute(file, nil); err != nil {
|
|
||||||
t.Errorf("Template execution with nil data failed: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test with error data
|
|
||||||
t.Run("with_error", func(t *testing.T) {
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
outputPath := filepath.Join(tempDir, "with-error.html")
|
|
||||||
|
|
||||||
file, err := os.Create(outputPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create output file: %v", err)
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
data := map[string]string{
|
|
||||||
"Error": "test error message",
|
|
||||||
}
|
|
||||||
if err := tmpl.Execute(file, data); err != nil {
|
|
||||||
t.Errorf("Template execution with error data failed: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPKCEAuthMsgTemplateContent(t *testing.T) {
|
|
||||||
// Test that the template contains expected elements
|
|
||||||
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Template parsing failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("success_content", func(t *testing.T) {
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
outputPath := filepath.Join(tempDir, "success.html")
|
|
||||||
|
|
||||||
file, err := os.Create(outputPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create output file: %v", err)
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
data := map[string]string{}
|
|
||||||
if err := tmpl.Execute(file, data); err != nil {
|
|
||||||
t.Fatalf("Template execution failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read the file and verify it contains expected content
|
|
||||||
content, err := os.ReadFile(outputPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to read output file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for success indicators
|
|
||||||
contentStr := string(content)
|
|
||||||
if len(contentStr) == 0 {
|
|
||||||
t.Error("Generated HTML is empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Basic HTML structure checks
|
|
||||||
requiredElements := []string{
|
|
||||||
"<!DOCTYPE html>",
|
|
||||||
"<html",
|
|
||||||
"<head>",
|
|
||||||
"<body>",
|
|
||||||
"Login Successful",
|
|
||||||
"NetBird",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, elem := range requiredElements {
|
|
||||||
if !contains(contentStr, elem) {
|
|
||||||
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("error_content", func(t *testing.T) {
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
outputPath := filepath.Join(tempDir, "error.html")
|
|
||||||
|
|
||||||
file, err := os.Create(outputPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create output file: %v", err)
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
errorMsg := "test error message"
|
|
||||||
data := map[string]string{
|
|
||||||
"Error": errorMsg,
|
|
||||||
}
|
|
||||||
if err := tmpl.Execute(file, data); err != nil {
|
|
||||||
t.Fatalf("Template execution failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read the file and verify it contains expected content
|
|
||||||
content, err := os.ReadFile(outputPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to read output file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for error indicators
|
|
||||||
contentStr := string(content)
|
|
||||||
if len(contentStr) == 0 {
|
|
||||||
t.Error("Generated HTML is empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Basic HTML structure checks
|
|
||||||
requiredElements := []string{
|
|
||||||
"<!DOCTYPE html>",
|
|
||||||
"<html",
|
|
||||||
"<head>",
|
|
||||||
"<body>",
|
|
||||||
"Login Failed",
|
|
||||||
errorMsg,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, elem := range requiredElements {
|
|
||||||
if !contains(contentStr, elem) {
|
|
||||||
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func contains(s, substr string) bool {
|
|
||||||
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
|
|
||||||
(len(s) > 0 && len(substr) > 0 && containsHelper(s, substr)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func containsHelper(s, substr string) bool {
|
|
||||||
for i := 0; i <= len(s)-len(substr); i++ {
|
|
||||||
if s[i:i+len(substr)] == substr {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -228,7 +228,7 @@ func (c *Client) LoginForMobile() string {
|
|||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
})
|
})
|
||||||
|
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "")
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err.Error()
|
return err.Error()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,66 +0,0 @@
|
|||||||
package jobexec
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/debug"
|
|
||||||
"github.com/netbirdio/netbird/upload-server/types"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
MaxBundleWaitTime = 60 * time.Minute // maximum wait time for bundle generation (1 hour)
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrJobNotImplemented = errors.New("job not implemented")
|
|
||||||
)
|
|
||||||
|
|
||||||
type Executor struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewExecutor() *Executor {
|
|
||||||
return &Executor{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Executor) BundleJob(ctx context.Context, debugBundleDependencies debug.GeneratorDependencies, params debug.BundleConfig, waitForDuration time.Duration, mgmURL string) (string, error) {
|
|
||||||
if waitForDuration > MaxBundleWaitTime {
|
|
||||||
log.Warnf("bundle wait time %v exceeds maximum %v, capping to maximum", waitForDuration, MaxBundleWaitTime)
|
|
||||||
waitForDuration = MaxBundleWaitTime
|
|
||||||
}
|
|
||||||
|
|
||||||
if waitForDuration > 0 {
|
|
||||||
waitFor(ctx, waitForDuration)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("execute debug bundle generation")
|
|
||||||
|
|
||||||
bundleGenerator := debug.NewBundleGenerator(debugBundleDependencies, params)
|
|
||||||
|
|
||||||
path, err := bundleGenerator.Generate()
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("generate debug bundle: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := debug.UploadDebugBundle(ctx, types.DefaultBundleURL, mgmURL, path)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to upload debug bundle: %v", err)
|
|
||||||
return "", fmt.Errorf("upload debug bundle: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("debug bundle has been generated well")
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitFor(ctx context.Context, duration time.Duration) {
|
|
||||||
log.Infof("wait for %v minutes before executing debug bundle", duration.Minutes())
|
|
||||||
select {
|
|
||||||
case <-time.After(duration):
|
|
||||||
case <-ctx.Done():
|
|
||||||
log.Infof("wait cancelled: %v", ctx.Err())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -90,7 +90,7 @@ service DaemonService {
|
|||||||
|
|
||||||
// RequestJWTAuth initiates JWT authentication flow for SSH
|
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||||
rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {}
|
rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {}
|
||||||
|
|
||||||
// WaitJWTToken waits for JWT authentication completion
|
// WaitJWTToken waits for JWT authentication completion
|
||||||
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
||||||
}
|
}
|
||||||
@@ -168,15 +168,11 @@ message LoginRequest {
|
|||||||
|
|
||||||
optional int64 mtu = 32;
|
optional int64 mtu = 32;
|
||||||
|
|
||||||
// hint is used to pre-fill the email/username field during SSO authentication
|
optional bool enableSSHRoot = 33;
|
||||||
optional string hint = 33;
|
optional bool enableSSHSFTP = 34;
|
||||||
|
optional bool enableSSHLocalPortForwarding = 35;
|
||||||
optional bool enableSSHRoot = 34;
|
optional bool enableSSHRemotePortForwarding = 36;
|
||||||
optional bool enableSSHSFTP = 35;
|
optional bool disableSSHAuth = 37;
|
||||||
optional bool enableSSHLocalPortForwarding = 36;
|
|
||||||
optional bool enableSSHRemotePortForwarding = 37;
|
|
||||||
optional bool disableSSHAuth = 38;
|
|
||||||
optional int32 sshJWTCacheTTL = 39;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message LoginResponse {
|
message LoginResponse {
|
||||||
@@ -206,7 +202,7 @@ message StatusRequest{
|
|||||||
bool getFullPeerStatus = 1;
|
bool getFullPeerStatus = 1;
|
||||||
bool shouldRunProbes = 2;
|
bool shouldRunProbes = 2;
|
||||||
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
|
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
|
||||||
optional bool waitForReady = 3;
|
optional bool waitForReady = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StatusResponse{
|
message StatusResponse{
|
||||||
@@ -281,8 +277,6 @@ message GetConfigResponse {
|
|||||||
bool enableSSHRemotePortForwarding = 23;
|
bool enableSSHRemotePortForwarding = 23;
|
||||||
|
|
||||||
bool disableSSHAuth = 25;
|
bool disableSSHAuth = 25;
|
||||||
|
|
||||||
int32 sshJWTCacheTTL = 26;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerState contains the latest state of a peer
|
// PeerState contains the latest state of a peer
|
||||||
@@ -346,20 +340,6 @@ message NSGroupState {
|
|||||||
string error = 4;
|
string error = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSHSessionInfo contains information about an active SSH session
|
|
||||||
message SSHSessionInfo {
|
|
||||||
string username = 1;
|
|
||||||
string remoteAddress = 2;
|
|
||||||
string command = 3;
|
|
||||||
string jwtUsername = 4;
|
|
||||||
}
|
|
||||||
|
|
||||||
// SSHServerState contains the latest state of the SSH server
|
|
||||||
message SSHServerState {
|
|
||||||
bool enabled = 1;
|
|
||||||
repeated SSHSessionInfo sessions = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// FullStatus contains the full state held by the Status instance
|
// FullStatus contains the full state held by the Status instance
|
||||||
message FullStatus {
|
message FullStatus {
|
||||||
ManagementState managementState = 1;
|
ManagementState managementState = 1;
|
||||||
@@ -373,7 +353,6 @@ message FullStatus {
|
|||||||
repeated SystemEvent events = 7;
|
repeated SystemEvent events = 7;
|
||||||
|
|
||||||
bool lazyConnectionEnabled = 9;
|
bool lazyConnectionEnabled = 9;
|
||||||
SSHServerState sshServerState = 10;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Networks
|
// Networks
|
||||||
@@ -434,6 +413,7 @@ message ForwardingRulesResponse {
|
|||||||
// DebugBundler
|
// DebugBundler
|
||||||
message DebugBundleRequest {
|
message DebugBundleRequest {
|
||||||
bool anonymize = 1;
|
bool anonymize = 1;
|
||||||
|
string status = 2;
|
||||||
bool systemInfo = 3;
|
bool systemInfo = 3;
|
||||||
string uploadURL = 4;
|
string uploadURL = 4;
|
||||||
uint32 logFileCount = 5;
|
uint32 logFileCount = 5;
|
||||||
@@ -639,10 +619,9 @@ message SetConfigRequest {
|
|||||||
|
|
||||||
optional bool enableSSHRoot = 29;
|
optional bool enableSSHRoot = 29;
|
||||||
optional bool enableSSHSFTP = 30;
|
optional bool enableSSHSFTP = 30;
|
||||||
optional bool enableSSHLocalPortForwarding = 31;
|
optional bool enableSSHLocalPortForward = 31;
|
||||||
optional bool enableSSHRemotePortForwarding = 32;
|
optional bool enableSSHRemotePortForward = 32;
|
||||||
optional bool disableSSHAuth = 33;
|
optional bool disableSSHAuth = 33;
|
||||||
optional int32 sshJWTCacheTTL = 34;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message SetConfigResponse{}
|
message SetConfigResponse{}
|
||||||
@@ -715,8 +694,6 @@ message GetPeerSSHHostKeyResponse {
|
|||||||
|
|
||||||
// RequestJWTAuthRequest for initiating JWT authentication flow
|
// RequestJWTAuthRequest for initiating JWT authentication flow
|
||||||
message RequestJWTAuthRequest {
|
message RequestJWTAuthRequest {
|
||||||
// hint for OIDC login_hint parameter (typically email address)
|
|
||||||
optional string hint = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequestJWTAuthResponse contains authentication flow information
|
// RequestJWTAuthResponse contains authentication flow information
|
||||||
|
|||||||
@@ -4,16 +4,24 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/debug"
|
"github.com/netbirdio/netbird/client/internal/debug"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/upload-server/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const maxBundleUploadSize = 50 * 1024 * 1024
|
||||||
|
|
||||||
// DebugBundle creates a debug bundle and returns the location.
|
// DebugBundle creates a debug bundle and returns the location.
|
||||||
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
|
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
@@ -29,10 +37,11 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
|||||||
InternalConfig: s.config,
|
InternalConfig: s.config,
|
||||||
StatusRecorder: s.statusRecorder,
|
StatusRecorder: s.statusRecorder,
|
||||||
SyncResponse: syncResponse,
|
SyncResponse: syncResponse,
|
||||||
LogPath: s.logFile,
|
LogFile: s.logFile,
|
||||||
},
|
},
|
||||||
debug.BundleConfig{
|
debug.BundleConfig{
|
||||||
Anonymize: req.GetAnonymize(),
|
Anonymize: req.GetAnonymize(),
|
||||||
|
ClientStatus: req.GetStatus(),
|
||||||
IncludeSystemInfo: req.GetSystemInfo(),
|
IncludeSystemInfo: req.GetSystemInfo(),
|
||||||
LogFileCount: req.GetLogFileCount(),
|
LogFileCount: req.GetLogFileCount(),
|
||||||
},
|
},
|
||||||
@@ -46,7 +55,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
|||||||
if req.GetUploadURL() == "" {
|
if req.GetUploadURL() == "" {
|
||||||
return &proto.DebugBundleResponse{Path: path}, nil
|
return &proto.DebugBundleResponse{Path: path}, nil
|
||||||
}
|
}
|
||||||
key, err := debug.UploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
|
key, err := uploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err)
|
log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err)
|
||||||
return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil
|
return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil
|
||||||
@@ -57,6 +66,92 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
|||||||
return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil
|
return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func uploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
|
||||||
|
response, err := getUploadURL(ctx, url, managementURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = upload(ctx, filePath, response)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return response.Key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
|
||||||
|
fileData, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer fileData.Close()
|
||||||
|
|
||||||
|
stat, err := fileData.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stat file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stat.Size() > maxBundleUploadSize {
|
||||||
|
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create PUT request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.ContentLength = stat.Size()
|
||||||
|
req.Header.Set("Content-Type", "application/octet-stream")
|
||||||
|
|
||||||
|
putResp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("upload failed: %v", err)
|
||||||
|
}
|
||||||
|
defer putResp.Body.Close()
|
||||||
|
|
||||||
|
if putResp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(putResp.Body)
|
||||||
|
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
|
||||||
|
id := getURLHash(managementURL)
|
||||||
|
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create GET request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(getReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get presigned URL: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
urlBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read response body: %w", err)
|
||||||
|
}
|
||||||
|
var response types.GetURLResponse
|
||||||
|
if err := json.Unmarshal(urlBytes, &response); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||||||
|
}
|
||||||
|
return &response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getURLHash(url string) string {
|
||||||
|
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
|
||||||
|
}
|
||||||
|
|
||||||
// GetLogLevel gets the current logging level for the server.
|
// GetLogLevel gets the current logging level for the server.
|
||||||
func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) {
|
func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) {
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package debug
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -38,7 +38,7 @@ func TestUpload(t *testing.T) {
|
|||||||
fileContent := []byte("test file content")
|
fileContent := []byte("test file content")
|
||||||
err := os.WriteFile(file, fileContent, 0640)
|
err := os.WriteFile(file, fileContent, 0640)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
key, err := UploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
|
key, err := uploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
id := getURLHash(testURL)
|
id := getURLHash(testURL)
|
||||||
require.Contains(t, key, id+"/")
|
require.Contains(t, key, id+"/")
|
||||||
@@ -37,18 +37,13 @@ func (c *jwtCache) store(token string, maxAge time.Duration) {
|
|||||||
|
|
||||||
c.expiresAt = time.Now().Add(maxAge)
|
c.expiresAt = time.Now().Add(maxAge)
|
||||||
|
|
||||||
var timer *time.Timer
|
c.timer = time.AfterFunc(maxAge, func() {
|
||||||
timer = time.AfterFunc(maxAge, func() {
|
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
if c.timer != timer {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.cleanup()
|
c.cleanup()
|
||||||
c.timer = nil
|
c.timer = nil
|
||||||
log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge)
|
log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge)
|
||||||
})
|
})
|
||||||
c.timer = timer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *jwtCache) get() (string, bool) {
|
func (c *jwtCache) get() (string, bool) {
|
||||||
@@ -75,5 +70,4 @@ func (c *jwtCache) cleanup() {
|
|||||||
if c.enclave != nil {
|
if c.enclave != nil {
|
||||||
c.enclave = nil
|
c.enclave = nil
|
||||||
}
|
}
|
||||||
c.expiresAt = time.Time{}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,11 +13,15 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
"golang.org/x/exp/maps"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"google.golang.org/protobuf/types/known/durationpb"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
@@ -28,7 +32,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -43,8 +46,8 @@ const (
|
|||||||
defaultMaxRetryTime = 14 * 24 * time.Hour
|
defaultMaxRetryTime = 14 * 24 * time.Hour
|
||||||
defaultRetryMultiplier = 1.7
|
defaultRetryMultiplier = 1.7
|
||||||
|
|
||||||
// JWT token cache TTL for the client daemon (disabled by default)
|
// JWT token cache TTL for the client daemon
|
||||||
defaultJWTCacheTTL = 0
|
defaultJWTCacheTTL = 5 * time.Minute
|
||||||
|
|
||||||
errRestoreResidualState = "failed to restore residual state: %v"
|
errRestoreResidualState = "failed to restore residual state: %v"
|
||||||
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
|
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
|
||||||
@@ -378,15 +381,11 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
|||||||
config.BlockInbound = msg.BlockInbound
|
config.BlockInbound = msg.BlockInbound
|
||||||
config.EnableSSHRoot = msg.EnableSSHRoot
|
config.EnableSSHRoot = msg.EnableSSHRoot
|
||||||
config.EnableSSHSFTP = msg.EnableSSHSFTP
|
config.EnableSSHSFTP = msg.EnableSSHSFTP
|
||||||
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForwarding
|
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForward
|
||||||
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForwarding
|
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForward
|
||||||
if msg.DisableSSHAuth != nil {
|
if msg.DisableSSHAuth != nil {
|
||||||
config.DisableSSHAuth = msg.DisableSSHAuth
|
config.DisableSSHAuth = msg.DisableSSHAuth
|
||||||
}
|
}
|
||||||
if msg.SshJWTCacheTTL != nil {
|
|
||||||
ttl := int(*msg.SshJWTCacheTTL)
|
|
||||||
config.SSHJWTCacheTTL = &ttl
|
|
||||||
}
|
|
||||||
|
|
||||||
if msg.Mtu != nil {
|
if msg.Mtu != nil {
|
||||||
mtu := uint16(*msg.Mtu)
|
mtu := uint16(*msg.Mtu)
|
||||||
@@ -497,11 +496,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
|||||||
state.Set(internal.StatusConnecting)
|
state.Set(internal.StatusConnecting)
|
||||||
|
|
||||||
if msg.SetupKey == "" {
|
if msg.SetupKey == "" {
|
||||||
hint := ""
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient)
|
||||||
if msg.Hint != nil {
|
|
||||||
hint = *msg.Hint
|
|
||||||
}
|
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.Set(internal.StatusLoginFailed)
|
state.Set(internal.StatusLoginFailed)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1077,49 +1072,14 @@ func (s *Server) Status(
|
|||||||
if msg.GetFullPeerStatus {
|
if msg.GetFullPeerStatus {
|
||||||
s.runProbes(msg.ShouldRunProbes)
|
s.runProbes(msg.ShouldRunProbes)
|
||||||
fullStatus := s.statusRecorder.GetFullStatus()
|
fullStatus := s.statusRecorder.GetFullStatus()
|
||||||
pbFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
pbFullStatus := toProtoFullStatus(fullStatus)
|
||||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||||
|
|
||||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
|
||||||
|
|
||||||
statusResponse.FullStatus = pbFullStatus
|
statusResponse.FullStatus = pbFullStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
return &statusResponse, nil
|
return &statusResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getSSHServerState retrieves the current SSH server state including enabled status and active sessions
|
|
||||||
func (s *Server) getSSHServerState() *proto.SSHServerState {
|
|
||||||
s.mutex.Lock()
|
|
||||||
connectClient := s.connectClient
|
|
||||||
s.mutex.Unlock()
|
|
||||||
|
|
||||||
if connectClient == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
engine := connectClient.Engine()
|
|
||||||
if engine == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
enabled, sessions := engine.GetSSHServerStatus()
|
|
||||||
sshServerState := &proto.SSHServerState{
|
|
||||||
Enabled: enabled,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, session := range sessions {
|
|
||||||
sshServerState.Sessions = append(sshServerState.Sessions, &proto.SSHSessionInfo{
|
|
||||||
Username: session.Username,
|
|
||||||
RemoteAddress: session.RemoteAddress,
|
|
||||||
Command: session.Command,
|
|
||||||
JwtUsername: session.JWTUsername,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return sshServerState
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||||
func (s *Server) GetPeerSSHHostKey(
|
func (s *Server) GetPeerSSHHostKey(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -1172,31 +1132,35 @@ func (s *Server) GetPeerSSHHostKey(
|
|||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getJWTCacheTTL returns the JWT cache TTL from config or default (disabled)
|
// getJWTCacheTTL returns the JWT cache TTL from environment variable or default
|
||||||
func (s *Server) getJWTCacheTTL() time.Duration {
|
// NB_SSH_JWT_CACHE_TTL=0 disables caching
|
||||||
s.mutex.Lock()
|
// NB_SSH_JWT_CACHE_TTL=<seconds> sets custom cache TTL
|
||||||
config := s.config
|
func getJWTCacheTTL() time.Duration {
|
||||||
s.mutex.Unlock()
|
envValue := os.Getenv("NB_SSH_JWT_CACHE_TTL")
|
||||||
|
if envValue == "" {
|
||||||
if config == nil || config.SSHJWTCacheTTL == nil {
|
return defaultJWTCacheTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
seconds, err := strconv.Atoi(envValue)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("invalid NB_SSH_JWT_CACHE_TTL value %s, using default: %v", envValue, defaultJWTCacheTTL)
|
||||||
return defaultJWTCacheTTL
|
return defaultJWTCacheTTL
|
||||||
}
|
}
|
||||||
|
|
||||||
seconds := *config.SSHJWTCacheTTL
|
|
||||||
if seconds == 0 {
|
if seconds == 0 {
|
||||||
log.Debug("SSH JWT cache disabled (configured to 0)")
|
log.Info("SSH JWT cache disabled via NB_SSH_JWT_CACHE_TTL=0")
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
ttl := time.Duration(seconds) * time.Second
|
ttl := time.Duration(seconds) * time.Second
|
||||||
log.Debugf("SSH JWT cache TTL set to %v from config", ttl)
|
log.Infof("SSH JWT cache TTL set to %v via NB_SSH_JWT_CACHE_TTL", ttl)
|
||||||
return ttl
|
return ttl
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequestJWTAuth initiates JWT authentication flow for SSH
|
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||||
func (s *Server) RequestJWTAuth(
|
func (s *Server) RequestJWTAuth(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
msg *proto.RequestJWTAuthRequest,
|
_ *proto.RequestJWTAuthRequest,
|
||||||
) (*proto.RequestJWTAuthResponse, error) {
|
) (*proto.RequestJWTAuthResponse, error) {
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
@@ -1210,7 +1174,7 @@ func (s *Server) RequestJWTAuth(
|
|||||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
|
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtCacheTTL := s.getJWTCacheTTL()
|
jwtCacheTTL := getJWTCacheTTL()
|
||||||
if jwtCacheTTL > 0 {
|
if jwtCacheTTL > 0 {
|
||||||
if cachedToken, found := s.jwtCache.get(); found {
|
if cachedToken, found := s.jwtCache.get(); found {
|
||||||
log.Debugf("JWT token found in cache, returning cached token for SSH authentication")
|
log.Debugf("JWT token found in cache, returning cached token for SSH authentication")
|
||||||
@@ -1222,17 +1186,8 @@ func (s *Server) RequestJWTAuth(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hint := ""
|
|
||||||
if msg.Hint != nil {
|
|
||||||
hint = *msg.Hint
|
|
||||||
}
|
|
||||||
|
|
||||||
if hint == "" {
|
|
||||||
hint = profilemanager.GetLoginHint()
|
|
||||||
}
|
|
||||||
|
|
||||||
isDesktop := isUnixRunningDesktop()
|
isDesktop := isUnixRunningDesktop()
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, hint)
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
|
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
|
||||||
}
|
}
|
||||||
@@ -1283,7 +1238,7 @@ func (s *Server) WaitJWTToken(
|
|||||||
|
|
||||||
token := tokenInfo.GetTokenToUse()
|
token := tokenInfo.GetTokenToUse()
|
||||||
|
|
||||||
jwtCacheTTL := s.getJWTCacheTTL()
|
jwtCacheTTL := getJWTCacheTTL()
|
||||||
if jwtCacheTTL > 0 {
|
if jwtCacheTTL > 0 {
|
||||||
s.jwtCache.store(token, jwtCacheTTL)
|
s.jwtCache.store(token, jwtCacheTTL)
|
||||||
log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL)
|
log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL)
|
||||||
@@ -1374,33 +1329,28 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
|||||||
blockLANAccess := cfg.BlockLANAccess
|
blockLANAccess := cfg.BlockLANAccess
|
||||||
|
|
||||||
enableSSHRoot := false
|
enableSSHRoot := false
|
||||||
if cfg.EnableSSHRoot != nil {
|
if s.config.EnableSSHRoot != nil {
|
||||||
enableSSHRoot = *cfg.EnableSSHRoot
|
enableSSHRoot = *s.config.EnableSSHRoot
|
||||||
}
|
}
|
||||||
|
|
||||||
enableSSHSFTP := false
|
enableSSHSFTP := false
|
||||||
if cfg.EnableSSHSFTP != nil {
|
if s.config.EnableSSHSFTP != nil {
|
||||||
enableSSHSFTP = *cfg.EnableSSHSFTP
|
enableSSHSFTP = *s.config.EnableSSHSFTP
|
||||||
}
|
}
|
||||||
|
|
||||||
enableSSHLocalPortForwarding := false
|
enableSSHLocalPortForwarding := false
|
||||||
if cfg.EnableSSHLocalPortForwarding != nil {
|
if s.config.EnableSSHLocalPortForwarding != nil {
|
||||||
enableSSHLocalPortForwarding = *cfg.EnableSSHLocalPortForwarding
|
enableSSHLocalPortForwarding = *s.config.EnableSSHLocalPortForwarding
|
||||||
}
|
}
|
||||||
|
|
||||||
enableSSHRemotePortForwarding := false
|
enableSSHRemotePortForwarding := false
|
||||||
if cfg.EnableSSHRemotePortForwarding != nil {
|
if s.config.EnableSSHRemotePortForwarding != nil {
|
||||||
enableSSHRemotePortForwarding = *cfg.EnableSSHRemotePortForwarding
|
enableSSHRemotePortForwarding = *s.config.EnableSSHRemotePortForwarding
|
||||||
}
|
}
|
||||||
|
|
||||||
disableSSHAuth := false
|
disableSSHAuth := false
|
||||||
if cfg.DisableSSHAuth != nil {
|
if s.config.DisableSSHAuth != nil {
|
||||||
disableSSHAuth = *cfg.DisableSSHAuth
|
disableSSHAuth = *s.config.DisableSSHAuth
|
||||||
}
|
|
||||||
|
|
||||||
sshJWTCacheTTL := int32(0)
|
|
||||||
if cfg.SSHJWTCacheTTL != nil {
|
|
||||||
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &proto.GetConfigResponse{
|
return &proto.GetConfigResponse{
|
||||||
@@ -1427,7 +1377,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
|||||||
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
|
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
|
||||||
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
||||||
DisableSSHAuth: disableSSHAuth,
|
DisableSSHAuth: disableSSHAuth,
|
||||||
SshJWTCacheTTL: sshJWTCacheTTL,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1535,7 +1484,7 @@ func (s *Server) connect(ctx context.Context, config *profilemanager.Config, sta
|
|||||||
log.Tracef("running client connection")
|
log.Tracef("running client connection")
|
||||||
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
|
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
|
||||||
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
|
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
|
||||||
if err := s.connectClient.Run(runningChan, s.logFile); err != nil {
|
if err := s.connectClient.Run(runningChan); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -1609,6 +1558,94 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio
|
|||||||
return defaultDuration
|
return defaultDuration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||||
|
pbFullStatus := proto.FullStatus{
|
||||||
|
ManagementState: &proto.ManagementState{},
|
||||||
|
SignalState: &proto.SignalState{},
|
||||||
|
LocalPeerState: &proto.LocalPeerState{},
|
||||||
|
Peers: []*proto.PeerState{},
|
||||||
|
}
|
||||||
|
|
||||||
|
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
|
||||||
|
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
|
||||||
|
if err := fullStatus.ManagementState.Error; err != nil {
|
||||||
|
pbFullStatus.ManagementState.Error = err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
|
||||||
|
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
|
||||||
|
if err := fullStatus.SignalState.Error; err != nil {
|
||||||
|
pbFullStatus.SignalState.Error = err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
|
||||||
|
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
|
||||||
|
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
|
||||||
|
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
|
||||||
|
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
|
||||||
|
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
|
||||||
|
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
|
||||||
|
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
|
||||||
|
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
|
||||||
|
|
||||||
|
for _, peerState := range fullStatus.Peers {
|
||||||
|
pbPeerState := &proto.PeerState{
|
||||||
|
IP: peerState.IP,
|
||||||
|
PubKey: peerState.PubKey,
|
||||||
|
ConnStatus: peerState.ConnStatus.String(),
|
||||||
|
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
||||||
|
Relayed: peerState.Relayed,
|
||||||
|
LocalIceCandidateType: peerState.LocalIceCandidateType,
|
||||||
|
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
|
||||||
|
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
|
||||||
|
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
|
||||||
|
RelayAddress: peerState.RelayServerAddress,
|
||||||
|
Fqdn: peerState.FQDN,
|
||||||
|
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
|
||||||
|
BytesRx: peerState.BytesRx,
|
||||||
|
BytesTx: peerState.BytesTx,
|
||||||
|
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||||
|
Networks: maps.Keys(peerState.GetRoutes()),
|
||||||
|
Latency: durationpb.New(peerState.Latency),
|
||||||
|
SshHostKey: peerState.SSHHostKey,
|
||||||
|
}
|
||||||
|
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, relayState := range fullStatus.Relays {
|
||||||
|
pbRelayState := &proto.RelayState{
|
||||||
|
URI: relayState.URI,
|
||||||
|
Available: relayState.Err == nil,
|
||||||
|
}
|
||||||
|
if err := relayState.Err; err != nil {
|
||||||
|
pbRelayState.Error = err.Error()
|
||||||
|
}
|
||||||
|
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dnsState := range fullStatus.NSGroupStates {
|
||||||
|
var err string
|
||||||
|
if dnsState.Error != nil {
|
||||||
|
err = dnsState.Error.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
var servers []string
|
||||||
|
for _, server := range dnsState.Servers {
|
||||||
|
servers = append(servers, server.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
pbDnsState := &proto.NSGroupState{
|
||||||
|
Servers: servers,
|
||||||
|
Domains: dnsState.Domains,
|
||||||
|
Enabled: dnsState.Enabled,
|
||||||
|
Error: err,
|
||||||
|
}
|
||||||
|
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pbFullStatus
|
||||||
|
}
|
||||||
|
|
||||||
// sendTerminalNotification sends a terminal notification message
|
// sendTerminalNotification sends a terminal notification message
|
||||||
// to inform the user that the NetBird connection session has expired.
|
// to inform the user that the NetBird connection session has expired.
|
||||||
func sendTerminalNotification() error {
|
func sendTerminalNotification() error {
|
||||||
|
|||||||
@@ -15,11 +15,6 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||||
@@ -295,7 +290,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
}
|
}
|
||||||
t.Cleanup(cleanUp)
|
t.Cleanup(cleanUp)
|
||||||
|
|
||||||
jobManager := job.NewJobManager(nil, store)
|
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
@@ -316,16 +311,13 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
accountManager, err := server.BuildManager(context.Background(), config, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
|
||||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, jobManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
|
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,7 +72,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
|||||||
lazyConnectionEnabled := true
|
lazyConnectionEnabled := true
|
||||||
blockInbound := true
|
blockInbound := true
|
||||||
mtu := int64(1280)
|
mtu := int64(1280)
|
||||||
sshJWTCacheTTL := int32(300)
|
|
||||||
|
|
||||||
req := &proto.SetConfigRequest{
|
req := &proto.SetConfigRequest{
|
||||||
ProfileName: profName,
|
ProfileName: profName,
|
||||||
@@ -103,7 +102,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
|||||||
CleanDNSLabels: false,
|
CleanDNSLabels: false,
|
||||||
DnsRouteInterval: durationpb.New(2 * time.Minute),
|
DnsRouteInterval: durationpb.New(2 * time.Minute),
|
||||||
Mtu: &mtu,
|
Mtu: &mtu,
|
||||||
SshJWTCacheTTL: &sshJWTCacheTTL,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.SetConfig(ctx, req)
|
_, err = s.SetConfig(ctx, req)
|
||||||
@@ -148,8 +146,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
|||||||
require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList())
|
require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList())
|
||||||
require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval)
|
require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval)
|
||||||
require.Equal(t, uint16(mtu), cfg.MTU)
|
require.Equal(t, uint16(mtu), cfg.MTU)
|
||||||
require.NotNil(t, cfg.SSHJWTCacheTTL)
|
|
||||||
require.Equal(t, int(sshJWTCacheTTL), *cfg.SSHJWTCacheTTL)
|
|
||||||
|
|
||||||
verifyAllFieldsCovered(t, req)
|
verifyAllFieldsCovered(t, req)
|
||||||
}
|
}
|
||||||
@@ -171,36 +167,35 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
expectedFields := map[string]bool{
|
expectedFields := map[string]bool{
|
||||||
"ManagementUrl": true,
|
"ManagementUrl": true,
|
||||||
"AdminURL": true,
|
"AdminURL": true,
|
||||||
"RosenpassEnabled": true,
|
"RosenpassEnabled": true,
|
||||||
"RosenpassPermissive": true,
|
"RosenpassPermissive": true,
|
||||||
"ServerSSHAllowed": true,
|
"ServerSSHAllowed": true,
|
||||||
"InterfaceName": true,
|
"InterfaceName": true,
|
||||||
"WireguardPort": true,
|
"WireguardPort": true,
|
||||||
"OptionalPreSharedKey": true,
|
"OptionalPreSharedKey": true,
|
||||||
"DisableAutoConnect": true,
|
"DisableAutoConnect": true,
|
||||||
"NetworkMonitor": true,
|
"NetworkMonitor": true,
|
||||||
"DisableClientRoutes": true,
|
"DisableClientRoutes": true,
|
||||||
"DisableServerRoutes": true,
|
"DisableServerRoutes": true,
|
||||||
"DisableDns": true,
|
"DisableDns": true,
|
||||||
"DisableFirewall": true,
|
"DisableFirewall": true,
|
||||||
"BlockLanAccess": true,
|
"BlockLanAccess": true,
|
||||||
"DisableNotifications": true,
|
"DisableNotifications": true,
|
||||||
"LazyConnectionEnabled": true,
|
"LazyConnectionEnabled": true,
|
||||||
"BlockInbound": true,
|
"BlockInbound": true,
|
||||||
"NatExternalIPs": true,
|
"NatExternalIPs": true,
|
||||||
"CustomDNSAddress": true,
|
"CustomDNSAddress": true,
|
||||||
"ExtraIFaceBlacklist": true,
|
"ExtraIFaceBlacklist": true,
|
||||||
"DnsLabels": true,
|
"DnsLabels": true,
|
||||||
"DnsRouteInterval": true,
|
"DnsRouteInterval": true,
|
||||||
"Mtu": true,
|
"Mtu": true,
|
||||||
"EnableSSHRoot": true,
|
"EnableSSHRoot": true,
|
||||||
"EnableSSHSFTP": true,
|
"EnableSSHSFTP": true,
|
||||||
"EnableSSHLocalPortForwarding": true,
|
"EnableSSHLocalPortForward": true,
|
||||||
"EnableSSHRemotePortForwarding": true,
|
"EnableSSHRemotePortForward": true,
|
||||||
"DisableSSHAuth": true,
|
"DisableSSHAuth": true,
|
||||||
"SshJWTCacheTTL": true,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
val := reflect.ValueOf(req).Elem()
|
val := reflect.ValueOf(req).Elem()
|
||||||
@@ -256,10 +251,9 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
|||||||
"mtu": "Mtu",
|
"mtu": "Mtu",
|
||||||
"enable-ssh-root": "EnableSSHRoot",
|
"enable-ssh-root": "EnableSSHRoot",
|
||||||
"enable-ssh-sftp": "EnableSSHSFTP",
|
"enable-ssh-sftp": "EnableSSHSFTP",
|
||||||
"enable-ssh-local-port-forwarding": "EnableSSHLocalPortForwarding",
|
"enable-ssh-local-port-forwarding": "EnableSSHLocalPortForward",
|
||||||
"enable-ssh-remote-port-forwarding": "EnableSSHRemotePortForwarding",
|
"enable-ssh-remote-port-forwarding": "EnableSSHRemotePortForward",
|
||||||
"disable-ssh-auth": "DisableSSHAuth",
|
"disable-ssh-auth": "DisableSSHAuth",
|
||||||
"ssh-jwt-cache-ttl": "SshJWTCacheTTL",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).
|
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
@@ -215,7 +214,7 @@ func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) erro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleCommandError processes command execution errors
|
// handleCommandError processes command execution errors, treating exit codes as normal
|
||||||
func (c *Client) handleCommandError(err error) error {
|
func (c *Client) handleCommandError(err error) error {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -223,11 +222,11 @@ func (c *Client) handleCommandError(err error) error {
|
|||||||
|
|
||||||
var e *ssh.ExitError
|
var e *ssh.ExitError
|
||||||
var em *ssh.ExitMissingError
|
var em *ssh.ExitMissingError
|
||||||
if errors.As(err, &e) || errors.As(err, &em) {
|
if !errors.As(err, &e) && !errors.As(err, &em) {
|
||||||
return err
|
return fmt.Errorf("execute command: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("execute command: %w", err)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupContextCancellation sets up context cancellation for a session
|
// setupContextCancellation sets up context cancellation for a session
|
||||||
@@ -282,12 +281,6 @@ type DialOptions struct {
|
|||||||
|
|
||||||
// Dial connects to the given ssh server with specified options
|
// Dial connects to the given ssh server with specified options
|
||||||
func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
|
func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
|
||||||
daemonAddr := opts.DaemonAddr
|
|
||||||
if daemonAddr == "" {
|
|
||||||
daemonAddr = getDefaultDaemonAddr()
|
|
||||||
}
|
|
||||||
opts.DaemonAddr = daemonAddr
|
|
||||||
|
|
||||||
hostKeyCallback, err := createHostKeyCallback(opts)
|
hostKeyCallback, err := createHostKeyCallback(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create host key callback: %w", err)
|
return nil, fmt.Errorf("create host key callback: %w", err)
|
||||||
@@ -307,6 +300,11 @@ func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, er
|
|||||||
config.Auth = append(config.Auth, authMethod)
|
config.Auth = append(config.Auth, authMethod)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
daemonAddr := opts.DaemonAddr
|
||||||
|
if daemonAddr == "" {
|
||||||
|
daemonAddr = getDefaultDaemonAddr()
|
||||||
|
}
|
||||||
|
|
||||||
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken)
|
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -367,8 +365,6 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo
|
|||||||
|
|
||||||
// requestJWTToken requests a JWT token from the NetBird daemon
|
// requestJWTToken requests a JWT token from the NetBird daemon
|
||||||
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) {
|
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) {
|
||||||
hint := profilemanager.GetLoginHint()
|
|
||||||
|
|
||||||
conn, err := connectToDaemon(daemonAddr)
|
conn, err := connectToDaemon(daemonAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("connect to daemon: %w", err)
|
return "", fmt.Errorf("connect to daemon: %w", err)
|
||||||
@@ -376,7 +372,7 @@ func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (st
|
|||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint)
|
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache)
|
||||||
}
|
}
|
||||||
|
|
||||||
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
|
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
|
||||||
@@ -468,7 +464,7 @@ func tryKnownHostsVerification(hostname string, remote net.Addr, key ssh.PublicK
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fmt.Errorf("host key verification failed: key for %s not found in any known_hosts file", hostname)
|
return fmt.Errorf("host key verification failed: key not found in NetBird daemon or any known_hosts file")
|
||||||
}
|
}
|
||||||
|
|
||||||
func getKnownHostsFilesList(knownHostsFile string) []string {
|
func getKnownHostsFilesList(knownHostsFile string) []string {
|
||||||
@@ -512,7 +508,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
if err := localListener.Close(); err != nil {
|
||||||
log.Debugf("local listener close error: %v", err)
|
log.Debugf("local listener close error: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -530,9 +526,6 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
|
||||||
log.Debugf("local listener close error: %v", err)
|
|
||||||
}
|
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -137,10 +137,10 @@ func TestSSHClient_ConnectionHandling(t *testing.T) {
|
|||||||
const numClients = 3
|
const numClients = 3
|
||||||
clients := make([]*Client, numClients)
|
clients := make([]*Client, numClients)
|
||||||
|
|
||||||
currentUser := testutil.GetTestUsername(t)
|
|
||||||
for i := 0; i < numClients; i++ {
|
for i := 0; i < numClients; i++ {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
currentUser := testutil.GetTestUsername(t)
|
||||||
|
client, err := Dial(ctx, serverAddr, fmt.Sprintf("%s-%d", currentUser, i), DialOptions{
|
||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
})
|
})
|
||||||
cancel()
|
cancel()
|
||||||
|
|||||||
@@ -15,26 +15,17 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) error {
|
func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) error {
|
||||||
stdinFd := int(os.Stdin.Fd())
|
fd := int(os.Stdout.Fd())
|
||||||
|
|
||||||
if !term.IsTerminal(stdinFd) {
|
if !term.IsTerminal(fd) {
|
||||||
return c.setupNonTerminalMode(ctx, session)
|
return c.setupNonTerminalMode(ctx, session)
|
||||||
}
|
}
|
||||||
|
|
||||||
fd := int(os.Stdin.Fd())
|
|
||||||
|
|
||||||
state, err := term.MakeRaw(fd)
|
state, err := term.MakeRaw(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.setupNonTerminalMode(ctx, session)
|
return c.setupNonTerminalMode(ctx, session)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.setupTerminal(session, fd); err != nil {
|
|
||||||
if restoreErr := term.Restore(fd, state); restoreErr != nil {
|
|
||||||
log.Debugf("restore terminal state: %v", restoreErr)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.terminalState = state
|
c.terminalState = state
|
||||||
c.terminalFd = fd
|
c.terminalFd = fd
|
||||||
|
|
||||||
@@ -64,10 +55,27 @@ func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) er
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return nil
|
return c.setupTerminal(session, fd)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) setupNonTerminalMode(_ context.Context, session *ssh.Session) error {
|
func (c *Client) setupNonTerminalMode(_ context.Context, session *ssh.Session) error {
|
||||||
|
w, h := 80, 24
|
||||||
|
|
||||||
|
modes := ssh.TerminalModes{
|
||||||
|
ssh.ECHO: 1,
|
||||||
|
ssh.TTY_OP_ISPEED: 14400,
|
||||||
|
ssh.TTY_OP_OSPEED: 14400,
|
||||||
|
}
|
||||||
|
|
||||||
|
terminal := os.Getenv("TERM")
|
||||||
|
if terminal == "" {
|
||||||
|
terminal = "xterm-256color"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := session.RequestPty(terminal, h, w, modes); err != nil {
|
||||||
|
return fmt.Errorf("request pty: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -62,10 +62,11 @@ func (c *Client) setupTerminalMode(_ context.Context, session *ssh.Session) erro
|
|||||||
if err := c.saveWindowsConsoleState(); err != nil {
|
if err := c.saveWindowsConsoleState(); err != nil {
|
||||||
var consoleErr *ConsoleUnavailableError
|
var consoleErr *ConsoleUnavailableError
|
||||||
if errors.As(err, &consoleErr) {
|
if errors.As(err, &consoleErr) {
|
||||||
log.Debugf("console unavailable, not requesting PTY: %v", err)
|
log.Debugf("console unavailable, continuing with defaults: %v", err)
|
||||||
return nil
|
c.terminalFd = 0
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("save console state: %w", err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("save console state: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.enableWindowsVirtualTerminal(); err != nil {
|
if err := c.enableWindowsVirtualTerminal(); err != nil {
|
||||||
@@ -104,14 +105,7 @@ func (c *Client) setupTerminalMode(_ context.Context, session *ssh.Session) erro
|
|||||||
ssh.VREPRINT: 18, // Ctrl+R
|
ssh.VREPRINT: 18, // Ctrl+R
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := session.RequestPty("xterm-256color", h, w, modes); err != nil {
|
return session.RequestPty("xterm-256color", h, w, modes)
|
||||||
if restoreErr := c.restoreWindowsConsoleState(); restoreErr != nil {
|
|
||||||
log.Debugf("restore Windows console state: %v", restoreErr)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("request pty: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) saveWindowsConsoleState() error {
|
func (c *Client) saveWindowsConsoleState() error {
|
||||||
|
|||||||
@@ -68,12 +68,8 @@ func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
|
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
|
||||||
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string) (string, error) {
|
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool) (string, error) {
|
||||||
req := &proto.RequestJWTAuthRequest{}
|
authResponse, err := client.RequestJWTAuth(ctx, &proto.RequestJWTAuthRequest{})
|
||||||
if hint != "" {
|
|
||||||
req.Hint = &hint
|
|
||||||
}
|
|
||||||
authResponse, err := client.RequestJWTAuth(ctx, req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("request JWT auth: %w", err)
|
return "", fmt.Errorf("request JWT auth: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
@@ -39,7 +38,6 @@ type SSHProxy struct {
|
|||||||
targetHost string
|
targetHost string
|
||||||
targetPort int
|
targetPort int
|
||||||
stderr io.Writer
|
stderr io.Writer
|
||||||
conn *grpc.ClientConn
|
|
||||||
daemonClient proto.DaemonServiceClient
|
daemonClient proto.DaemonServiceClient
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,22 +53,12 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHP
|
|||||||
targetHost: targetHost,
|
targetHost: targetHost,
|
||||||
targetPort: targetPort,
|
targetPort: targetPort,
|
||||||
stderr: stderr,
|
stderr: stderr,
|
||||||
conn: grpcConn,
|
|
||||||
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SSHProxy) Close() error {
|
|
||||||
if p.conn != nil {
|
|
||||||
return p.conn.Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *SSHProxy) Connect(ctx context.Context) error {
|
func (p *SSHProxy) Connect(ctx context.Context) error {
|
||||||
hint := profilemanager.GetLoginHint()
|
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true)
|
||||||
|
|
||||||
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(jwtAuthErrorMsg, err)
|
return fmt.Errorf(jwtAuthErrorMsg, err)
|
||||||
}
|
}
|
||||||
@@ -156,7 +144,6 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
|
|||||||
if len(session.Command()) > 0 {
|
if len(session.Command()) > 0 {
|
||||||
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
|
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
|
||||||
log.Debugf("run command: %v", err)
|
log.Debugf("run command: %v", err)
|
||||||
p.handleProxyExitCode(session, err)
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -167,16 +154,6 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
|
|||||||
}
|
}
|
||||||
if err := serverSession.Wait(); err != nil {
|
if err := serverSession.Wait(); err != nil {
|
||||||
log.Debugf("session wait: %v", err)
|
log.Debugf("session wait: %v", err)
|
||||||
p.handleProxyExitCode(session, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
|
|
||||||
var exitErr *cryptossh.ExitError
|
|
||||||
if errors.As(err, &exitErr) {
|
|
||||||
if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil {
|
|
||||||
log.Debugf("set exit status: %v", exitErr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
|
|||||||
|
|
||||||
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
|
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
|
||||||
|
|
||||||
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
|
execCmd, err := s.createCommand(privilegeResult, session, hasPty)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("%s creation failed: %v", commandType, err)
|
logger.Errorf("%s creation failed: %v", commandType, err)
|
||||||
|
|
||||||
@@ -42,59 +42,31 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hasPty {
|
if s.executeCommand(logger, session, execCmd) {
|
||||||
if s.executeCommand(logger, session, execCmd, cleanup) {
|
|
||||||
logger.Debugf("%s execution completed", commandType)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
ptyReq, _, _ := session.Pty()
|
|
||||||
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
|
|
||||||
logger.Debugf("%s execution completed", commandType)
|
logger.Debugf("%s execution completed", commandType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
|
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, error) {
|
||||||
localUser := privilegeResult.User
|
localUser := privilegeResult.User
|
||||||
if localUser == nil {
|
|
||||||
return nil, nil, errors.New("no user in privilege result")
|
|
||||||
}
|
|
||||||
|
|
||||||
// If PTY requested but su doesn't support --pty, skip su and use executor
|
|
||||||
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
|
|
||||||
if hasPty && !s.suSupportsPty {
|
|
||||||
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
|
|
||||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
|
||||||
}
|
|
||||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
|
||||||
return cmd, cleanup, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try su first for system integration (PAM/audit) when privileged
|
// Try su first for system integration (PAM/audit) when privileged
|
||||||
cmd, err := s.createSuCommand(session, localUser, hasPty)
|
cmd, err := s.createSuCommand(session, localUser, hasPty)
|
||||||
if err != nil || privilegeResult.UsedFallback {
|
if err != nil || privilegeResult.UsedFallback {
|
||||||
log.Debugf("su command failed, falling back to executor: %v", err)
|
log.Debugf("su command failed, falling back to executor: %v", err)
|
||||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
cmd, err = s.createExecutorCommand(session, localUser, hasPty)
|
||||||
if err != nil {
|
}
|
||||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
|
||||||
}
|
if err != nil {
|
||||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
return nil, fmt.Errorf("create command with privileges: %w", err)
|
||||||
return cmd, cleanup, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||||
return cmd, func() {}, nil
|
return cmd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeCommand executes the command and handles I/O and exit codes
|
// executeCommand executes the command and handles I/O and exit codes
|
||||||
func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, cleanup func()) bool {
|
func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd) bool {
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
s.setupProcessGroup(execCmd)
|
s.setupProcessGroup(execCmd)
|
||||||
|
|
||||||
stdinPipe, err := execCmd.StdinPipe()
|
stdinPipe, err := execCmd.StdinPipe()
|
||||||
@@ -107,7 +79,7 @@ func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd
|
|||||||
}
|
}
|
||||||
|
|
||||||
execCmd.Stdout = session
|
execCmd.Stdout = session
|
||||||
execCmd.Stderr = session.Stderr()
|
execCmd.Stderr = session
|
||||||
|
|
||||||
if execCmd.Dir != "" {
|
if execCmd.Dir != "" {
|
||||||
if _, err := os.Stat(execCmd.Dir); err != nil {
|
if _, err := os.Stat(execCmd.Dir); err != nil {
|
||||||
|
|||||||
@@ -3,13 +3,11 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
|
|
||||||
"github.com/gliderlabs/ssh"
|
"github.com/gliderlabs/ssh"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
|
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
|
||||||
@@ -20,8 +18,8 @@ func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createExecutorCommand is not supported on JS/WASM
|
// createExecutorCommand is not supported on JS/WASM
|
||||||
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
|
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
||||||
return nil, nil, errNotSupported
|
return nil, errNotSupported
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareCommandEnv is not supported on JS/WASM
|
// prepareCommandEnv is not supported on JS/WASM
|
||||||
@@ -34,19 +32,5 @@ func (s *Server) setupProcessGroup(_ *exec.Cmd) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// killProcessGroup is not supported on JS/WASM
|
// killProcessGroup is not supported on JS/WASM
|
||||||
func (s *Server) killProcessGroup(*exec.Cmd) {
|
func (s *Server) killProcessGroup(_ *exec.Cmd) {
|
||||||
}
|
|
||||||
|
|
||||||
// detectSuPtySupport always returns false on JS/WASM
|
|
||||||
func (s *Server) detectSuPtySupport(context.Context) bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// executeCommandWithPty is not supported on JS/WASM
|
|
||||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
|
||||||
logger.Errorf("PTY command execution not supported on JS/WASM")
|
|
||||||
if err := session.Exit(1); err != nil {
|
|
||||||
logSessionExitError(logger, err)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,14 +3,12 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
@@ -20,6 +18,47 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// createSuCommand creates a command using su -l -c for privilege switching
|
||||||
|
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||||
|
suPath, err := exec.LookPath("su")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("su command not available: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
command := session.RawCommand()
|
||||||
|
if command == "" {
|
||||||
|
return nil, fmt.Errorf("no command specified for su execution")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: handle pty flag if available
|
||||||
|
args := []string{"-l", localUser.Username, "-c", command}
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(session.Context(), suPath, args...)
|
||||||
|
cmd.Dir = localUser.HomeDir
|
||||||
|
|
||||||
|
return cmd, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getShellCommandArgs returns the shell command and arguments for executing a command string
|
||||||
|
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||||
|
if cmdString == "" {
|
||||||
|
return []string{shell, "-l"}
|
||||||
|
}
|
||||||
|
return []string{shell, "-l", "-c", cmdString}
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepareCommandEnv prepares environment variables for command execution on Unix
|
||||||
|
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
||||||
|
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||||
|
env = append(env, prepareSSHEnv(session)...)
|
||||||
|
for _, v := range session.Environ() {
|
||||||
|
if acceptEnv(v) {
|
||||||
|
env = append(env, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return env
|
||||||
|
}
|
||||||
|
|
||||||
// ptyManager manages Pty file operations with thread safety
|
// ptyManager manages Pty file operations with thread safety
|
||||||
type ptyManager struct {
|
type ptyManager struct {
|
||||||
file *os.File
|
file *os.File
|
||||||
@@ -49,7 +88,7 @@ func (pm *ptyManager) Setsize(ws *pty.Winsize) error {
|
|||||||
pm.mu.RLock()
|
pm.mu.RLock()
|
||||||
defer pm.mu.RUnlock()
|
defer pm.mu.RUnlock()
|
||||||
if pm.closed {
|
if pm.closed {
|
||||||
return errors.New("pty is closed")
|
return errors.New("Pty is closed")
|
||||||
}
|
}
|
||||||
return pty.Setsize(pm.file, ws)
|
return pty.Setsize(pm.file, ws)
|
||||||
}
|
}
|
||||||
@@ -58,78 +97,6 @@ func (pm *ptyManager) File() *os.File {
|
|||||||
return pm.file
|
return pm.file
|
||||||
}
|
}
|
||||||
|
|
||||||
// detectSuPtySupport checks if su supports the --pty flag
|
|
||||||
func (s *Server) detectSuPtySupport(ctx context.Context) bool {
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
cmd := exec.CommandContext(ctx, "su", "--help")
|
|
||||||
output, err := cmd.CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("su --help failed (may not support --help): %v", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
supported := strings.Contains(string(output), "--pty")
|
|
||||||
log.Debugf("su --pty support detected: %v", supported)
|
|
||||||
return supported
|
|
||||||
}
|
|
||||||
|
|
||||||
// createSuCommand creates a command using su -l -c for privilege switching
|
|
||||||
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
|
||||||
suPath, err := exec.LookPath("su")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("su command not available: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
command := session.RawCommand()
|
|
||||||
if command == "" {
|
|
||||||
return nil, fmt.Errorf("no command specified for su execution")
|
|
||||||
}
|
|
||||||
|
|
||||||
args := []string{"-l"}
|
|
||||||
if hasPty && s.suSupportsPty {
|
|
||||||
args = append(args, "--pty")
|
|
||||||
}
|
|
||||||
args = append(args, localUser.Username, "-c", command)
|
|
||||||
|
|
||||||
cmd := exec.CommandContext(session.Context(), suPath, args...)
|
|
||||||
cmd.Dir = localUser.HomeDir
|
|
||||||
|
|
||||||
return cmd, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getShellCommandArgs returns the shell command and arguments for executing a command string
|
|
||||||
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
|
||||||
if cmdString == "" {
|
|
||||||
return []string{shell, "-l"}
|
|
||||||
}
|
|
||||||
return []string{shell, "-l", "-c", cmdString}
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepareCommandEnv prepares environment variables for command execution on Unix
|
|
||||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
|
||||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
|
||||||
env = append(env, prepareSSHEnv(session)...)
|
|
||||||
for _, v := range session.Environ() {
|
|
||||||
if acceptEnv(v) {
|
|
||||||
env = append(env, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return env
|
|
||||||
}
|
|
||||||
|
|
||||||
// executeCommandWithPty executes a command with PTY allocation
|
|
||||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
|
||||||
termType := ptyReq.Term
|
|
||||||
if termType == "" {
|
|
||||||
termType = "xterm-256color"
|
|
||||||
}
|
|
||||||
execCmd.Env = append(execCmd.Env, fmt.Sprintf("TERM=%s", termType))
|
|
||||||
|
|
||||||
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||||
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
|
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -144,12 +111,14 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResu
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Infof("starting interactive shell: %s", execCmd.Path)
|
shell := execCmd.Path
|
||||||
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
cmd := session.Command()
|
||||||
}
|
if len(cmd) == 0 {
|
||||||
|
logger.Infof("starting interactive shell: %s", shell)
|
||||||
|
} else {
|
||||||
|
logger.Infof("executing command: %s", safeLogCommand(cmd))
|
||||||
|
}
|
||||||
|
|
||||||
// runPtyCommand runs a command with PTY management (common code for interactive and command execution)
|
|
||||||
func (s *Server) runPtyCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
|
||||||
ptmx, err := s.startPtyCommandWithSize(execCmd, ptyReq)
|
ptmx, err := s.startPtyCommandWithSize(execCmd, ptyReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Pty start failed: %v", err)
|
logger.Errorf("Pty start failed: %v", err)
|
||||||
@@ -301,29 +270,9 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) {
|
|||||||
pgid := cmd.Process.Pid
|
pgid := cmd.Process.Pid
|
||||||
|
|
||||||
if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
|
if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
|
||||||
logger.Debugf("kill process group SIGTERM: %v", err)
|
logger.Debugf("kill process group SIGTERM failed: %v", err)
|
||||||
return
|
if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil {
|
||||||
}
|
logger.Debugf("kill process group SIGKILL failed: %v", err)
|
||||||
|
|
||||||
const gracePeriod = 500 * time.Millisecond
|
|
||||||
const checkInterval = 50 * time.Millisecond
|
|
||||||
|
|
||||||
ticker := time.NewTicker(checkInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
timeout := time.After(gracePeriod)
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-timeout:
|
|
||||||
if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil {
|
|
||||||
logger.Debugf("kill process group SIGKILL: %v", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
if err := syscall.Kill(-pgid, 0); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -7,6 +9,7 @@ import (
|
|||||||
"os/exec"
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
@@ -31,11 +34,6 @@ func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return s.getUserEnvironmentWithToken(userToken, username, domain)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
|
|
||||||
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
|
|
||||||
userProfile, err := s.loadUserProfile(userToken, username, domain)
|
userProfile, err := s.loadUserProfile(userToken, username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
||||||
@@ -268,11 +266,6 @@ func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||||
if privilegeResult.User == nil {
|
|
||||||
logger.Errorf("no user in privilege result")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := session.Command()
|
cmd := session.Command()
|
||||||
shell := getUserShell(privilegeResult.User.Uid)
|
shell := getUserShell(privilegeResult.User.Uid)
|
||||||
|
|
||||||
@@ -282,6 +275,7 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResu
|
|||||||
logger.Infof("executing command: %s", safeLogCommand(cmd))
|
logger.Infof("executing command: %s", safeLogCommand(cmd))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always use user switching on Windows - no direct execution
|
||||||
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
|
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -295,8 +289,45 @@ func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
|
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
|
||||||
logger.Info("starting interactive shell")
|
localUser := privilegeResult.User
|
||||||
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
|
|
||||||
|
username, domain := s.parseUsername(localUser.Username)
|
||||||
|
shell := getUserShell(localUser.Uid)
|
||||||
|
|
||||||
|
var command string
|
||||||
|
rawCmd := session.RawCommand()
|
||||||
|
if rawCmd != "" {
|
||||||
|
command = rawCmd
|
||||||
|
}
|
||||||
|
|
||||||
|
req := PtyExecutionRequest{
|
||||||
|
Shell: shell,
|
||||||
|
Command: command,
|
||||||
|
Width: ptyReq.Window.Width,
|
||||||
|
Height: ptyReq.Window.Height,
|
||||||
|
Username: username,
|
||||||
|
Domain: domain,
|
||||||
|
}
|
||||||
|
err := executePtyCommandWithUserToken(session.Context(), session, req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Windows ConPty with user switching failed: %v", err)
|
||||||
|
var errorMsg string
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
errorMsg = "Windows user switching failed - NetBird must run as a Windows service or with elevated privileges for user switching\r\n"
|
||||||
|
} else {
|
||||||
|
errorMsg = "User switching failed - login command not available\r\n"
|
||||||
|
}
|
||||||
|
if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
|
||||||
|
logger.Debugf(errWriteSession, writeErr)
|
||||||
|
}
|
||||||
|
if err := session.Exit(1); err != nil {
|
||||||
|
logSessionExitError(logger, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debugf("Windows ConPty command execution with user switching completed")
|
||||||
}
|
}
|
||||||
|
|
||||||
type PtyExecutionRequest struct {
|
type PtyExecutionRequest struct {
|
||||||
@@ -324,7 +355,7 @@ func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, re
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
server := &Server{}
|
server := &Server{}
|
||||||
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
|
userEnv, err := server.getUserEnvironment(req.Username, req.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
||||||
userEnv = os.Environ()
|
userEnv = os.Environ()
|
||||||
@@ -377,54 +408,3 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) {
|
|||||||
logger.Debugf("kill process failed: %v", err)
|
logger.Debugf("kill process failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// detectSuPtySupport always returns false on Windows as su is not available
|
|
||||||
func (s *Server) detectSuPtySupport(context.Context) bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
|
||||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
|
||||||
command := session.RawCommand()
|
|
||||||
if command == "" {
|
|
||||||
logger.Error("no command specified for PTY execution")
|
|
||||||
if err := session.Exit(1); err != nil {
|
|
||||||
logSessionExitError(logger, err)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
|
|
||||||
}
|
|
||||||
|
|
||||||
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
|
|
||||||
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
|
|
||||||
localUser := privilegeResult.User
|
|
||||||
if localUser == nil {
|
|
||||||
logger.Errorf("no user in privilege result")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
username, domain := s.parseUsername(localUser.Username)
|
|
||||||
shell := getUserShell(localUser.Uid)
|
|
||||||
|
|
||||||
req := PtyExecutionRequest{
|
|
||||||
Shell: shell,
|
|
||||||
Command: command,
|
|
||||||
Width: ptyReq.Window.Width,
|
|
||||||
Height: ptyReq.Window.Height,
|
|
||||||
Username: username,
|
|
||||||
Domain: domain,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
|
|
||||||
logger.Errorf("ConPty execution failed: %v", err)
|
|
||||||
if err := session.Exit(1); err != nil {
|
|
||||||
logSessionExitError(logger, err)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug("ConPty execution completed")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -61,14 +61,12 @@ const (
|
|||||||
convertDomainError = "convert domain to UTF16: %w"
|
convertDomainError = "convert domain to UTF16: %w"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateWindowsExecutorCommand creates a Windows command with privilege dropping.
|
func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, config WindowsExecutorConfig) (*exec.Cmd, error) {
|
||||||
// The caller must close the returned token handle after starting the process.
|
|
||||||
func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, config WindowsExecutorConfig) (*exec.Cmd, windows.Token, error) {
|
|
||||||
if config.Username == "" {
|
if config.Username == "" {
|
||||||
return nil, 0, errors.New("username cannot be empty")
|
return nil, errors.New("username cannot be empty")
|
||||||
}
|
}
|
||||||
if config.Shell == "" {
|
if config.Shell == "" {
|
||||||
return nil, 0, errors.New("shell cannot be empty")
|
return nil, errors.New("shell cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
shell := config.Shell
|
shell := config.Shell
|
||||||
@@ -82,13 +80,13 @@ func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, co
|
|||||||
|
|
||||||
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
|
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
|
||||||
|
|
||||||
cmd, token, err := pd.CreateWindowsProcessAsUser(
|
cmd, err := pd.CreateWindowsProcessAsUser(
|
||||||
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
|
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("create Windows process as user: %w", err)
|
return nil, fmt.Errorf("create Windows process as user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return cmd, token, nil
|
return cmd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -183,6 +181,7 @@ func newLsaString(s string) lsaString {
|
|||||||
func generateS4UUserToken(username, domain string) (windows.Handle, error) {
|
func generateS4UUserToken(username, domain string) (windows.Handle, error) {
|
||||||
userCpn := buildUserCpn(username, domain)
|
userCpn := buildUserCpn(username, domain)
|
||||||
|
|
||||||
|
// Use proper domain detection logic instead of simple string check
|
||||||
pd := NewPrivilegeDropper()
|
pd := NewPrivilegeDropper()
|
||||||
isDomainUser := !pd.isLocalUser(domain)
|
isDomainUser := !pd.isLocalUser(domain)
|
||||||
|
|
||||||
@@ -344,7 +343,7 @@ func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, er
|
|||||||
|
|
||||||
upnOffset := structSize
|
upnOffset := structSize
|
||||||
upnBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + upnOffset))
|
upnBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + upnOffset))
|
||||||
copy((*[1025]uint16)(unsafe.Pointer(upnBuffer))[:len(upnUtf16)], upnUtf16)
|
copy((*[512]uint16)(unsafe.Pointer(upnBuffer))[:len(upnUtf16)], upnUtf16)
|
||||||
|
|
||||||
s4uLogon.ClientUpn = unicodeString{
|
s4uLogon.ClientUpn = unicodeString{
|
||||||
Length: uint16((len(upnUtf16) - 1) * 2),
|
Length: uint16((len(upnUtf16) - 1) * 2),
|
||||||
@@ -516,34 +515,31 @@ func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsernam
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateWindowsProcessAsUser creates a process as user with safe argument passing (for SFTP and executables).
|
// CreateWindowsProcessAsUser creates a process as user with safe argument passing (for SFTP and executables)
|
||||||
// The caller must close the returned token handle after starting the process.
|
func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, executablePath string, args []string, username, domain, workingDir string) (*exec.Cmd, error) {
|
||||||
func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, executablePath string, args []string, username, domain, workingDir string) (*exec.Cmd, windows.Token, error) {
|
fullUsername := buildUserCpn(username, domain)
|
||||||
|
|
||||||
token, err := pd.createToken(username, domain)
|
token, err := pd.createToken(username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("user authentication: %w", err)
|
return nil, fmt.Errorf("user authentication: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debugf("using S4U authentication for user %s", fullUsername)
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := windows.CloseHandle(token); err != nil {
|
if err := windows.CloseHandle(token); err != nil {
|
||||||
log.Debugf("close impersonation token: %v", err)
|
log.Debugf("close impersonation token error: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
cmd, primaryToken, err := pd.createProcessWithToken(ctx, windows.Token(token), executablePath, args, workingDir)
|
return pd.createProcessWithToken(ctx, windows.Token(token), executablePath, args, workingDir)
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return cmd, primaryToken, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// createProcessWithToken creates process with the specified token and executable path.
|
// createProcessWithToken creates process with the specified token and executable path
|
||||||
// The caller must close the returned token handle after starting the process.
|
func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceToken windows.Token, executablePath string, args []string, workingDir string) (*exec.Cmd, error) {
|
||||||
func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceToken windows.Token, executablePath string, args []string, workingDir string) (*exec.Cmd, windows.Token, error) {
|
|
||||||
cmd := exec.CommandContext(ctx, executablePath, args[1:]...)
|
cmd := exec.CommandContext(ctx, executablePath, args[1:]...)
|
||||||
cmd.Dir = workingDir
|
cmd.Dir = workingDir
|
||||||
|
|
||||||
|
// Duplicate the token to create a primary token that can be used to start a new process
|
||||||
var primaryToken windows.Token
|
var primaryToken windows.Token
|
||||||
err := windows.DuplicateTokenEx(
|
err := windows.DuplicateTokenEx(
|
||||||
sourceToken,
|
sourceToken,
|
||||||
@@ -554,14 +550,14 @@ func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceTo
|
|||||||
&primaryToken,
|
&primaryToken,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("duplicate token to primary token: %w", err)
|
return nil, fmt.Errorf("duplicate token to primary token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
Token: syscall.Token(primaryToken),
|
Token: syscall.Token(primaryToken),
|
||||||
}
|
}
|
||||||
|
|
||||||
return cmd, primaryToken, nil
|
return cmd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub)
|
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub)
|
||||||
|
|||||||
@@ -311,16 +311,6 @@ func TestJWTFailClose(t *testing.T) {
|
|||||||
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "blocks_token_exceeding_max_age",
|
|
||||||
tokenClaims: jwt.MapClaims{
|
|
||||||
"iss": issuer,
|
|
||||||
"aud": audience,
|
|
||||||
"sub": "test-user",
|
|
||||||
"exp": time.Now().Add(time.Hour).Unix(),
|
|
||||||
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import (
|
|||||||
gojwt "github.com/golang-jwt/jwt/v5"
|
gojwt "github.com/golang-jwt/jwt/v5"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
cryptossh "golang.org/x/crypto/ssh"
|
cryptossh "golang.org/x/crypto/ssh"
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
@@ -106,20 +105,12 @@ type sshConnectionState struct {
|
|||||||
remoteAddr string
|
remoteAddr string
|
||||||
}
|
}
|
||||||
|
|
||||||
type authKey string
|
|
||||||
|
|
||||||
func newAuthKey(username string, remoteAddr net.Addr) authKey {
|
|
||||||
return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String()))
|
|
||||||
}
|
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
sshServer *ssh.Server
|
sshServer *ssh.Server
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
hostKeyPEM []byte
|
hostKeyPEM []byte
|
||||||
sessions map[SessionKey]ssh.Session
|
sessions map[SessionKey]ssh.Session
|
||||||
sessionCancels map[ConnectionKey]context.CancelFunc
|
sessionCancels map[ConnectionKey]context.CancelFunc
|
||||||
sessionJWTUsers map[SessionKey]string
|
|
||||||
pendingAuthJWT map[authKey]string
|
|
||||||
|
|
||||||
allowLocalPortForwarding bool
|
allowLocalPortForwarding bool
|
||||||
allowRemotePortForwarding bool
|
allowRemotePortForwarding bool
|
||||||
@@ -137,8 +128,6 @@ type Server struct {
|
|||||||
jwtValidator *jwt.Validator
|
jwtValidator *jwt.Validator
|
||||||
jwtExtractor *jwt.ClaimsExtractor
|
jwtExtractor *jwt.ClaimsExtractor
|
||||||
jwtConfig *JWTConfig
|
jwtConfig *JWTConfig
|
||||||
|
|
||||||
suSupportsPty bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type JWTConfig struct {
|
type JWTConfig struct {
|
||||||
@@ -157,14 +146,6 @@ type Config struct {
|
|||||||
HostKeyPEM []byte
|
HostKeyPEM []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionInfo contains information about an active SSH session
|
|
||||||
type SessionInfo struct {
|
|
||||||
Username string
|
|
||||||
RemoteAddress string
|
|
||||||
Command string
|
|
||||||
JWTUsername string
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates an SSH server instance with the provided host key and optional JWT configuration
|
// New creates an SSH server instance with the provided host key and optional JWT configuration
|
||||||
// If jwtConfig is nil, JWT authentication is disabled
|
// If jwtConfig is nil, JWT authentication is disabled
|
||||||
func New(config *Config) *Server {
|
func New(config *Config) *Server {
|
||||||
@@ -172,8 +153,6 @@ func New(config *Config) *Server {
|
|||||||
mu: sync.RWMutex{},
|
mu: sync.RWMutex{},
|
||||||
hostKeyPEM: config.HostKeyPEM,
|
hostKeyPEM: config.HostKeyPEM,
|
||||||
sessions: make(map[SessionKey]ssh.Session),
|
sessions: make(map[SessionKey]ssh.Session),
|
||||||
sessionJWTUsers: make(map[SessionKey]string),
|
|
||||||
pendingAuthJWT: make(map[authKey]string),
|
|
||||||
remoteForwardListeners: make(map[ForwardKey]net.Listener),
|
remoteForwardListeners: make(map[ForwardKey]net.Listener),
|
||||||
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
|
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
|
||||||
jwtEnabled: config.JWT != nil,
|
jwtEnabled: config.JWT != nil,
|
||||||
@@ -192,8 +171,6 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
|||||||
return errors.New("SSH server is already running")
|
return errors.New("SSH server is already running")
|
||||||
}
|
}
|
||||||
|
|
||||||
s.suSupportsPty = s.detectSuPtySupport(ctx)
|
|
||||||
|
|
||||||
ln, addrDesc, err := s.createListener(ctx, addr)
|
ln, addrDesc, err := s.createListener(ctx, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create listener: %w", err)
|
return fmt.Errorf("create listener: %w", err)
|
||||||
@@ -209,7 +186,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
|||||||
log.Infof("SSH server started on %s", addrDesc)
|
log.Infof("SSH server started on %s", addrDesc)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := sshServer.Serve(ln); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
|
if err := sshServer.Serve(ln); !isShutdownError(err) {
|
||||||
log.Errorf("SSH server error: %v", err)
|
log.Errorf("SSH server error: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -252,58 +229,15 @@ func (s *Server) Stop() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.sshServer.Close(); err != nil {
|
if err := s.sshServer.Close(); err != nil && !isShutdownError(err) {
|
||||||
log.Debugf("close SSH server: %v", err)
|
return fmt.Errorf("shutdown SSH server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sshServer = nil
|
s.sshServer = nil
|
||||||
|
|
||||||
maps.Clear(s.sessions)
|
|
||||||
maps.Clear(s.sessionJWTUsers)
|
|
||||||
maps.Clear(s.pendingAuthJWT)
|
|
||||||
maps.Clear(s.sshConnections)
|
|
||||||
|
|
||||||
for _, cancelFunc := range s.sessionCancels {
|
|
||||||
cancelFunc()
|
|
||||||
}
|
|
||||||
maps.Clear(s.sessionCancels)
|
|
||||||
|
|
||||||
for _, listener := range s.remoteForwardListeners {
|
|
||||||
if err := listener.Close(); err != nil {
|
|
||||||
log.Debugf("close remote forward listener: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
maps.Clear(s.remoteForwardListeners)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStatus returns the current status of the SSH server and active sessions
|
|
||||||
func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
|
|
||||||
enabled = s.sshServer != nil
|
|
||||||
|
|
||||||
for sessionKey, session := range s.sessions {
|
|
||||||
cmd := "<interactive shell>"
|
|
||||||
if len(session.Command()) > 0 {
|
|
||||||
cmd = safeLogCommand(session.Command())
|
|
||||||
}
|
|
||||||
|
|
||||||
jwtUsername := s.sessionJWTUsers[sessionKey]
|
|
||||||
|
|
||||||
sessions = append(sessions, SessionInfo{
|
|
||||||
Username: session.User(),
|
|
||||||
RemoteAddress: session.RemoteAddr().String(),
|
|
||||||
Command: cmd,
|
|
||||||
JWTUsername: jwtUsername,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return enabled, sessions
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNetstackNet sets the netstack network for userspace networking
|
// SetNetstackNet sets the netstack network for userspace networking
|
||||||
func (s *Server) SetNetstackNet(net *netstack.Net) {
|
func (s *Server) SetNetstackNet(net *netstack.Net) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
@@ -388,15 +322,10 @@ func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
|
func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
|
||||||
if jwtConfig == nil {
|
if jwtConfig == nil || jwtConfig.MaxTokenAge <= 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
maxTokenAge := jwtConfig.MaxTokenAge
|
|
||||||
if maxTokenAge <= 0 {
|
|
||||||
maxTokenAge = DefaultJWTMaxTokenAge
|
|
||||||
}
|
|
||||||
|
|
||||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||||
if !ok {
|
if !ok {
|
||||||
userID := extractUserID(token)
|
userID := extractUserID(token)
|
||||||
@@ -411,7 +340,7 @@ func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
|
|||||||
|
|
||||||
issuedAt := time.Unix(int64(iat), 0)
|
issuedAt := time.Unix(int64(iat), 0)
|
||||||
tokenAge := time.Since(issuedAt)
|
tokenAge := time.Since(issuedAt)
|
||||||
maxAge := time.Duration(maxTokenAge) * time.Second
|
maxAge := time.Duration(jwtConfig.MaxTokenAge) * time.Second
|
||||||
if tokenAge > maxAge {
|
if tokenAge > maxAge {
|
||||||
userID := getUserIDFromClaims(claims)
|
userID := getUserIDFromClaims(claims)
|
||||||
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
|
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
|
||||||
@@ -508,11 +437,6 @@ func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
key := newAuthKey(ctx.User(), ctx.RemoteAddr())
|
|
||||||
s.mu.Lock()
|
|
||||||
s.pendingAuthJWT[key] = userAuth.UserId
|
|
||||||
s.mu.Unlock()
|
|
||||||
|
|
||||||
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
|
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -608,6 +532,19 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
|||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isShutdownError(err error) bool {
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var opErr *net.OpError
|
||||||
|
if errors.As(err, &opErr) && opErr.Op == "accept" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
||||||
if err := enableUserSwitching(); err != nil {
|
if err := enableUserSwitching(); err != nil {
|
||||||
log.Warnf("failed to enable user switching: %v", err)
|
log.Warnf("failed to enable user switching: %v", err)
|
||||||
|
|||||||
@@ -65,12 +65,9 @@ func TestSSHServerIntegration(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
|
||||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
|
||||||
errChan <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
started <- actualAddr
|
started <- actualAddr
|
||||||
|
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||||
|
errChan <- server.Start(context.Background(), addrPort)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -82,6 +79,8 @@ func TestSSHServerIntegration(t *testing.T) {
|
|||||||
t.Fatal("Server start timeout")
|
t.Fatal("Server start timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Server is ready when we get the started signal
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := server.Stop()
|
err := server.Stop()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -167,12 +166,9 @@ func TestSSHServerMultipleConnections(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
|
||||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
|
||||||
errChan <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
started <- actualAddr
|
started <- actualAddr
|
||||||
|
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||||
|
errChan <- server.Start(context.Background(), addrPort)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -184,6 +180,8 @@ func TestSSHServerMultipleConnections(t *testing.T) {
|
|||||||
t.Fatal("Server start timeout")
|
t.Fatal("Server start timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Server is ready when we get the started signal
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := server.Stop()
|
err := server.Stop()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -279,12 +277,9 @@ func TestSSHServerNoAuthMode(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
|
||||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
|
||||||
errChan <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
started <- actualAddr
|
started <- actualAddr
|
||||||
|
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||||
|
errChan <- server.Start(context.Background(), addrPort)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -296,6 +291,8 @@ func TestSSHServerNoAuthMode(t *testing.T) {
|
|||||||
t.Fatal("Server start timeout")
|
t.Fatal("Server start timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Server is ready when we get the started signal
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := server.Stop()
|
err := server.Stop()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -364,12 +361,9 @@ func TestSSHServerStartStopCycle(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
|
||||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
|
||||||
errChan <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
started <- actualAddr
|
started <- actualAddr
|
||||||
|
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||||
|
errChan <- server.Start(context.Background(), addrPort)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -11,29 +11,13 @@ import (
|
|||||||
|
|
||||||
"github.com/gliderlabs/ssh"
|
"github.com/gliderlabs/ssh"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
cryptossh "golang.org/x/crypto/ssh"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// sessionHandler handles SSH sessions
|
// sessionHandler handles SSH sessions
|
||||||
func (s *Server) sessionHandler(session ssh.Session) {
|
func (s *Server) sessionHandler(session ssh.Session) {
|
||||||
sessionKey := s.registerSession(session)
|
sessionKey := s.registerSession(session)
|
||||||
|
|
||||||
key := newAuthKey(session.User(), session.RemoteAddr())
|
|
||||||
s.mu.Lock()
|
|
||||||
jwtUsername := s.pendingAuthJWT[key]
|
|
||||||
if jwtUsername != "" {
|
|
||||||
s.sessionJWTUsers[sessionKey] = jwtUsername
|
|
||||||
delete(s.pendingAuthJWT, key)
|
|
||||||
}
|
|
||||||
s.mu.Unlock()
|
|
||||||
|
|
||||||
logger := log.WithField("session", sessionKey)
|
logger := log.WithField("session", sessionKey)
|
||||||
if jwtUsername != "" {
|
logger.Infof("SSH session started")
|
||||||
logger = logger.WithField("jwt_user", jwtUsername)
|
|
||||||
logger.Infof("SSH session started (JWT user: %s)", jwtUsername)
|
|
||||||
} else {
|
|
||||||
logger.Infof("SSH session started")
|
|
||||||
}
|
|
||||||
sessionStart := time.Now()
|
sessionStart := time.Now()
|
||||||
|
|
||||||
defer s.unregisterSession(sessionKey, session)
|
defer s.unregisterSession(sessionKey, session)
|
||||||
@@ -102,10 +86,9 @@ func (s *Server) registerSession(session ssh.Session) SessionKey {
|
|||||||
return sessionKey
|
return sessionKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) {
|
func (s *Server) unregisterSession(sessionKey SessionKey, _ ssh.Session) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
delete(s.sessions, sessionKey)
|
delete(s.sessions, sessionKey)
|
||||||
delete(s.sessionJWTUsers, sessionKey)
|
|
||||||
|
|
||||||
// Cancel all port forwarding connections for this session
|
// Cancel all port forwarding connections for this session
|
||||||
var connectionsToCancel []ConnectionKey
|
var connectionsToCancel []ConnectionKey
|
||||||
@@ -123,12 +106,6 @@ func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil {
|
|
||||||
if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok {
|
|
||||||
delete(s.sshConnections, sshConn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -62,12 +62,9 @@ func TestSSHServer_SFTPSubsystem(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
|
||||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
|
||||||
errChan <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
started <- actualAddr
|
started <- actualAddr
|
||||||
|
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||||
|
errChan <- server.Start(context.Background(), addrPort)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -171,12 +168,9 @@ func TestSSHServer_SFTPDisabled(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
|
||||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
|
||||||
errChan <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
started <- actualAddr
|
started <- actualAddr
|
||||||
|
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||||
|
errChan <- server.Start(context.Background(), addrPort)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -14,14 +14,13 @@ import (
|
|||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
// createSftpCommand creates a Windows SFTP command with user switching.
|
// createSftpCommand creates a Windows SFTP command with user switching
|
||||||
// The caller must close the returned token handle after starting the process.
|
func (s *Server) createSftpCommand(targetUser *user.User, sess ssh.Session) (*exec.Cmd, error) {
|
||||||
func (s *Server) createSftpCommand(targetUser *user.User, sess ssh.Session) (*exec.Cmd, windows.Token, error) {
|
|
||||||
username, domain := s.parseUsername(targetUser.Username)
|
username, domain := s.parseUsername(targetUser.Username)
|
||||||
|
|
||||||
netbirdPath, err := os.Executable()
|
netbirdPath, err := os.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("get netbird executable path: %w", err)
|
return nil, fmt.Errorf("get netbird executable path: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []string{
|
args := []string{
|
||||||
@@ -34,32 +33,27 @@ func (s *Server) createSftpCommand(targetUser *user.User, sess ssh.Session) (*ex
|
|||||||
pd := NewPrivilegeDropper()
|
pd := NewPrivilegeDropper()
|
||||||
token, err := pd.createToken(username, domain)
|
token, err := pd.createToken(username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("create token: %w", err)
|
return nil, fmt.Errorf("create token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := windows.CloseHandle(token); err != nil {
|
if err := windows.CloseHandle(token); err != nil {
|
||||||
log.Warnf("failed to close impersonation token: %v", err)
|
log.Warnf("failed to close Windows token handle: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
cmd, primaryToken, err := pd.createProcessWithToken(sess.Context(), windows.Token(token), netbirdPath, append([]string{netbirdPath}, args...), targetUser.HomeDir)
|
cmd, err := pd.createProcessWithToken(sess.Context(), windows.Token(token), netbirdPath, append([]string{netbirdPath}, args...), targetUser.HomeDir)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("create SFTP command: %w", err)
|
return nil, fmt.Errorf("create SFTP command: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Created Windows SFTP command with user switching for %s", targetUser.Username)
|
log.Debugf("Created Windows SFTP command with user switching for %s", targetUser.Username)
|
||||||
return cmd, primaryToken, nil
|
return cmd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeSftpCommand executes a Windows SFTP command with proper I/O handling
|
// executeSftpCommand executes a Windows SFTP command with proper I/O handling
|
||||||
func (s *Server) executeSftpCommand(sess ssh.Session, sftpCmd *exec.Cmd, token windows.Token) error {
|
func (s *Server) executeSftpCommand(sess ssh.Session, sftpCmd *exec.Cmd) error {
|
||||||
defer func() {
|
|
||||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
|
||||||
log.Debugf("close primary token: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
sftpCmd.Stdin = sess
|
sftpCmd.Stdin = sess
|
||||||
sftpCmd.Stdout = sess
|
sftpCmd.Stdout = sess
|
||||||
sftpCmd.Stderr = sess.Stderr()
|
sftpCmd.Stderr = sess.Stderr()
|
||||||
@@ -83,9 +77,9 @@ func (s *Server) executeSftpCommand(sess ssh.Session, sftpCmd *exec.Cmd, token w
|
|||||||
|
|
||||||
// executeSftpWithPrivilegeDrop executes SFTP using Windows privilege dropping
|
// executeSftpWithPrivilegeDrop executes SFTP using Windows privilege dropping
|
||||||
func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error {
|
func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error {
|
||||||
sftpCmd, token, err := s.createSftpCommand(targetUser, sess)
|
sftpCmd, err := s.createSftpCommand(targetUser, sess)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create sftp: %w", err)
|
return fmt.Errorf("create sftp: %w", err)
|
||||||
}
|
}
|
||||||
return s.executeSftpCommand(sess, sftpCmd, token)
|
return s.executeSftpCommand(sess, sftpCmd)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -99,17 +99,12 @@ func getShellFromPasswd(userID string) string {
|
|||||||
|
|
||||||
// prepareUserEnv prepares environment variables for user execution
|
// prepareUserEnv prepares environment variables for user execution
|
||||||
func prepareUserEnv(user *user.User, shell string) []string {
|
func prepareUserEnv(user *user.User, shell string) []string {
|
||||||
pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games"
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`
|
|
||||||
}
|
|
||||||
|
|
||||||
return []string{
|
return []string{
|
||||||
fmt.Sprint("SHELL=" + shell),
|
fmt.Sprint("SHELL=" + shell),
|
||||||
fmt.Sprint("USER=" + user.Username),
|
fmt.Sprint("USER=" + user.Username),
|
||||||
fmt.Sprint("LOGNAME=" + user.Username),
|
fmt.Sprint("LOGNAME=" + user.Username),
|
||||||
fmt.Sprint("HOME=" + user.HomeDir),
|
fmt.Sprint("HOME=" + user.HomeDir),
|
||||||
"PATH=" + pathValue,
|
"PATH=/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ func (s *Server) resolveRequestedUser(requestedUsername string) (*user.User, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := validateUsername(requestedUsername); err != nil {
|
if err := validateUsername(requestedUsername); err != nil {
|
||||||
return nil, fmt.Errorf("invalid username %q: %w", requestedUsername, err)
|
return nil, fmt.Errorf("invalid username: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
u, err := lookupUser(requestedUsername)
|
u, err := lookupUser(requestedUsername)
|
||||||
|
|||||||
@@ -152,18 +152,17 @@ func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) {
|
|||||||
return groups, nil
|
return groups, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping.
|
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
|
||||||
// Returns the command and a cleanup function (no-op on Unix).
|
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
|
||||||
log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||||
|
|
||||||
if err := validateUsername(localUser.Username); err != nil {
|
if err := validateUsername(localUser.Username); err != nil {
|
||||||
return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
|
return nil, fmt.Errorf("invalid username: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
uid, gid, groups, err := s.parseUserCredentials(localUser)
|
uid, gid, groups, err := s.parseUserCredentials(localUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("parse user credentials: %w", err)
|
return nil, fmt.Errorf("parse user credentials: %w", err)
|
||||||
}
|
}
|
||||||
privilegeDropper := NewPrivilegeDropper()
|
privilegeDropper := NewPrivilegeDropper()
|
||||||
config := ExecutorConfig{
|
config := ExecutorConfig{
|
||||||
@@ -176,8 +175,7 @@ func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User
|
|||||||
PTY: hasPty,
|
PTY: hasPty,
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd, err := privilegeDropper.CreateExecutorCommand(session.Context(), config)
|
return privilegeDropper.CreateExecutorCommand(session.Context(), config)
|
||||||
return cmd, func() {}, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// enableUserSwitching is a no-op on Unix systems
|
// enableUserSwitching is a no-op on Unix systems
|
||||||
@@ -188,9 +186,6 @@ func enableUserSwitching() error {
|
|||||||
// createPtyCommand creates the exec.Cmd for Pty execution respecting privilege check results
|
// createPtyCommand creates the exec.Cmd for Pty execution respecting privilege check results
|
||||||
func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) {
|
func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) {
|
||||||
localUser := privilegeResult.User
|
localUser := privilegeResult.User
|
||||||
if localUser == nil {
|
|
||||||
return nil, errors.New("no user in privilege result")
|
|
||||||
}
|
|
||||||
|
|
||||||
if privilegeResult.UsedFallback {
|
if privilegeResult.UsedFallback {
|
||||||
return s.createDirectPtyCommand(session, localUser, ptyReq), nil
|
return s.createDirectPtyCommand(session, localUser, ptyReq), nil
|
||||||
|
|||||||
@@ -86,22 +86,20 @@ func validateUsernameFormat(username string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createExecutorCommand creates a command using Windows executor for privilege dropping.
|
// createExecutorCommand creates a command using Windows executor for privilege dropping
|
||||||
// Returns the command and a cleanup function that must be called after starting the process.
|
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
|
||||||
log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||||
|
|
||||||
username, _ := s.parseUsername(localUser.Username)
|
username, _ := s.parseUsername(localUser.Username)
|
||||||
if err := validateUsername(username); err != nil {
|
if err := validateUsername(username); err != nil {
|
||||||
return nil, nil, fmt.Errorf("invalid username %q: %w", username, err)
|
return nil, fmt.Errorf("invalid username: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.createUserSwitchCommand(localUser, session, hasPty)
|
return s.createUserSwitchCommand(localUser, session, hasPty)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createUserSwitchCommand creates a command with Windows user switching.
|
// createUserSwitchCommand creates a command with Windows user switching
|
||||||
// Returns the command and a cleanup function that must be called after starting the process.
|
func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, error) {
|
||||||
func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) {
|
|
||||||
username, domain := s.parseUsername(localUser.Username)
|
username, domain := s.parseUsername(localUser.Username)
|
||||||
|
|
||||||
shell := getUserShell(localUser.Uid)
|
shell := getUserShell(localUser.Uid)
|
||||||
@@ -122,20 +120,7 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi
|
|||||||
}
|
}
|
||||||
|
|
||||||
dropper := NewPrivilegeDropper()
|
dropper := NewPrivilegeDropper()
|
||||||
cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config)
|
return dropper.CreateWindowsExecutorCommand(session.Context(), config)
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cleanup := func() {
|
|
||||||
if token != 0 {
|
|
||||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
|
||||||
log.Debugf("close primary token: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return cmd, cleanup, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseUsername extracts username and domain from a Windows username
|
// parseUsername extracts username and domain from a Windows username
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ func executeConPtyWithConfig(commandLine string, config ExecutionConfig) error {
|
|||||||
log.Debugf("close output write handle: %v", err)
|
log.Debugf("close output write handle: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return bridgeConPtyIO(ctx, hPty, inputWrite, outputRead, session, session, session, pi.Process)
|
return bridgeConPtyIO(ctx, hPty, inputWrite, outputRead, session, session, pi.Process)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createConPtyPipes creates input/output pipes for ConPty.
|
// createConPtyPipes creates input/output pipes for ConPty.
|
||||||
@@ -323,13 +323,8 @@ func duplicateToPrimaryToken(token windows.Handle) (windows.Handle, error) {
|
|||||||
return primaryToken, nil
|
return primaryToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionExiter provides the Exit method for reporting process exit status.
|
|
||||||
type SessionExiter interface {
|
|
||||||
Exit(code int) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// bridgeConPtyIO handles I/O bridging between ConPty and readers/writers.
|
// bridgeConPtyIO handles I/O bridging between ConPty and readers/writers.
|
||||||
func bridgeConPtyIO(ctx context.Context, hPty, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer, session SessionExiter, process windows.Handle) error {
|
func bridgeConPtyIO(ctx context.Context, hPty, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer, process windows.Handle) error {
|
||||||
if err := ctx.Err(); err != nil {
|
if err := ctx.Err(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -342,15 +337,6 @@ func bridgeConPtyIO(ctx context.Context, hPty, inputWrite, outputRead windows.Ha
|
|||||||
return processErr
|
return processErr
|
||||||
}
|
}
|
||||||
|
|
||||||
var exitCode uint32
|
|
||||||
if err := windows.GetExitCodeProcess(process, &exitCode); err != nil {
|
|
||||||
log.Debugf("get exit code: %v", err)
|
|
||||||
} else {
|
|
||||||
if err := session.Exit(int(exitCode)); err != nil {
|
|
||||||
log.Debugf("report exit code: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up in the original order after process completes
|
// Clean up in the original order after process completes
|
||||||
if err := reader.Close(); err != nil {
|
if err := reader.Close(); err != nil {
|
||||||
log.Debugf("close reader: %v", err)
|
log.Debugf("close reader: %v", err)
|
||||||
|
|||||||
@@ -228,7 +228,6 @@ func TestWindowsHandleReader(t *testing.T) {
|
|||||||
if err := windows.CloseHandle(writeHandle); err != nil {
|
if err := windows.CloseHandle(writeHandle); err != nil {
|
||||||
t.Fatalf("Should close write handle: %v", err)
|
t.Fatalf("Should close write handle: %v", err)
|
||||||
}
|
}
|
||||||
writeHandle = windows.InvalidHandle
|
|
||||||
|
|
||||||
// Test reading
|
// Test reading
|
||||||
reader := &windowsHandleReader{handle: readHandle}
|
reader := &windowsHandleReader{handle: readHandle}
|
||||||
|
|||||||
@@ -11,12 +11,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/protobuf/types/known/durationpb"
|
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/anonymize"
|
"github.com/netbirdio/netbird/client/anonymize"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
probeRelay "github.com/netbirdio/netbird/client/internal/relay"
|
probeRelay "github.com/netbirdio/netbird/client/internal/relay"
|
||||||
@@ -85,18 +81,6 @@ type NsServerGroupStateOutput struct {
|
|||||||
Error string `json:"error" yaml:"error"`
|
Error string `json:"error" yaml:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SSHSessionOutput struct {
|
|
||||||
Username string `json:"username" yaml:"username"`
|
|
||||||
RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
|
|
||||||
Command string `json:"command" yaml:"command"`
|
|
||||||
JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type SSHServerStateOutput struct {
|
|
||||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
|
||||||
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OutputOverview struct {
|
type OutputOverview struct {
|
||||||
Peers PeersStateOutput `json:"peers" yaml:"peers"`
|
Peers PeersStateOutput `json:"peers" yaml:"peers"`
|
||||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||||
@@ -116,10 +100,11 @@ type OutputOverview struct {
|
|||||||
Events []SystemEventOutput `json:"events" yaml:"events"`
|
Events []SystemEventOutput `json:"events" yaml:"events"`
|
||||||
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
||||||
ProfileName string `json:"profileName" yaml:"profileName"`
|
ProfileName string `json:"profileName" yaml:"profileName"`
|
||||||
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, daemonVersion string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview {
|
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview {
|
||||||
|
pbFullStatus := resp.GetFullStatus()
|
||||||
|
|
||||||
managementState := pbFullStatus.GetManagementState()
|
managementState := pbFullStatus.GetManagementState()
|
||||||
managementOverview := ManagementStateOutput{
|
managementOverview := ManagementStateOutput{
|
||||||
URL: managementState.GetURL(),
|
URL: managementState.GetURL(),
|
||||||
@@ -135,13 +120,12 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, da
|
|||||||
}
|
}
|
||||||
|
|
||||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||||
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
|
peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
|
||||||
peersOverview := mapPeers(pbFullStatus.GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
|
|
||||||
|
|
||||||
overview := OutputOverview{
|
overview := OutputOverview{
|
||||||
Peers: peersOverview,
|
Peers: peersOverview,
|
||||||
CliVersion: version.NetbirdVersion(),
|
CliVersion: version.NetbirdVersion(),
|
||||||
DaemonVersion: daemonVersion,
|
DaemonVersion: resp.GetDaemonVersion(),
|
||||||
ManagementState: managementOverview,
|
ManagementState: managementOverview,
|
||||||
SignalState: signalOverview,
|
SignalState: signalOverview,
|
||||||
Relays: relayOverview,
|
Relays: relayOverview,
|
||||||
@@ -157,7 +141,6 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, da
|
|||||||
Events: mapEvents(pbFullStatus.GetEvents()),
|
Events: mapEvents(pbFullStatus.GetEvents()),
|
||||||
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
|
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
|
||||||
ProfileName: profName,
|
ProfileName: profName,
|
||||||
SSHServerState: sshServerOverview,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if anon {
|
if anon {
|
||||||
@@ -207,30 +190,6 @@ func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput {
|
|||||||
return mappedNSGroups
|
return mappedNSGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput {
|
|
||||||
if sshServerState == nil {
|
|
||||||
return SSHServerStateOutput{
|
|
||||||
Enabled: false,
|
|
||||||
Sessions: []SSHSessionOutput{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sessions := make([]SSHSessionOutput, 0, len(sshServerState.GetSessions()))
|
|
||||||
for _, session := range sshServerState.GetSessions() {
|
|
||||||
sessions = append(sessions, SSHSessionOutput{
|
|
||||||
Username: session.GetUsername(),
|
|
||||||
RemoteAddress: session.GetRemoteAddress(),
|
|
||||||
Command: session.GetCommand(),
|
|
||||||
JWTUsername: session.GetJwtUsername(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return SSHServerStateOutput{
|
|
||||||
Enabled: sshServerState.GetEnabled(),
|
|
||||||
Sessions: sessions,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func mapPeers(
|
func mapPeers(
|
||||||
peers []*proto.PeerState,
|
peers []*proto.PeerState,
|
||||||
statusFilter string,
|
statusFilter string,
|
||||||
@@ -341,7 +300,7 @@ func ParseToYAML(overview OutputOverview) (string, error) {
|
|||||||
return string(yamlBytes), nil
|
return string(yamlBytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
|
func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
|
||||||
var managementConnString string
|
var managementConnString string
|
||||||
if overview.ManagementState.Connected {
|
if overview.ManagementState.Connected {
|
||||||
managementConnString = "Connected"
|
managementConnString = "Connected"
|
||||||
@@ -446,41 +405,6 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
|||||||
lazyConnectionEnabledStatus = "true"
|
lazyConnectionEnabledStatus = "true"
|
||||||
}
|
}
|
||||||
|
|
||||||
sshServerStatus := "Disabled"
|
|
||||||
if overview.SSHServerState.Enabled {
|
|
||||||
sessionCount := len(overview.SSHServerState.Sessions)
|
|
||||||
if sessionCount > 0 {
|
|
||||||
sessionWord := "session"
|
|
||||||
if sessionCount > 1 {
|
|
||||||
sessionWord = "sessions"
|
|
||||||
}
|
|
||||||
sshServerStatus = fmt.Sprintf("Enabled (%d active %s)", sessionCount, sessionWord)
|
|
||||||
} else {
|
|
||||||
sshServerStatus = "Enabled"
|
|
||||||
}
|
|
||||||
|
|
||||||
if showSSHSessions && sessionCount > 0 {
|
|
||||||
for _, session := range overview.SSHServerState.Sessions {
|
|
||||||
var sessionDisplay string
|
|
||||||
if session.JWTUsername != "" {
|
|
||||||
sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s",
|
|
||||||
session.JWTUsername,
|
|
||||||
session.RemoteAddress,
|
|
||||||
session.Username,
|
|
||||||
session.Command,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
sessionDisplay = fmt.Sprintf("[%s@%s] %s",
|
|
||||||
session.Username,
|
|
||||||
session.RemoteAddress,
|
|
||||||
session.Command,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
sshServerStatus += "\n " + sessionDisplay
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||||
|
|
||||||
goos := runtime.GOOS
|
goos := runtime.GOOS
|
||||||
@@ -504,7 +428,6 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
|||||||
"Interface type: %s\n"+
|
"Interface type: %s\n"+
|
||||||
"Quantum resistance: %s\n"+
|
"Quantum resistance: %s\n"+
|
||||||
"Lazy connection: %s\n"+
|
"Lazy connection: %s\n"+
|
||||||
"SSH Server: %s\n"+
|
|
||||||
"Networks: %s\n"+
|
"Networks: %s\n"+
|
||||||
"Forwarding rules: %d\n"+
|
"Forwarding rules: %d\n"+
|
||||||
"Peers count: %s\n",
|
"Peers count: %s\n",
|
||||||
@@ -521,7 +444,6 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
|||||||
interfaceTypeString,
|
interfaceTypeString,
|
||||||
rosenpassEnabledStatus,
|
rosenpassEnabledStatus,
|
||||||
lazyConnectionEnabledStatus,
|
lazyConnectionEnabledStatus,
|
||||||
sshServerStatus,
|
|
||||||
networks,
|
networks,
|
||||||
overview.NumberOfForwardingRules,
|
overview.NumberOfForwardingRules,
|
||||||
peersCountString,
|
peersCountString,
|
||||||
@@ -532,7 +454,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
|||||||
func ParseToFullDetailSummary(overview OutputOverview) string {
|
func ParseToFullDetailSummary(overview OutputOverview) string {
|
||||||
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
||||||
parsedEventsString := parseEvents(overview.Events)
|
parsedEventsString := parseEvents(overview.Events)
|
||||||
summary := ParseGeneralSummary(overview, true, true, true, true)
|
summary := ParseGeneralSummary(overview, true, true, true)
|
||||||
|
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
"Peers detail:"+
|
"Peers detail:"+
|
||||||
@@ -546,94 +468,6 @@ func ParseToFullDetailSummary(overview OutputOverview) string {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ToProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
|
||||||
pbFullStatus := proto.FullStatus{
|
|
||||||
ManagementState: &proto.ManagementState{},
|
|
||||||
SignalState: &proto.SignalState{},
|
|
||||||
LocalPeerState: &proto.LocalPeerState{},
|
|
||||||
Peers: []*proto.PeerState{},
|
|
||||||
}
|
|
||||||
|
|
||||||
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
|
|
||||||
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
|
|
||||||
if err := fullStatus.ManagementState.Error; err != nil {
|
|
||||||
pbFullStatus.ManagementState.Error = err.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
|
|
||||||
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
|
|
||||||
if err := fullStatus.SignalState.Error; err != nil {
|
|
||||||
pbFullStatus.SignalState.Error = err.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
|
|
||||||
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
|
|
||||||
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
|
|
||||||
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
|
|
||||||
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
|
|
||||||
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
|
|
||||||
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
|
|
||||||
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
|
|
||||||
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
|
|
||||||
|
|
||||||
for _, peerState := range fullStatus.Peers {
|
|
||||||
pbPeerState := &proto.PeerState{
|
|
||||||
IP: peerState.IP,
|
|
||||||
PubKey: peerState.PubKey,
|
|
||||||
ConnStatus: peerState.ConnStatus.String(),
|
|
||||||
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
|
||||||
Relayed: peerState.Relayed,
|
|
||||||
LocalIceCandidateType: peerState.LocalIceCandidateType,
|
|
||||||
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
|
|
||||||
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
|
|
||||||
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
|
|
||||||
RelayAddress: peerState.RelayServerAddress,
|
|
||||||
Fqdn: peerState.FQDN,
|
|
||||||
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
|
|
||||||
BytesRx: peerState.BytesRx,
|
|
||||||
BytesTx: peerState.BytesTx,
|
|
||||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
|
||||||
Networks: maps.Keys(peerState.GetRoutes()),
|
|
||||||
Latency: durationpb.New(peerState.Latency),
|
|
||||||
SshHostKey: peerState.SSHHostKey,
|
|
||||||
}
|
|
||||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, relayState := range fullStatus.Relays {
|
|
||||||
pbRelayState := &proto.RelayState{
|
|
||||||
URI: relayState.URI,
|
|
||||||
Available: relayState.Err == nil,
|
|
||||||
}
|
|
||||||
if err := relayState.Err; err != nil {
|
|
||||||
pbRelayState.Error = err.Error()
|
|
||||||
}
|
|
||||||
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, dnsState := range fullStatus.NSGroupStates {
|
|
||||||
var err string
|
|
||||||
if dnsState.Error != nil {
|
|
||||||
err = dnsState.Error.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
var servers []string
|
|
||||||
for _, server := range dnsState.Servers {
|
|
||||||
servers = append(servers, server.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
pbDnsState := &proto.NSGroupState{
|
|
||||||
Servers: servers,
|
|
||||||
Domains: dnsState.Domains,
|
|
||||||
Enabled: dnsState.Enabled,
|
|
||||||
Error: err,
|
|
||||||
}
|
|
||||||
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &pbFullStatus
|
|
||||||
}
|
|
||||||
|
|
||||||
func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
|
func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
|
||||||
var (
|
var (
|
||||||
peersString = ""
|
peersString = ""
|
||||||
@@ -912,13 +746,4 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
|
|||||||
event.Metadata[k] = a.AnonymizeString(v)
|
event.Metadata[k] = a.AnonymizeString(v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, session := range overview.SSHServerState.Sessions {
|
|
||||||
if host, port, err := net.SplitHostPort(session.RemoteAddress); err == nil {
|
|
||||||
overview.SSHServerState.Sessions[i].RemoteAddress = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
|
|
||||||
} else {
|
|
||||||
overview.SSHServerState.Sessions[i].RemoteAddress = a.AnonymizeIPString(session.RemoteAddress)
|
|
||||||
}
|
|
||||||
overview.SSHServerState.Sessions[i].Command = a.AnonymizeString(session.Command)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -231,14 +231,10 @@ var overview = OutputOverview{
|
|||||||
Networks: []string{
|
Networks: []string{
|
||||||
"10.10.0.0/24",
|
"10.10.0.0/24",
|
||||||
},
|
},
|
||||||
SSHServerState: SSHServerStateOutput{
|
|
||||||
Enabled: false,
|
|
||||||
Sessions: []SSHSessionOutput{},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||||
convertedResult := ConvertToStatusOutputOverview(resp.GetFullStatus(), false, resp.GetDaemonVersion(), "", nil, nil, nil, "", "")
|
convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil, "", "")
|
||||||
|
|
||||||
assert.Equal(t, overview, convertedResult)
|
assert.Equal(t, overview, convertedResult)
|
||||||
}
|
}
|
||||||
@@ -389,11 +385,7 @@ func TestParsingToJSON(t *testing.T) {
|
|||||||
],
|
],
|
||||||
"events": [],
|
"events": [],
|
||||||
"lazyConnectionEnabled": false,
|
"lazyConnectionEnabled": false,
|
||||||
"profileName":"",
|
"profileName":""
|
||||||
"sshServer":{
|
|
||||||
"enabled":false,
|
|
||||||
"sessions":[]
|
|
||||||
}
|
|
||||||
}`
|
}`
|
||||||
// @formatter:on
|
// @formatter:on
|
||||||
|
|
||||||
@@ -496,9 +488,6 @@ dnsServers:
|
|||||||
events: []
|
events: []
|
||||||
lazyConnectionEnabled: false
|
lazyConnectionEnabled: false
|
||||||
profileName: ""
|
profileName: ""
|
||||||
sshServer:
|
|
||||||
enabled: false
|
|
||||||
sessions: []
|
|
||||||
`
|
`
|
||||||
|
|
||||||
assert.Equal(t, expectedYAML, yaml)
|
assert.Equal(t, expectedYAML, yaml)
|
||||||
@@ -565,7 +554,6 @@ NetBird IP: 192.168.178.100/16
|
|||||||
Interface type: Kernel
|
Interface type: Kernel
|
||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Lazy connection: false
|
Lazy connection: false
|
||||||
SSH Server: Disabled
|
|
||||||
Networks: 10.10.0.0/24
|
Networks: 10.10.0.0/24
|
||||||
Forwarding rules: 0
|
Forwarding rules: 0
|
||||||
Peers count: 2/2 Connected
|
Peers count: 2/2 Connected
|
||||||
@@ -575,7 +563,7 @@ Peers count: 2/2 Connected
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestParsingToShortVersion(t *testing.T) {
|
func TestParsingToShortVersion(t *testing.T) {
|
||||||
shortVersion := ParseGeneralSummary(overview, false, false, false, false)
|
shortVersion := ParseGeneralSummary(overview, false, false, false)
|
||||||
|
|
||||||
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
|
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
|
||||||
Daemon version: 0.14.1
|
Daemon version: 0.14.1
|
||||||
@@ -590,7 +578,6 @@ NetBird IP: 192.168.178.100/16
|
|||||||
Interface type: Kernel
|
Interface type: Kernel
|
||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Lazy connection: false
|
Lazy connection: false
|
||||||
SSH Server: Disabled
|
|
||||||
Networks: 10.10.0.0/24
|
Networks: 10.10.0.0/24
|
||||||
Forwarding rules: 0
|
Forwarding rules: 0
|
||||||
Peers count: 2/2 Connected
|
Peers count: 2/2 Connected
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 4.9 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 7.4 KiB |
@@ -55,7 +55,6 @@ const (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
censoredPreSharedKey = "**********"
|
censoredPreSharedKey = "**********"
|
||||||
maxSSHJWTCacheTTL = 86_400 // 24 hours in seconds
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -86,22 +85,21 @@ func main() {
|
|||||||
|
|
||||||
// Create the service client (this also builds the settings or networks UI if requested).
|
// Create the service client (this also builds the settings or networks UI if requested).
|
||||||
client := newServiceClient(&newServiceClientArgs{
|
client := newServiceClient(&newServiceClientArgs{
|
||||||
addr: flags.daemonAddr,
|
addr: flags.daemonAddr,
|
||||||
logFile: logFile,
|
logFile: logFile,
|
||||||
app: a,
|
app: a,
|
||||||
showSettings: flags.showSettings,
|
showSettings: flags.showSettings,
|
||||||
showNetworks: flags.showNetworks,
|
showNetworks: flags.showNetworks,
|
||||||
showLoginURL: flags.showLoginURL,
|
showLoginURL: flags.showLoginURL,
|
||||||
showDebug: flags.showDebug,
|
showDebug: flags.showDebug,
|
||||||
showProfiles: flags.showProfiles,
|
showProfiles: flags.showProfiles,
|
||||||
showQuickActions: flags.showQuickActions,
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// Watch for theme/settings changes to update the icon.
|
// Watch for theme/settings changes to update the icon.
|
||||||
go watchSettingsChanges(a, client)
|
go watchSettingsChanges(a, client)
|
||||||
|
|
||||||
// Run in window mode if any UI flag was set.
|
// Run in window mode if any UI flag was set.
|
||||||
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions {
|
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles {
|
||||||
a.Run()
|
a.Run()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -113,29 +111,23 @@ func main() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if running {
|
if running {
|
||||||
log.Infof("another process is running with pid %d, sending signal to show window", pid)
|
log.Warnf("another process is running with pid %d, exiting", pid)
|
||||||
if err := sendShowWindowSignal(pid); err != nil {
|
|
||||||
log.Errorf("send signal to running instance: %v", err)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
client.setupSignalHandler(client.ctx)
|
|
||||||
|
|
||||||
client.setDefaultFonts()
|
client.setDefaultFonts()
|
||||||
systray.Run(client.onTrayReady, client.onTrayExit)
|
systray.Run(client.onTrayReady, client.onTrayExit)
|
||||||
}
|
}
|
||||||
|
|
||||||
type cliFlags struct {
|
type cliFlags struct {
|
||||||
daemonAddr string
|
daemonAddr string
|
||||||
showSettings bool
|
showSettings bool
|
||||||
showNetworks bool
|
showNetworks bool
|
||||||
showProfiles bool
|
showProfiles bool
|
||||||
showDebug bool
|
showDebug bool
|
||||||
showLoginURL bool
|
showLoginURL bool
|
||||||
showQuickActions bool
|
errorMsg string
|
||||||
errorMsg string
|
saveLogsInFile bool
|
||||||
saveLogsInFile bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseFlags reads and returns all needed command-line flags.
|
// parseFlags reads and returns all needed command-line flags.
|
||||||
@@ -151,7 +143,6 @@ func parseFlags() *cliFlags {
|
|||||||
flag.BoolVar(&flags.showNetworks, "networks", false, "run networks window")
|
flag.BoolVar(&flags.showNetworks, "networks", false, "run networks window")
|
||||||
flag.BoolVar(&flags.showProfiles, "profiles", false, "run profiles window")
|
flag.BoolVar(&flags.showProfiles, "profiles", false, "run profiles window")
|
||||||
flag.BoolVar(&flags.showDebug, "debug", false, "run debug window")
|
flag.BoolVar(&flags.showDebug, "debug", false, "run debug window")
|
||||||
flag.BoolVar(&flags.showQuickActions, "quick-actions", false, "run quick actions window")
|
|
||||||
flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window")
|
flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window")
|
||||||
flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
|
flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
|
||||||
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
|
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
|
||||||
@@ -167,9 +158,11 @@ func initLogFile() (string, error) {
|
|||||||
|
|
||||||
// watchSettingsChanges listens for Fyne theme/settings changes and updates the client icon.
|
// watchSettingsChanges listens for Fyne theme/settings changes and updates the client icon.
|
||||||
func watchSettingsChanges(a fyne.App, client *serviceClient) {
|
func watchSettingsChanges(a fyne.App, client *serviceClient) {
|
||||||
a.Settings().AddListener(func(settings fyne.Settings) {
|
settingsChangeChan := make(chan fyne.Settings)
|
||||||
|
a.Settings().AddChangeListener(settingsChangeChan)
|
||||||
|
for range settingsChangeChan {
|
||||||
client.updateIcon()
|
client.updateIcon()
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// showErrorMessage displays an error message in a simple window.
|
// showErrorMessage displays an error message in a simple window.
|
||||||
@@ -277,7 +270,6 @@ type serviceClient struct {
|
|||||||
sEnableSSHLocalPortForward *widget.Check
|
sEnableSSHLocalPortForward *widget.Check
|
||||||
sEnableSSHRemotePortForward *widget.Check
|
sEnableSSHRemotePortForward *widget.Check
|
||||||
sDisableSSHAuth *widget.Check
|
sDisableSSHAuth *widget.Check
|
||||||
iSSHJWTCacheTTL *widget.Entry
|
|
||||||
|
|
||||||
// observable settings over corresponding iMngURL and iPreSharedKey values.
|
// observable settings over corresponding iMngURL and iPreSharedKey values.
|
||||||
managementURL string
|
managementURL string
|
||||||
@@ -297,7 +289,6 @@ type serviceClient struct {
|
|||||||
enableSSHLocalPortForward bool
|
enableSSHLocalPortForward bool
|
||||||
enableSSHRemotePortForward bool
|
enableSSHRemotePortForward bool
|
||||||
disableSSHAuth bool
|
disableSSHAuth bool
|
||||||
sshJWTCacheTTL int
|
|
||||||
|
|
||||||
connected bool
|
connected bool
|
||||||
update *version.Update
|
update *version.Update
|
||||||
@@ -307,7 +298,6 @@ type serviceClient struct {
|
|||||||
showNetworks bool
|
showNetworks bool
|
||||||
wNetworks fyne.Window
|
wNetworks fyne.Window
|
||||||
wProfiles fyne.Window
|
wProfiles fyne.Window
|
||||||
wQuickActions fyne.Window
|
|
||||||
|
|
||||||
eventManager *event.Manager
|
eventManager *event.Manager
|
||||||
|
|
||||||
@@ -327,15 +317,14 @@ type menuHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type newServiceClientArgs struct {
|
type newServiceClientArgs struct {
|
||||||
addr string
|
addr string
|
||||||
logFile string
|
logFile string
|
||||||
app fyne.App
|
app fyne.App
|
||||||
showSettings bool
|
showSettings bool
|
||||||
showNetworks bool
|
showNetworks bool
|
||||||
showDebug bool
|
showDebug bool
|
||||||
showLoginURL bool
|
showLoginURL bool
|
||||||
showProfiles bool
|
showProfiles bool
|
||||||
showQuickActions bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// newServiceClient instance constructor
|
// newServiceClient instance constructor
|
||||||
@@ -371,8 +360,6 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
|||||||
s.showDebugUI()
|
s.showDebugUI()
|
||||||
case args.showProfiles:
|
case args.showProfiles:
|
||||||
s.showProfilesUI()
|
s.showProfilesUI()
|
||||||
case args.showQuickActions:
|
|
||||||
s.showQuickActionsUI()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
return s
|
||||||
@@ -454,7 +441,6 @@ func (s *serviceClient) showSettingsUI() {
|
|||||||
s.sEnableSSHLocalPortForward = widget.NewCheck("Enable SSH Local Port Forwarding", nil)
|
s.sEnableSSHLocalPortForward = widget.NewCheck("Enable SSH Local Port Forwarding", nil)
|
||||||
s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil)
|
s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil)
|
||||||
s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil)
|
s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil)
|
||||||
s.iSSHJWTCacheTTL = widget.NewEntry()
|
|
||||||
|
|
||||||
s.wSettings.SetContent(s.getSettingsForm())
|
s.wSettings.SetContent(s.getSettingsForm())
|
||||||
s.wSettings.Resize(fyne.NewSize(600, 400))
|
s.wSettings.Resize(fyne.NewSize(600, 400))
|
||||||
@@ -510,15 +496,11 @@ func (s *serviceClient) saveSettings() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
iMngURL := strings.TrimSpace(s.iMngURL.Text)
|
iMngURL := strings.TrimSpace(s.iMngURL.Text)
|
||||||
|
defer s.wSettings.Close()
|
||||||
|
|
||||||
if s.hasSettingsChanged(iMngURL, port, mtu) {
|
if s.hasSettingsChanged(iMngURL, port, mtu) {
|
||||||
if err := s.applySettingsChanges(iMngURL, port, mtu); err != nil {
|
s.applySettingsChanges(iMngURL, port, mtu)
|
||||||
dialog.ShowError(err, s.wSettings)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.wSettings.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) validateSettings() error {
|
func (s *serviceClient) validateSettings() error {
|
||||||
@@ -535,9 +517,6 @@ func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, errors.New("Invalid interface port")
|
return 0, 0, errors.New("Invalid interface port")
|
||||||
}
|
}
|
||||||
if port < 1 || port > 65535 {
|
|
||||||
return 0, 0, errors.New("Invalid interface port: out of range 1-65535")
|
|
||||||
}
|
|
||||||
|
|
||||||
var mtu int64
|
var mtu int64
|
||||||
mtuText := strings.TrimSpace(s.iMTU.Text)
|
mtuText := strings.TrimSpace(s.iMTU.Text)
|
||||||
@@ -569,21 +548,20 @@ func (s *serviceClient) hasSettingsChanged(iMngURL string, port, mtu int64) bool
|
|||||||
s.hasSSHChanges()
|
s.hasSSHChanges()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) applySettingsChanges(iMngURL string, port, mtu int64) error {
|
func (s *serviceClient) applySettingsChanges(iMngURL string, port, mtu int64) {
|
||||||
s.managementURL = iMngURL
|
s.managementURL = iMngURL
|
||||||
s.preSharedKey = s.iPreSharedKey.Text
|
s.preSharedKey = s.iPreSharedKey.Text
|
||||||
s.mtu = uint16(mtu)
|
s.mtu = uint16(mtu)
|
||||||
|
|
||||||
req, err := s.buildSetConfigRequest(iMngURL, port, mtu)
|
req, err := s.buildSetConfigRequest(iMngURL, port, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("build config request: %w", err)
|
log.Errorf("build config request: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.sendConfigUpdate(req); err != nil {
|
if err := s.sendConfigUpdate(req); err != nil {
|
||||||
return fmt.Errorf("set configuration: %w", err)
|
dialog.ShowError(fmt.Errorf("Failed to set configuration: %v", err), s.wSettings)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (*proto.SetConfigRequest, error) {
|
func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (*proto.SetConfigRequest, error) {
|
||||||
@@ -621,23 +599,10 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
|
|||||||
|
|
||||||
req.EnableSSHRoot = &s.sEnableSSHRoot.Checked
|
req.EnableSSHRoot = &s.sEnableSSHRoot.Checked
|
||||||
req.EnableSSHSFTP = &s.sEnableSSHSFTP.Checked
|
req.EnableSSHSFTP = &s.sEnableSSHSFTP.Checked
|
||||||
req.EnableSSHLocalPortForwarding = &s.sEnableSSHLocalPortForward.Checked
|
req.EnableSSHLocalPortForward = &s.sEnableSSHLocalPortForward.Checked
|
||||||
req.EnableSSHRemotePortForwarding = &s.sEnableSSHRemotePortForward.Checked
|
req.EnableSSHRemotePortForward = &s.sEnableSSHRemotePortForward.Checked
|
||||||
req.DisableSSHAuth = &s.sDisableSSHAuth.Checked
|
req.DisableSSHAuth = &s.sDisableSSHAuth.Checked
|
||||||
|
|
||||||
sshJWTCacheTTLText := strings.TrimSpace(s.iSSHJWTCacheTTL.Text)
|
|
||||||
if sshJWTCacheTTLText != "" {
|
|
||||||
sshJWTCacheTTL, err := strconv.ParseInt(sshJWTCacheTTLText, 10, 32)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("Invalid SSH JWT Cache TTL value")
|
|
||||||
}
|
|
||||||
if sshJWTCacheTTL < 0 || sshJWTCacheTTL > maxSSHJWTCacheTTL {
|
|
||||||
return nil, fmt.Errorf("SSH JWT Cache TTL must be between 0 and %d seconds", maxSSHJWTCacheTTL)
|
|
||||||
}
|
|
||||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
|
||||||
req.SshJWTCacheTTL = &sshJWTCacheTTL32
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.iPreSharedKey.Text != censoredPreSharedKey {
|
if s.iPreSharedKey.Text != censoredPreSharedKey {
|
||||||
req.OptionalPreSharedKey = &s.iPreSharedKey.Text
|
req.OptionalPreSharedKey = &s.iPreSharedKey.Text
|
||||||
}
|
}
|
||||||
@@ -723,27 +688,16 @@ func (s *serviceClient) getSSHForm() *widget.Form {
|
|||||||
{Text: "Enable SSH Local Port Forwarding", Widget: s.sEnableSSHLocalPortForward},
|
{Text: "Enable SSH Local Port Forwarding", Widget: s.sEnableSSHLocalPortForward},
|
||||||
{Text: "Enable SSH Remote Port Forwarding", Widget: s.sEnableSSHRemotePortForward},
|
{Text: "Enable SSH Remote Port Forwarding", Widget: s.sEnableSSHRemotePortForward},
|
||||||
{Text: "Disable SSH Authentication", Widget: s.sDisableSSHAuth},
|
{Text: "Disable SSH Authentication", Widget: s.sDisableSSHAuth},
|
||||||
{Text: "JWT Cache TTL (seconds, 0=disabled)", Widget: s.iSSHJWTCacheTTL},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) hasSSHChanges() bool {
|
func (s *serviceClient) hasSSHChanges() bool {
|
||||||
currentSSHJWTCacheTTL := s.sshJWTCacheTTL
|
|
||||||
if text := strings.TrimSpace(s.iSSHJWTCacheTTL.Text); text != "" {
|
|
||||||
val, err := strconv.Atoi(text)
|
|
||||||
if err != nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
currentSSHJWTCacheTTL = val
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.enableSSHRoot != s.sEnableSSHRoot.Checked ||
|
return s.enableSSHRoot != s.sEnableSSHRoot.Checked ||
|
||||||
s.enableSSHSFTP != s.sEnableSSHSFTP.Checked ||
|
s.enableSSHSFTP != s.sEnableSSHSFTP.Checked ||
|
||||||
s.enableSSHLocalPortForward != s.sEnableSSHLocalPortForward.Checked ||
|
s.enableSSHLocalPortForward != s.sEnableSSHLocalPortForward.Checked ||
|
||||||
s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked ||
|
s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked ||
|
||||||
s.disableSSHAuth != s.sDisableSSHAuth.Checked ||
|
s.disableSSHAuth != s.sDisableSSHAuth.Checked
|
||||||
s.sshJWTCacheTTL != currentSSHJWTCacheTTL
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
|
func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
|
||||||
@@ -762,20 +716,11 @@ func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginRe
|
|||||||
return nil, fmt.Errorf("get current user: %w", err)
|
return nil, fmt.Errorf("get current user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loginReq := &proto.LoginRequest{
|
loginResp, err := conn.Login(ctx, &proto.LoginRequest{
|
||||||
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
|
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
|
||||||
ProfileName: &activeProf.Name,
|
ProfileName: &activeProf.Name,
|
||||||
Username: &currUser.Username,
|
Username: &currUser.Username,
|
||||||
}
|
})
|
||||||
|
|
||||||
profileState, err := s.profileManager.GetProfileState(activeProf.Name)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
|
||||||
} else if profileState.Email != "" {
|
|
||||||
loginReq.Hint = &profileState.Email
|
|
||||||
}
|
|
||||||
|
|
||||||
loginResp, err := conn.Login(ctx, loginReq)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("login to management: %w", err)
|
return nil, fmt.Errorf("login to management: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1280,9 +1225,6 @@ func (s *serviceClient) getSrvConfig() {
|
|||||||
if cfg.DisableSSHAuth != nil {
|
if cfg.DisableSSHAuth != nil {
|
||||||
s.disableSSHAuth = *cfg.DisableSSHAuth
|
s.disableSSHAuth = *cfg.DisableSSHAuth
|
||||||
}
|
}
|
||||||
if cfg.SSHJWTCacheTTL != nil {
|
|
||||||
s.sshJWTCacheTTL = *cfg.SSHJWTCacheTTL
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.showAdvancedSettings {
|
if s.showAdvancedSettings {
|
||||||
s.iMngURL.SetText(s.managementURL)
|
s.iMngURL.SetText(s.managementURL)
|
||||||
@@ -1319,9 +1261,6 @@ func (s *serviceClient) getSrvConfig() {
|
|||||||
if cfg.DisableSSHAuth != nil {
|
if cfg.DisableSSHAuth != nil {
|
||||||
s.sDisableSSHAuth.SetChecked(*cfg.DisableSSHAuth)
|
s.sDisableSSHAuth.SetChecked(*cfg.DisableSSHAuth)
|
||||||
}
|
}
|
||||||
if cfg.SSHJWTCacheTTL != nil {
|
|
||||||
s.iSSHJWTCacheTTL.SetText(strconv.Itoa(*cfg.SSHJWTCacheTTL))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.mNotifications == nil {
|
if s.mNotifications == nil {
|
||||||
@@ -1392,14 +1331,21 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
|
|||||||
config.DisableServerRoutes = cfg.DisableServerRoutes
|
config.DisableServerRoutes = cfg.DisableServerRoutes
|
||||||
config.BlockLANAccess = cfg.BlockLanAccess
|
config.BlockLANAccess = cfg.BlockLanAccess
|
||||||
|
|
||||||
config.EnableSSHRoot = &cfg.EnableSSHRoot
|
if cfg.EnableSSHRoot {
|
||||||
config.EnableSSHSFTP = &cfg.EnableSSHSFTP
|
config.EnableSSHRoot = &cfg.EnableSSHRoot
|
||||||
config.EnableSSHLocalPortForwarding = &cfg.EnableSSHLocalPortForwarding
|
}
|
||||||
config.EnableSSHRemotePortForwarding = &cfg.EnableSSHRemotePortForwarding
|
if cfg.EnableSSHSFTP {
|
||||||
config.DisableSSHAuth = &cfg.DisableSSHAuth
|
config.EnableSSHSFTP = &cfg.EnableSSHSFTP
|
||||||
|
}
|
||||||
ttl := int(cfg.SshJWTCacheTTL)
|
if cfg.EnableSSHLocalPortForwarding {
|
||||||
config.SSHJWTCacheTTL = &ttl
|
config.EnableSSHLocalPortForwarding = &cfg.EnableSSHLocalPortForwarding
|
||||||
|
}
|
||||||
|
if cfg.EnableSSHRemotePortForwarding {
|
||||||
|
config.EnableSSHRemotePortForwarding = &cfg.EnableSSHRemotePortForwarding
|
||||||
|
}
|
||||||
|
if cfg.DisableSSHAuth {
|
||||||
|
config.DisableSSHAuth = &cfg.DisableSSHAuth
|
||||||
|
}
|
||||||
|
|
||||||
return &config
|
return &config
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,9 @@ import (
|
|||||||
"github.com/skratchdot/open-golang/open"
|
"github.com/skratchdot/open-golang/open"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
uptypes "github.com/netbirdio/netbird/upload-server/types"
|
uptypes "github.com/netbirdio/netbird/upload-server/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -289,18 +291,19 @@ func (s *serviceClient) handleRunForDuration(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer s.restoreServiceState(conn, initialState)
|
statusOutput, err := s.collectDebugData(conn, initialState, params, progressUI)
|
||||||
|
if err != nil {
|
||||||
if err := s.collectDebugData(conn, initialState, params, progressUI); err != nil {
|
|
||||||
handleError(progressUI, err.Error())
|
handleError(progressUI, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.createDebugBundleFromCollection(conn, params, progressUI); err != nil {
|
if err := s.createDebugBundleFromCollection(conn, params, statusOutput, progressUI); err != nil {
|
||||||
handleError(progressUI, err.Error())
|
handleError(progressUI, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.restoreServiceState(conn, initialState)
|
||||||
|
|
||||||
progressUI.statusLabel.SetText("Bundle created successfully")
|
progressUI.statusLabel.SetText("Bundle created successfully")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -414,33 +417,68 @@ func (s *serviceClient) collectDebugData(
|
|||||||
state *debugInitialState,
|
state *debugInitialState,
|
||||||
params *debugCollectionParams,
|
params *debugCollectionParams,
|
||||||
progress *progressUI,
|
progress *progressUI,
|
||||||
) error {
|
) (string, error) {
|
||||||
ctx, cancel := context.WithTimeout(s.ctx, params.duration)
|
ctx, cancel := context.WithTimeout(s.ctx, params.duration)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
startProgressTracker(ctx, &wg, params.duration, progress)
|
startProgressTracker(ctx, &wg, params.duration, progress)
|
||||||
|
|
||||||
if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil {
|
if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil {
|
||||||
return err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pm := profilemanager.NewProfileManager()
|
||||||
|
var profName string
|
||||||
|
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||||
|
profName = activeProf.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to get post-up status: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var postUpStatusOutput string
|
||||||
|
if postUpStatus != nil {
|
||||||
|
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||||
|
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||||
|
}
|
||||||
|
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||||
|
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
progress.progressBar.Hide()
|
progress.progressBar.Hide()
|
||||||
progress.statusLabel.SetText("Collecting debug data...")
|
progress.statusLabel.SetText("Collecting debug data...")
|
||||||
|
|
||||||
return nil
|
preDownStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to get pre-down status: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var preDownStatusOutput string
|
||||||
|
if preDownStatus != nil {
|
||||||
|
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||||
|
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||||
|
}
|
||||||
|
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
||||||
|
time.Now().Format(time.RFC3339), params.duration)
|
||||||
|
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, preDownStatusOutput)
|
||||||
|
|
||||||
|
return statusOutput, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the debug bundle with collected data
|
// Create the debug bundle with collected data
|
||||||
func (s *serviceClient) createDebugBundleFromCollection(
|
func (s *serviceClient) createDebugBundleFromCollection(
|
||||||
conn proto.DaemonServiceClient,
|
conn proto.DaemonServiceClient,
|
||||||
params *debugCollectionParams,
|
params *debugCollectionParams,
|
||||||
|
statusOutput string,
|
||||||
progress *progressUI,
|
progress *progressUI,
|
||||||
) error {
|
) error {
|
||||||
progress.statusLabel.SetText("Creating debug bundle with collected logs...")
|
progress.statusLabel.SetText("Creating debug bundle with collected logs...")
|
||||||
|
|
||||||
request := &proto.DebugBundleRequest{
|
request := &proto.DebugBundleRequest{
|
||||||
Anonymize: params.anonymize,
|
Anonymize: params.anonymize,
|
||||||
|
Status: statusOutput,
|
||||||
SystemInfo: params.systemInfo,
|
SystemInfo: params.systemInfo,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -462,7 +500,7 @@ func (s *serviceClient) createDebugBundleFromCollection(
|
|||||||
if uploadFailureReason != "" {
|
if uploadFailureReason != "" {
|
||||||
showUploadFailedDialog(progress.window, localPath, uploadFailureReason)
|
showUploadFailedDialog(progress.window, localPath, uploadFailureReason)
|
||||||
} else {
|
} else {
|
||||||
showUploadSuccessDialog(s.app, progress.window, localPath, uploadedKey)
|
showUploadSuccessDialog(progress.window, localPath, uploadedKey)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
showBundleCreatedDialog(progress.window, localPath)
|
showBundleCreatedDialog(progress.window, localPath)
|
||||||
@@ -527,7 +565,7 @@ func (s *serviceClient) handleDebugCreation(
|
|||||||
if uploadFailureReason != "" {
|
if uploadFailureReason != "" {
|
||||||
showUploadFailedDialog(w, localPath, uploadFailureReason)
|
showUploadFailedDialog(w, localPath, uploadFailureReason)
|
||||||
} else {
|
} else {
|
||||||
showUploadSuccessDialog(s.app, w, localPath, uploadedKey)
|
showUploadSuccessDialog(w, localPath, uploadedKey)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
showBundleCreatedDialog(w, localPath)
|
showBundleCreatedDialog(w, localPath)
|
||||||
@@ -543,8 +581,26 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
|||||||
return nil, fmt.Errorf("get client: %v", err)
|
return nil, fmt.Errorf("get client: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pm := profilemanager.NewProfileManager()
|
||||||
|
var profName string
|
||||||
|
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||||
|
profName = activeProf.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to get status for debug bundle: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var statusOutput string
|
||||||
|
if statusResp != nil {
|
||||||
|
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
|
||||||
|
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||||
|
}
|
||||||
|
|
||||||
request := &proto.DebugBundleRequest{
|
request := &proto.DebugBundleRequest{
|
||||||
Anonymize: anonymize,
|
Anonymize: anonymize,
|
||||||
|
Status: statusOutput,
|
||||||
SystemInfo: systemInfo,
|
SystemInfo: systemInfo,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -609,7 +665,7 @@ func showUploadFailedDialog(w fyne.Window, localPath, failureReason string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// showUploadSuccessDialog displays a dialog when upload succeeds
|
// showUploadSuccessDialog displays a dialog when upload succeeds
|
||||||
func showUploadSuccessDialog(a fyne.App, w fyne.Window, localPath, uploadedKey string) {
|
func showUploadSuccessDialog(w fyne.Window, localPath, uploadedKey string) {
|
||||||
log.Infof("Upload key: %s", uploadedKey)
|
log.Infof("Upload key: %s", uploadedKey)
|
||||||
keyEntry := widget.NewEntry()
|
keyEntry := widget.NewEntry()
|
||||||
keyEntry.SetText(uploadedKey)
|
keyEntry.SetText(uploadedKey)
|
||||||
@@ -627,7 +683,7 @@ func showUploadSuccessDialog(a fyne.App, w fyne.Window, localPath, uploadedKey s
|
|||||||
customDialog := dialog.NewCustom("Upload Successful", "OK", content, w)
|
customDialog := dialog.NewCustom("Upload Successful", "OK", content, w)
|
||||||
|
|
||||||
copyBtn := createButtonWithAction("Copy key", func() {
|
copyBtn := createButtonWithAction("Copy key", func() {
|
||||||
a.Clipboard().SetContent(uploadedKey)
|
w.Clipboard().SetContent(uploadedKey)
|
||||||
log.Info("Upload key copied to clipboard")
|
log.Info("Upload key copied to clipboard")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -9,9 +9,6 @@ import (
|
|||||||
//go:embed assets/netbird.png
|
//go:embed assets/netbird.png
|
||||||
var iconAbout []byte
|
var iconAbout []byte
|
||||||
|
|
||||||
//go:embed assets/netbird-disconnected.png
|
|
||||||
var iconAboutDisconnected []byte
|
|
||||||
|
|
||||||
//go:embed assets/netbird-systemtray-connected.png
|
//go:embed assets/netbird-systemtray-connected.png
|
||||||
var iconConnected []byte
|
var iconConnected []byte
|
||||||
|
|
||||||
|
|||||||
@@ -7,9 +7,6 @@ import (
|
|||||||
//go:embed assets/netbird.ico
|
//go:embed assets/netbird.ico
|
||||||
var iconAbout []byte
|
var iconAbout []byte
|
||||||
|
|
||||||
//go:embed assets/netbird-disconnected.ico
|
|
||||||
var iconAboutDisconnected []byte
|
|
||||||
|
|
||||||
//go:embed assets/netbird-systemtray-connected.ico
|
//go:embed assets/netbird-systemtray-connected.ico
|
||||||
var iconConnected []byte
|
var iconConnected []byte
|
||||||
|
|
||||||
|
|||||||
@@ -1,349 +0,0 @@
|
|||||||
//go:build !(linux && 386)
|
|
||||||
|
|
||||||
//go:generate fyne bundle -o quickactions_assets.go assets/connected.png
|
|
||||||
//go:generate fyne bundle -o quickactions_assets.go -append assets/disconnected.png
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
_ "embed"
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"fyne.io/fyne/v2"
|
|
||||||
"fyne.io/fyne/v2/canvas"
|
|
||||||
"fyne.io/fyne/v2/container"
|
|
||||||
"fyne.io/fyne/v2/layout"
|
|
||||||
"fyne.io/fyne/v2/widget"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
type quickActionsUiState struct {
|
|
||||||
connectionStatus string
|
|
||||||
isToggleButtonEnabled bool
|
|
||||||
isConnectionChanged bool
|
|
||||||
toggleAction func()
|
|
||||||
}
|
|
||||||
|
|
||||||
func newQuickActionsUiState() quickActionsUiState {
|
|
||||||
return quickActionsUiState{
|
|
||||||
connectionStatus: string(internal.StatusIdle),
|
|
||||||
isToggleButtonEnabled: false,
|
|
||||||
isConnectionChanged: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type clientConnectionStatusProvider interface {
|
|
||||||
connectionStatus(ctx context.Context) (string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type daemonClientConnectionStatusProvider struct {
|
|
||||||
client proto.DaemonServiceClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d daemonClientConnectionStatusProvider) connectionStatus(ctx context.Context) (string, error) {
|
|
||||||
childCtx, cancel := context.WithTimeout(ctx, 400*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
status, err := d.client.Status(childCtx, &proto.StatusRequest{})
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return status.Status, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type clientCommand interface {
|
|
||||||
execute() error
|
|
||||||
}
|
|
||||||
|
|
||||||
type connectCommand struct {
|
|
||||||
connectClient func() error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c connectCommand) execute() error {
|
|
||||||
return c.connectClient()
|
|
||||||
}
|
|
||||||
|
|
||||||
type disconnectCommand struct {
|
|
||||||
disconnectClient func() error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c disconnectCommand) execute() error {
|
|
||||||
return c.disconnectClient()
|
|
||||||
}
|
|
||||||
|
|
||||||
type quickActionsViewModel struct {
|
|
||||||
provider clientConnectionStatusProvider
|
|
||||||
connect clientCommand
|
|
||||||
disconnect clientCommand
|
|
||||||
uiChan chan quickActionsUiState
|
|
||||||
isWatchingConnectionStatus atomic.Bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func newQuickActionsViewModel(ctx context.Context, provider clientConnectionStatusProvider, connect, disconnect clientCommand, uiChan chan quickActionsUiState) {
|
|
||||||
viewModel := quickActionsViewModel{
|
|
||||||
provider: provider,
|
|
||||||
connect: connect,
|
|
||||||
disconnect: disconnect,
|
|
||||||
uiChan: uiChan,
|
|
||||||
}
|
|
||||||
|
|
||||||
viewModel.isWatchingConnectionStatus.Store(true)
|
|
||||||
|
|
||||||
// base UI status
|
|
||||||
uiChan <- newQuickActionsUiState()
|
|
||||||
|
|
||||||
// this retrieves the current connection status
|
|
||||||
// and pushes the UI state that reflects it via uiChan
|
|
||||||
go viewModel.watchConnectionStatus(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *quickActionsViewModel) updateUiState(ctx context.Context) {
|
|
||||||
uiState := newQuickActionsUiState()
|
|
||||||
connectionStatus, err := q.provider.connectionStatus(ctx)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Status: Error - %v", err)
|
|
||||||
q.uiChan <- uiState
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if connectionStatus == string(internal.StatusConnected) {
|
|
||||||
uiState.toggleAction = func() {
|
|
||||||
q.executeCommand(q.disconnect)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
uiState.toggleAction = func() {
|
|
||||||
q.executeCommand(q.connect)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
uiState.isToggleButtonEnabled = true
|
|
||||||
uiState.connectionStatus = connectionStatus
|
|
||||||
q.uiChan <- uiState
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *quickActionsViewModel) watchConnectionStatus(ctx context.Context) {
|
|
||||||
ticker := time.NewTicker(1000 * time.Millisecond)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
if q.isWatchingConnectionStatus.Load() {
|
|
||||||
q.updateUiState(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *quickActionsViewModel) executeCommand(command clientCommand) {
|
|
||||||
uiState := newQuickActionsUiState()
|
|
||||||
// newQuickActionsUiState starts with Idle connection status,
|
|
||||||
// and all that's necessary here is to just disable the toggle button.
|
|
||||||
uiState.connectionStatus = ""
|
|
||||||
|
|
||||||
q.uiChan <- uiState
|
|
||||||
|
|
||||||
q.isWatchingConnectionStatus.Store(false)
|
|
||||||
|
|
||||||
err := command.execute()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Status: Error - %v", err)
|
|
||||||
q.isWatchingConnectionStatus.Store(true)
|
|
||||||
} else {
|
|
||||||
uiState = newQuickActionsUiState()
|
|
||||||
uiState.isConnectionChanged = true
|
|
||||||
q.uiChan <- uiState
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSystemTrayName() string {
|
|
||||||
os := runtime.GOOS
|
|
||||||
switch os {
|
|
||||||
case "darwin":
|
|
||||||
return "menu bar"
|
|
||||||
default:
|
|
||||||
return "system tray"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceClient) getNetBirdImage(name string, content []byte) *canvas.Image {
|
|
||||||
imageSize := fyne.NewSize(64, 64)
|
|
||||||
|
|
||||||
resource := fyne.NewStaticResource(name, content)
|
|
||||||
image := canvas.NewImageFromResource(resource)
|
|
||||||
image.FillMode = canvas.ImageFillContain
|
|
||||||
image.SetMinSize(imageSize)
|
|
||||||
image.Resize(imageSize)
|
|
||||||
|
|
||||||
return image
|
|
||||||
}
|
|
||||||
|
|
||||||
type quickActionsUiComponents struct {
|
|
||||||
content *fyne.Container
|
|
||||||
toggleConnectionButton *widget.Button
|
|
||||||
connectedLabelText, disconnectedLabelText string
|
|
||||||
connectedImage, disconnectedImage *canvas.Image
|
|
||||||
connectedCircleRes, disconnectedCircleRes fyne.Resource
|
|
||||||
}
|
|
||||||
|
|
||||||
// applyQuickActionsUiState applies a single UI state to the quick actions window.
|
|
||||||
// It closes the window and returns true if the connection status has changed,
|
|
||||||
// in which case the caller should stop processing further states.
|
|
||||||
func (s *serviceClient) applyQuickActionsUiState(
|
|
||||||
uiState quickActionsUiState,
|
|
||||||
components quickActionsUiComponents,
|
|
||||||
) bool {
|
|
||||||
if uiState.isConnectionChanged {
|
|
||||||
fyne.DoAndWait(func() {
|
|
||||||
s.wQuickActions.Close()
|
|
||||||
})
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
var logo *canvas.Image
|
|
||||||
var buttonText string
|
|
||||||
var buttonIcon fyne.Resource
|
|
||||||
|
|
||||||
if uiState.connectionStatus == string(internal.StatusConnected) {
|
|
||||||
buttonText = components.connectedLabelText
|
|
||||||
buttonIcon = components.connectedCircleRes
|
|
||||||
logo = components.connectedImage
|
|
||||||
} else if uiState.connectionStatus == string(internal.StatusIdle) {
|
|
||||||
buttonText = components.disconnectedLabelText
|
|
||||||
buttonIcon = components.disconnectedCircleRes
|
|
||||||
logo = components.disconnectedImage
|
|
||||||
}
|
|
||||||
|
|
||||||
fyne.DoAndWait(func() {
|
|
||||||
if buttonText != "" {
|
|
||||||
components.toggleConnectionButton.SetText(buttonText)
|
|
||||||
}
|
|
||||||
|
|
||||||
if buttonIcon != nil {
|
|
||||||
components.toggleConnectionButton.SetIcon(buttonIcon)
|
|
||||||
}
|
|
||||||
|
|
||||||
if uiState.isToggleButtonEnabled {
|
|
||||||
components.toggleConnectionButton.Enable()
|
|
||||||
} else {
|
|
||||||
components.toggleConnectionButton.Disable()
|
|
||||||
}
|
|
||||||
|
|
||||||
components.toggleConnectionButton.OnTapped = func() {
|
|
||||||
if uiState.toggleAction != nil {
|
|
||||||
go uiState.toggleAction()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
components.toggleConnectionButton.Refresh()
|
|
||||||
|
|
||||||
// the second position in the content's object array is the NetBird logo.
|
|
||||||
if logo != nil {
|
|
||||||
components.content.Objects[1] = logo
|
|
||||||
components.content.Refresh()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// showQuickActionsUI displays a simple window with the NetBird logo and a connection toggle button.
|
|
||||||
func (s *serviceClient) showQuickActionsUI() {
|
|
||||||
s.wQuickActions = s.app.NewWindow("NetBird")
|
|
||||||
vmCtx, vmCancel := context.WithCancel(s.ctx)
|
|
||||||
s.wQuickActions.SetOnClosed(vmCancel)
|
|
||||||
|
|
||||||
client, err := s.getSrvClient(defaultFailTimeout)
|
|
||||||
|
|
||||||
connCmd := connectCommand{
|
|
||||||
connectClient: func() error {
|
|
||||||
return s.menuUpClick(s.ctx)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
disConnCmd := disconnectCommand{
|
|
||||||
disconnectClient: func() error {
|
|
||||||
return s.menuDownClick()
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("get service client: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
uiChan := make(chan quickActionsUiState, 1)
|
|
||||||
newQuickActionsViewModel(vmCtx, daemonClientConnectionStatusProvider{client: client}, connCmd, disConnCmd, uiChan)
|
|
||||||
|
|
||||||
connectedImage := s.getNetBirdImage("netbird.png", iconAbout)
|
|
||||||
disconnectedImage := s.getNetBirdImage("netbird-disconnected.png", iconAboutDisconnected)
|
|
||||||
|
|
||||||
connectedCircle := canvas.NewImageFromResource(resourceConnectedPng)
|
|
||||||
disconnectedCircle := canvas.NewImageFromResource(resourceDisconnectedPng)
|
|
||||||
|
|
||||||
connectedLabelText := "Disconnect"
|
|
||||||
disconnectedLabelText := "Connect"
|
|
||||||
|
|
||||||
toggleConnectionButton := widget.NewButtonWithIcon(disconnectedLabelText, disconnectedCircle.Resource, func() {
|
|
||||||
// This button's tap function will be set when an ui state arrives via the uiChan channel.
|
|
||||||
})
|
|
||||||
|
|
||||||
// Button starts disabled until the first ui state arrives.
|
|
||||||
toggleConnectionButton.Disable()
|
|
||||||
|
|
||||||
hintLabelText := fmt.Sprintf("You can always access NetBird from your %s.", getSystemTrayName())
|
|
||||||
hintLabel := widget.NewLabel(hintLabelText)
|
|
||||||
|
|
||||||
content := container.NewVBox(
|
|
||||||
layout.NewSpacer(),
|
|
||||||
disconnectedImage,
|
|
||||||
layout.NewSpacer(),
|
|
||||||
container.NewCenter(toggleConnectionButton),
|
|
||||||
layout.NewSpacer(),
|
|
||||||
container.NewCenter(hintLabel),
|
|
||||||
)
|
|
||||||
|
|
||||||
// this watches for ui state updates.
|
|
||||||
go func() {
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-vmCtx.Done():
|
|
||||||
return
|
|
||||||
case uiState, ok := <-uiChan:
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
closed := s.applyQuickActionsUiState(
|
|
||||||
uiState,
|
|
||||||
quickActionsUiComponents{
|
|
||||||
content,
|
|
||||||
toggleConnectionButton,
|
|
||||||
connectedLabelText, disconnectedLabelText,
|
|
||||||
connectedImage, disconnectedImage,
|
|
||||||
connectedCircle.Resource, disconnectedCircle.Resource,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if closed {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
s.wQuickActions.SetContent(content)
|
|
||||||
s.wQuickActions.Resize(fyne.NewSize(400, 200))
|
|
||||||
s.wQuickActions.SetFixedSize(true)
|
|
||||||
s.wQuickActions.Show()
|
|
||||||
}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
// auto-generated
|
|
||||||
// Code generated by '$ fyne bundle'. DO NOT EDIT.
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
_ "embed"
|
|
||||||
"fyne.io/fyne/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
//go:embed assets/connected.png
|
|
||||||
var resourceConnectedPngData []byte
|
|
||||||
var resourceConnectedPng = &fyne.StaticResource{
|
|
||||||
StaticName: "assets/connected.png",
|
|
||||||
StaticContent: resourceConnectedPngData,
|
|
||||||
}
|
|
||||||
|
|
||||||
//go:embed assets/disconnected.png
|
|
||||||
var resourceDisconnectedPngData []byte
|
|
||||||
var resourceDisconnectedPng = &fyne.StaticResource{
|
|
||||||
StaticName: "assets/disconnected.png",
|
|
||||||
StaticContent: resourceDisconnectedPngData,
|
|
||||||
}
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
//go:build !windows && !(linux && 386)
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"os/signal"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// setupSignalHandler sets up a signal handler to listen for SIGUSR1.
|
|
||||||
// When received, it opens the quick actions window.
|
|
||||||
func (s *serviceClient) setupSignalHandler(ctx context.Context) {
|
|
||||||
sigChan := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigChan, syscall.SIGUSR1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-sigChan:
|
|
||||||
log.Info("received SIGUSR1 signal, opening quick actions window")
|
|
||||||
s.openQuickActions()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// openQuickActions opens the quick actions window by spawning a new process.
|
|
||||||
func (s *serviceClient) openQuickActions() {
|
|
||||||
proc, err := os.Executable()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("get executable path: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.CommandContext(s.ctx, proc,
|
|
||||||
"--quick-actions=true",
|
|
||||||
"--daemon-addr="+s.addr,
|
|
||||||
)
|
|
||||||
|
|
||||||
if out := s.attachOutput(cmd); out != nil {
|
|
||||||
defer func() {
|
|
||||||
if err := out.Close(); err != nil {
|
|
||||||
log.Errorf("close log file %s: %v", s.logFile, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("running command: %s --quick-actions=true --daemon-addr=%s", proc, s.addr)
|
|
||||||
|
|
||||||
if err := cmd.Start(); err != nil {
|
|
||||||
log.Errorf("start quick actions window: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err := cmd.Wait(); err != nil {
|
|
||||||
log.Debugf("quick actions window exited: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendShowWindowSignal sends SIGUSR1 to the specified PID.
|
|
||||||
func sendShowWindowSignal(pid int32) error {
|
|
||||||
process, err := os.FindProcess(int(pid))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return process.Signal(syscall.SIGUSR1)
|
|
||||||
}
|
|
||||||
@@ -1,171 +0,0 @@
|
|||||||
//go:build windows
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
quickActionsTriggerEventName = `Global\NetBirdQuickActionsTriggerEvent`
|
|
||||||
waitTimeout = 5 * time.Second
|
|
||||||
// SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
|
|
||||||
desiredAccesses = windows.SYNCHRONIZE | windows.EVENT_MODIFY_STATE
|
|
||||||
)
|
|
||||||
|
|
||||||
func getEventNameUint16Pointer() (*uint16, error) {
|
|
||||||
eventNamePtr, err := windows.UTF16PtrFromString(quickActionsTriggerEventName)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to convert event name '%s' to UTF16: %v", quickActionsTriggerEventName, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return eventNamePtr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupSignalHandler sets up signal handling for Windows.
|
|
||||||
// Windows doesn't support SIGUSR1, so this uses a similar approach using windows.Events.
|
|
||||||
func (s *serviceClient) setupSignalHandler(ctx context.Context) {
|
|
||||||
eventNamePtr, err := getEventNameUint16Pointer()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
|
|
||||||
log.Warnf("Quick actions trigger event '%s' already exists. Attempting to open.", quickActionsTriggerEventName)
|
|
||||||
eventHandle, err = windows.OpenEvent(desiredAccesses, false, eventNamePtr)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to open existing quick actions trigger event '%s': %v", quickActionsTriggerEventName, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Infof("Successfully opened existing quick actions trigger event '%s'.", quickActionsTriggerEventName)
|
|
||||||
} else {
|
|
||||||
log.Errorf("Failed to create quick actions trigger event '%s': %v", quickActionsTriggerEventName, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if eventHandle == windows.InvalidHandle {
|
|
||||||
log.Errorf("Obtained an invalid handle for quick actions trigger event '%s'", quickActionsTriggerEventName)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("Quick actions handler waiting for signal on event: %s", quickActionsTriggerEventName)
|
|
||||||
|
|
||||||
go s.waitForEvent(ctx, eventHandle)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceClient) waitForEvent(ctx context.Context, eventHandle windows.Handle) {
|
|
||||||
defer func() {
|
|
||||||
if err := windows.CloseHandle(eventHandle); err != nil {
|
|
||||||
log.Errorf("Failed to close quick actions event handle '%s': %v", quickActionsTriggerEventName, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
|
|
||||||
|
|
||||||
switch status {
|
|
||||||
case windows.WAIT_OBJECT_0:
|
|
||||||
log.Info("Received signal on quick actions event. Opening quick actions window.")
|
|
||||||
|
|
||||||
// reset the event so it can be triggered again later (manual reset == 1)
|
|
||||||
if err := windows.ResetEvent(eventHandle); err != nil {
|
|
||||||
log.Errorf("Failed to reset quick actions event '%s': %v", quickActionsTriggerEventName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.openQuickActions()
|
|
||||||
case uint32(windows.WAIT_TIMEOUT):
|
|
||||||
|
|
||||||
default:
|
|
||||||
if isDone := logUnexpectedStatus(ctx, status, err); isDone {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func logUnexpectedStatus(ctx context.Context, status uint32, err error) bool {
|
|
||||||
log.Errorf("Unexpected status %d from WaitForSingleObject for quick actions event '%s': %v",
|
|
||||||
status, quickActionsTriggerEventName, err)
|
|
||||||
select {
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
return false
|
|
||||||
case <-ctx.Done():
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// openQuickActions opens the quick actions window by spawning a new process.
|
|
||||||
func (s *serviceClient) openQuickActions() {
|
|
||||||
proc, err := os.Executable()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("get executable path: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.CommandContext(s.ctx, proc,
|
|
||||||
"--quick-actions=true",
|
|
||||||
"--daemon-addr="+s.addr,
|
|
||||||
)
|
|
||||||
|
|
||||||
if out := s.attachOutput(cmd); out != nil {
|
|
||||||
defer func() {
|
|
||||||
if err := out.Close(); err != nil {
|
|
||||||
log.Errorf("close log file %s: %v", s.logFile, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("running command: %s --quick-actions=true --daemon-addr=%s", proc, s.addr)
|
|
||||||
|
|
||||||
if err := cmd.Start(); err != nil {
|
|
||||||
log.Errorf("error starting quick actions window: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err := cmd.Wait(); err != nil {
|
|
||||||
log.Debugf("quick actions window exited: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendShowWindowSignal(pid int32) error {
|
|
||||||
_, err := os.FindProcess(int(pid))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
eventNamePtr, err := getEventNameUint16Pointer()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
eventHandle, err := windows.OpenEvent(desiredAccesses, false, eventNamePtr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = windows.SetEvent(eventHandle)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Error setting event: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
57
go.mod
57
go.mod
@@ -16,7 +16,7 @@ require (
|
|||||||
github.com/sirupsen/logrus v1.9.3
|
github.com/sirupsen/logrus v1.9.3
|
||||||
github.com/spf13/cobra v1.7.0
|
github.com/spf13/cobra v1.7.0
|
||||||
github.com/spf13/pflag v1.0.5
|
github.com/spf13/pflag v1.0.5
|
||||||
github.com/vishvananda/netlink v1.3.1
|
github.com/vishvananda/netlink v1.3.0
|
||||||
golang.org/x/crypto v0.41.0
|
golang.org/x/crypto v0.41.0
|
||||||
golang.org/x/sys v0.35.0
|
golang.org/x/sys v0.35.0
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||||
@@ -28,8 +28,8 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
fyne.io/fyne/v2 v2.7.0
|
fyne.io/fyne/v2 v2.5.3
|
||||||
fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58
|
fyne.io/systray v1.11.0
|
||||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
|
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
|
||||||
github.com/awnumar/memguard v0.23.0
|
github.com/awnumar/memguard v0.23.0
|
||||||
github.com/aws/aws-sdk-go-v2 v1.36.3
|
github.com/aws/aws-sdk-go-v2 v1.36.3
|
||||||
@@ -44,7 +44,7 @@ require (
|
|||||||
github.com/eko/gocache/lib/v4 v4.2.0
|
github.com/eko/gocache/lib/v4 v4.2.0
|
||||||
github.com/eko/gocache/store/go_cache/v4 v4.2.2
|
github.com/eko/gocache/store/go_cache/v4 v4.2.2
|
||||||
github.com/eko/gocache/store/redis/v4 v4.2.2
|
github.com/eko/gocache/store/redis/v4 v4.2.2
|
||||||
github.com/fsnotify/fsnotify v1.9.0
|
github.com/fsnotify/fsnotify v1.7.0
|
||||||
github.com/gliderlabs/ssh v0.3.8
|
github.com/gliderlabs/ssh v0.3.8
|
||||||
github.com/godbus/dbus/v5 v5.1.0
|
github.com/godbus/dbus/v5 v5.1.0
|
||||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||||
@@ -57,16 +57,14 @@ require (
|
|||||||
github.com/hashicorp/go-multierror v1.1.1
|
github.com/hashicorp/go-multierror v1.1.1
|
||||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||||
github.com/hashicorp/go-version v1.6.0
|
github.com/hashicorp/go-version v1.6.0
|
||||||
github.com/jackc/pgx/v5 v5.5.5
|
|
||||||
github.com/libdns/route53 v1.5.0
|
github.com/libdns/route53 v1.5.0
|
||||||
github.com/libp2p/go-netroute v0.2.1
|
github.com/libp2p/go-netroute v0.2.1
|
||||||
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
|
|
||||||
github.com/mdlayher/socket v0.5.1
|
github.com/mdlayher/socket v0.5.1
|
||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
|
github.com/nadoo/ipset v0.5.0
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||||
github.com/oapi-codegen/runtime v1.1.2
|
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0
|
github.com/oschwald/maxminddb-golang v1.12.0
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||||
@@ -86,7 +84,7 @@ require (
|
|||||||
github.com/shirou/gopsutil/v3 v3.24.4
|
github.com/shirou/gopsutil/v3 v3.24.4
|
||||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
|
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.10.0
|
||||||
github.com/testcontainers/testcontainers-go v0.31.0
|
github.com/testcontainers/testcontainers-go v0.31.0
|
||||||
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0
|
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0
|
||||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0
|
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0
|
||||||
@@ -102,17 +100,15 @@ require (
|
|||||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
|
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
|
||||||
go.opentelemetry.io/otel/metric v1.35.0
|
go.opentelemetry.io/otel/metric v1.35.0
|
||||||
go.opentelemetry.io/otel/sdk/metric v1.35.0
|
go.opentelemetry.io/otel/sdk/metric v1.35.0
|
||||||
go.uber.org/mock v0.5.0
|
|
||||||
go.uber.org/zap v1.27.0
|
go.uber.org/zap v1.27.0
|
||||||
goauthentik.io/api/v3 v3.2023051.3
|
goauthentik.io/api/v3 v3.2023051.3
|
||||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
|
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
|
||||||
golang.org/x/mod v0.26.0
|
golang.org/x/mod v0.26.0
|
||||||
golang.org/x/net v0.42.0
|
golang.org/x/net v0.42.0
|
||||||
golang.org/x/oauth2 v0.30.0
|
golang.org/x/oauth2 v0.28.0
|
||||||
golang.org/x/sync v0.16.0
|
golang.org/x/sync v0.16.0
|
||||||
golang.org/x/term v0.34.0
|
golang.org/x/term v0.34.0
|
||||||
golang.org/x/time v0.12.0
|
|
||||||
google.golang.org/api v0.177.0
|
google.golang.org/api v0.177.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
gorm.io/driver/mysql v1.5.7
|
gorm.io/driver/mysql v1.5.7
|
||||||
@@ -129,11 +125,10 @@ require (
|
|||||||
dario.cat/mergo v1.0.0 // indirect
|
dario.cat/mergo v1.0.0 // indirect
|
||||||
filippo.io/edwards25519 v1.1.0 // indirect
|
filippo.io/edwards25519 v1.1.0 // indirect
|
||||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
|
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
|
||||||
github.com/BurntSushi/toml v1.5.0 // indirect
|
github.com/BurntSushi/toml v1.4.0 // indirect
|
||||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||||
github.com/Microsoft/hcsshim v0.12.3 // indirect
|
github.com/Microsoft/hcsshim v0.12.3 // indirect
|
||||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
|
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
|
||||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
|
||||||
github.com/awnumar/memcall v0.4.0 // indirect
|
github.com/awnumar/memcall v0.4.0 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect
|
||||||
@@ -154,7 +149,7 @@ require (
|
|||||||
github.com/beorn7/perks v1.0.1 // indirect
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
github.com/caddyserver/zerossl v0.1.3 // indirect
|
github.com/caddyserver/zerossl v0.1.3 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
github.com/containerd/containerd v1.7.29 // indirect
|
github.com/containerd/containerd v1.7.27 // indirect
|
||||||
github.com/containerd/log v0.1.0 // indirect
|
github.com/containerd/log v0.1.0 // indirect
|
||||||
github.com/containerd/platforms v0.2.1 // indirect
|
github.com/containerd/platforms v0.2.1 // indirect
|
||||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||||
@@ -165,12 +160,11 @@ require (
|
|||||||
github.com/docker/go-connections v0.5.0 // indirect
|
github.com/docker/go-connections v0.5.0 // indirect
|
||||||
github.com/docker/go-units v0.5.0 // indirect
|
github.com/docker/go-units v0.5.0 // indirect
|
||||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
github.com/fredbi/uri v1.1.1 // indirect
|
github.com/fredbi/uri v1.1.0 // indirect
|
||||||
github.com/fyne-io/gl-js v0.2.0 // indirect
|
github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe // indirect
|
||||||
github.com/fyne-io/glfw-js v0.3.0 // indirect
|
github.com/fyne-io/glfw-js v0.0.0-20241126112943-313d8a0fe1d0 // indirect
|
||||||
github.com/fyne-io/image v0.1.1 // indirect
|
github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 // indirect
|
||||||
github.com/fyne-io/oksvg v0.2.0 // indirect
|
github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6 // indirect
|
||||||
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
|
|
||||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
|
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
|
||||||
github.com/go-logr/logr v1.4.2 // indirect
|
github.com/go-logr/logr v1.4.2 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
@@ -178,7 +172,7 @@ require (
|
|||||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||||
github.com/go-text/render v0.2.0 // indirect
|
github.com/go-text/render v0.2.0 // indirect
|
||||||
github.com/go-text/typesetting v0.2.1 // indirect
|
github.com/go-text/typesetting v0.2.0 // indirect
|
||||||
github.com/gogo/protobuf v1.3.2 // indirect
|
github.com/gogo/protobuf v1.3.2 // indirect
|
||||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||||
github.com/google/btree v1.1.2 // indirect
|
github.com/google/btree v1.1.2 // indirect
|
||||||
@@ -186,19 +180,19 @@ require (
|
|||||||
github.com/google/s2a-go v0.1.7 // indirect
|
github.com/google/s2a-go v0.1.7 // indirect
|
||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||||
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
|
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
|
||||||
github.com/hack-pad/go-indexeddb v0.3.2 // indirect
|
github.com/gopherjs/gopherjs v1.17.2 // indirect
|
||||||
github.com/hack-pad/safejs v0.1.0 // indirect
|
|
||||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||||
github.com/hashicorp/go-uuid v1.0.3 // indirect
|
github.com/hashicorp/go-uuid v1.0.3 // indirect
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||||
|
github.com/jackc/pgx/v5 v5.5.5 // indirect
|
||||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||||
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect
|
github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||||
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 // indirect
|
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
|
||||||
github.com/kelseyhightower/envconfig v1.4.0 // indirect
|
github.com/kelseyhightower/envconfig v1.4.0 // indirect
|
||||||
github.com/klauspost/compress v1.18.0 // indirect
|
github.com/klauspost/compress v1.18.0 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||||
@@ -218,8 +212,7 @@ require (
|
|||||||
github.com/moby/term v0.5.0 // indirect
|
github.com/moby/term v0.5.0 // indirect
|
||||||
github.com/morikuni/aec v1.0.0 // indirect
|
github.com/morikuni/aec v1.0.0 // indirect
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect
|
github.com/nicksnyder/go-i18n/v2 v2.4.0 // indirect
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.5.1 // indirect
|
|
||||||
github.com/nxadm/tail v1.4.8 // indirect
|
github.com/nxadm/tail v1.4.8 // indirect
|
||||||
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
||||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||||
@@ -235,26 +228,28 @@ require (
|
|||||||
github.com/prometheus/client_model v0.6.1 // indirect
|
github.com/prometheus/client_model v0.6.1 // indirect
|
||||||
github.com/prometheus/common v0.62.0 // indirect
|
github.com/prometheus/common v0.62.0 // indirect
|
||||||
github.com/prometheus/procfs v0.15.1 // indirect
|
github.com/prometheus/procfs v0.15.1 // indirect
|
||||||
github.com/rymdport/portal v0.4.2 // indirect
|
github.com/rymdport/portal v0.3.0 // indirect
|
||||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
|
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
|
||||||
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
|
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
|
||||||
github.com/stretchr/objx v0.5.2 // indirect
|
github.com/stretchr/objx v0.5.2 // indirect
|
||||||
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||||
github.com/vishvananda/netns v0.0.5 // indirect
|
github.com/vishvananda/netns v0.0.4 // indirect
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||||
github.com/wlynxg/anet v0.0.3 // indirect
|
github.com/wlynxg/anet v0.0.3 // indirect
|
||||||
github.com/yuin/goldmark v1.7.8 // indirect
|
github.com/yuin/goldmark v1.7.1 // indirect
|
||||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||||
go.opencensus.io v0.24.0 // indirect
|
go.opencensus.io v0.24.0 // indirect
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
|
||||||
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
|
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
|
||||||
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
||||||
|
go.uber.org/mock v0.5.0 // indirect
|
||||||
go.uber.org/multierr v1.11.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
golang.org/x/image v0.24.0 // indirect
|
golang.org/x/image v0.18.0 // indirect
|
||||||
golang.org/x/text v0.28.0 // indirect
|
golang.org/x/text v0.28.0 // indirect
|
||||||
|
golang.org/x/time v0.5.0 // indirect
|
||||||
golang.org/x/tools v0.35.0 // indirect
|
golang.org/x/tools v0.35.0 // indirect
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect
|
||||||
|
|||||||
@@ -1,31 +0,0 @@
|
|||||||
package cache
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DNSConfigCache is a thread-safe cache for DNS configuration components
|
|
||||||
type DNSConfigCache struct {
|
|
||||||
NameServerGroups sync.Map
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetNameServerGroup retrieves a cached name server group
|
|
||||||
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
|
|
||||||
if c == nil {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
if value, ok := c.NameServerGroups.Load(key); ok {
|
|
||||||
return value.(*proto.NameServerGroup), true
|
|
||||||
}
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNameServerGroup stores a name server group in the cache
|
|
||||||
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
|
|
||||||
if c == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.NameServerGroups.Store(key, value)
|
|
||||||
}
|
|
||||||
@@ -1,784 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"slices"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
"golang.org/x/mod/semver"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Controller struct {
|
|
||||||
repo Repository
|
|
||||||
metrics *metrics
|
|
||||||
// This should not be here, but we need to maintain it for the time being
|
|
||||||
accountManagerMetrics *telemetry.AccountManagerMetrics
|
|
||||||
peersUpdateManager network_map.PeersUpdateManager
|
|
||||||
settingsManager settings.Manager
|
|
||||||
|
|
||||||
accountUpdateLocks sync.Map
|
|
||||||
sendAccountUpdateLocks sync.Map
|
|
||||||
updateAccountPeersBufferInterval atomic.Int64
|
|
||||||
// dnsDomain is used for peer resolution. This is appended to the peer's name
|
|
||||||
dnsDomain string
|
|
||||||
|
|
||||||
requestBuffer account.RequestBuffer
|
|
||||||
|
|
||||||
proxyController port_forwarding.Controller
|
|
||||||
|
|
||||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
|
||||||
|
|
||||||
holder *types.Holder
|
|
||||||
|
|
||||||
expNewNetworkMap bool
|
|
||||||
expNewNetworkMapAIDs map[string]struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
type bufferUpdate struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
next *time.Timer
|
|
||||||
update atomic.Bool
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ network_map.Controller = (*Controller)(nil)
|
|
||||||
|
|
||||||
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller) *Controller {
|
|
||||||
nMetrics, err := newMetrics(metrics.UpdateChannelMetrics())
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
|
|
||||||
newNetworkMapBuilder = false
|
|
||||||
}
|
|
||||||
|
|
||||||
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
|
|
||||||
expIDs := make(map[string]struct{}, len(ids))
|
|
||||||
for _, id := range ids {
|
|
||||||
expIDs[id] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Controller{
|
|
||||||
repo: newRepository(store),
|
|
||||||
metrics: nMetrics,
|
|
||||||
accountManagerMetrics: metrics.AccountManagerMetrics(),
|
|
||||||
peersUpdateManager: peersUpdateManager,
|
|
||||||
requestBuffer: requestBuffer,
|
|
||||||
integratedPeerValidator: integratedPeerValidator,
|
|
||||||
settingsManager: settingsManager,
|
|
||||||
dnsDomain: dnsDomain,
|
|
||||||
|
|
||||||
proxyController: proxyController,
|
|
||||||
|
|
||||||
holder: types.NewHolder(),
|
|
||||||
expNewNetworkMap: newNetworkMapBuilder,
|
|
||||||
expNewNetworkMapAIDs: expIDs,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
|
||||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
|
||||||
var (
|
|
||||||
account *types.Account
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
|
||||||
account = c.getAccountFromHolderOrInit(accountID)
|
|
||||||
} else {
|
|
||||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get account: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
globalStart := time.Now()
|
|
||||||
|
|
||||||
hasPeersConnected := false
|
|
||||||
for _, peer := range account.Peers {
|
|
||||||
if c.peersUpdateManager.HasChannel(peer.ID) {
|
|
||||||
hasPeersConnected = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hasPeersConnected {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get validate peers: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
semaphore := make(chan struct{}, 10)
|
|
||||||
|
|
||||||
dnsCache := &cache.DNSConfigCache{}
|
|
||||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
|
||||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
|
||||||
routers := account.GetResourceRoutersMap()
|
|
||||||
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
|
||||||
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
|
||||||
return fmt.Errorf("failed to get proxy network maps: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get flow enabled status: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
|
||||||
|
|
||||||
for _, peer := range account.Peers {
|
|
||||||
if !c.peersUpdateManager.HasChannel(peer.ID) {
|
|
||||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
semaphore <- struct{}{}
|
|
||||||
go func(p *nbpeer.Peer) {
|
|
||||||
defer wg.Done()
|
|
||||||
defer func() { <-semaphore }()
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
postureChecks, err := c.getPeerPostureChecks(account, p.ID)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
|
||||||
start = time.Now()
|
|
||||||
|
|
||||||
var remotePeerNetworkMap *types.NetworkMap
|
|
||||||
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
|
||||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
|
||||||
} else {
|
|
||||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
|
||||||
if ok {
|
|
||||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
peerGroups := account.GetPeerGroups(p.ID)
|
|
||||||
start = time.Now()
|
|
||||||
update := grpc.ToSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
|
||||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
|
||||||
|
|
||||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update})
|
|
||||||
}(peer)
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
if c.accountManagerMetrics != nil {
|
|
||||||
c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
|
||||||
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
|
|
||||||
|
|
||||||
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
|
||||||
b := bufUpd.(*bufferUpdate)
|
|
||||||
|
|
||||||
if !b.mu.TryLock() {
|
|
||||||
b.update.Store(true)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if b.next != nil {
|
|
||||||
b.next.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer b.mu.Unlock()
|
|
||||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
|
||||||
if !b.update.Load() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b.update.Store(false)
|
|
||||||
if b.next == nil {
|
|
||||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
|
||||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdatePeers updates all peers that belong to an account.
|
|
||||||
// Should be called when changes have to be synced to peers.
|
|
||||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
|
||||||
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
|
||||||
return fmt.Errorf("recalculate network map cache: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.sendUpdateAccountPeers(ctx, accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
|
|
||||||
if !c.peersUpdateManager.HasChannel(peerId) {
|
|
||||||
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to send out updates to peer %s: %v", peerId, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
peer := account.GetPeer(peerId)
|
|
||||||
if peer == nil {
|
|
||||||
return fmt.Errorf("peer %s doesn't exists in account %s", peerId, accountId)
|
|
||||||
}
|
|
||||||
|
|
||||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get validated peers: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsCache := &cache.DNSConfigCache{}
|
|
||||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
|
||||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
|
||||||
routers := account.GetResourceRoutersMap()
|
|
||||||
|
|
||||||
postureChecks, err := c.getPeerPostureChecks(account, peerId)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err)
|
|
||||||
return fmt.Errorf("failed to get posture checks for peer %s: %v", peerId, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var remotePeerNetworkMap *types.NetworkMap
|
|
||||||
|
|
||||||
if c.experimentalNetworkMap(accountId) {
|
|
||||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
|
||||||
} else {
|
|
||||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
|
||||||
if ok {
|
|
||||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get extra settings: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
peerGroups := account.GetPeerGroups(peerId)
|
|
||||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
|
||||||
|
|
||||||
update := grpc.ToSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
|
||||||
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update})
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
|
|
||||||
log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName())
|
|
||||||
|
|
||||||
bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
|
||||||
b := bufUpd.(*bufferUpdate)
|
|
||||||
|
|
||||||
if !b.mu.TryLock() {
|
|
||||||
b.update.Store(true)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if b.next != nil {
|
|
||||||
b.next.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer b.mu.Unlock()
|
|
||||||
_ = c.UpdateAccountPeers(ctx, accountID)
|
|
||||||
if !b.update.Load() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b.update.Store(false)
|
|
||||||
if b.next == nil {
|
|
||||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
|
||||||
_ = c.UpdateAccountPeers(ctx, accountID)
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error {
|
|
||||||
network, err := c.repo.GetAccountNetwork(ctx, accountId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peers, err := c.repo.GetAccountPeers(ctx, accountId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
|
|
||||||
c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{
|
|
||||||
RemotePeers: []*proto.RemotePeerConfig{},
|
|
||||||
RemotePeersIsEmpty: true,
|
|
||||||
NetworkMap: &proto.NetworkMap{
|
|
||||||
Serial: network.CurrentSerial(),
|
|
||||||
RemotePeers: []*proto.RemotePeerConfig{},
|
|
||||||
RemotePeersIsEmpty: true,
|
|
||||||
FirewallRules: []*proto.FirewallRule{},
|
|
||||||
FirewallRulesIsEmpty: true,
|
|
||||||
DNSConfig: &proto.DNSConfig{
|
|
||||||
ForwarderPort: dnsFwdPort,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.peersUpdateManager.CloseChannel(ctx, peerId)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
|
||||||
if isRequiresApproval {
|
|
||||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
emptyMap := &types.NetworkMap{
|
|
||||||
Network: network.Copy(),
|
|
||||||
}
|
|
||||||
return peer, emptyMap, nil, 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
account *types.Account
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
|
||||||
account = c.getAccountFromHolderOrInit(accountID)
|
|
||||||
} else {
|
|
||||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
startPosture := time.Now()
|
|
||||||
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, 0, err
|
|
||||||
}
|
|
||||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
|
||||||
|
|
||||||
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
|
|
||||||
|
|
||||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
|
||||||
return nil, nil, nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var networkMap *types.NetworkMap
|
|
||||||
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
|
||||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
|
||||||
} else {
|
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics)
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
|
||||||
if ok {
|
|
||||||
networkMap.Merge(proxyNetworkMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
|
||||||
|
|
||||||
return peer, networkMap, postureChecks, dnsFwdPort, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
|
|
||||||
c.enrichAccountFromHolder(account)
|
|
||||||
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) getPeerNetworkMapExp(
|
|
||||||
ctx context.Context,
|
|
||||||
accountId string,
|
|
||||||
peerId string,
|
|
||||||
validatedPeers map[string]struct{},
|
|
||||||
customZone nbdns.CustomZone,
|
|
||||||
metrics *telemetry.AccountManagerMetrics,
|
|
||||||
) *types.NetworkMap {
|
|
||||||
account := c.getAccountFromHolderOrInit(accountId)
|
|
||||||
if account == nil {
|
|
||||||
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
|
|
||||||
return &types.NetworkMap{
|
|
||||||
Network: &types.Network{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
|
||||||
c.enrichAccountFromHolder(account)
|
|
||||||
return account.OnPeerAddedUpdNetworkMapCache(peerId)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
|
||||||
c.enrichAccountFromHolder(account)
|
|
||||||
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
|
|
||||||
account := c.getAccountFromHolder(accountId)
|
|
||||||
if account == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
account.UpdatePeerInNetworkMapCache(peer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
|
|
||||||
account.RecalculateNetworkMapCache(validatedPeers)
|
|
||||||
c.updateAccountInHolder(account)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
|
|
||||||
if c.experimentalNetworkMap(accountId) {
|
|
||||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.recalculateNetworkMapCache(account, validatedPeers)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) experimentalNetworkMap(accountId string) bool {
|
|
||||||
_, ok := c.expNewNetworkMapAIDs[accountId]
|
|
||||||
return c.expNewNetworkMap || ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
|
|
||||||
a := c.holder.GetAccount(account.Id)
|
|
||||||
if a == nil {
|
|
||||||
c.holder.AddAccount(account)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
account.NetworkMapCache = a.NetworkMapCache
|
|
||||||
if account.NetworkMapCache == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
account.NetworkMapCache.UpdateAccountPointer(account)
|
|
||||||
c.holder.AddAccount(account)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
|
|
||||||
return c.holder.GetAccount(accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) getAccountFromHolderOrInit(accountID string) *types.Account {
|
|
||||||
a := c.holder.GetAccount(accountID)
|
|
||||||
if a != nil {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountWithBackpressure)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return account
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) updateAccountInHolder(account *types.Account) {
|
|
||||||
c.holder.AddAccount(account)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDNSDomain returns the configured dnsDomain
|
|
||||||
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
|
|
||||||
if settings == nil {
|
|
||||||
return c.dnsDomain
|
|
||||||
}
|
|
||||||
if settings.DNSDomain == "" {
|
|
||||||
return c.dnsDomain
|
|
||||||
}
|
|
||||||
|
|
||||||
return settings.DNSDomain
|
|
||||||
}
|
|
||||||
|
|
||||||
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
|
||||||
func (c *Controller) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) {
|
|
||||||
peerPostureChecks := make(map[string]*posture.Checks)
|
|
||||||
|
|
||||||
if len(account.PostureChecks) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, policy := range account.Policies {
|
|
||||||
if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return maps.Values(peerPostureChecks), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) StartWarmup(ctx context.Context) {
|
|
||||||
var initialInterval int64
|
|
||||||
intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS")
|
|
||||||
interval, err := strconv.Atoi(intervalStr)
|
|
||||||
if err != nil {
|
|
||||||
initialInterval = 1
|
|
||||||
log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err)
|
|
||||||
} else {
|
|
||||||
initialInterval = int64(interval) * 10
|
|
||||||
go func() {
|
|
||||||
startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S")
|
|
||||||
startupPeriod, err := strconv.Atoi(startupPeriodStr)
|
|
||||||
if err != nil {
|
|
||||||
startupPeriod = 1
|
|
||||||
log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err)
|
|
||||||
}
|
|
||||||
time.Sleep(time.Duration(startupPeriod) * time.Second)
|
|
||||||
c.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond))
|
|
||||||
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
c.updateAccountPeersBufferInterval.Store(int64(time.Duration(initialInterval) * time.Millisecond))
|
|
||||||
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// computeForwarderPort checks if all peers in the account have updated to a specific version or newer.
|
|
||||||
// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
|
|
||||||
func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
|
|
||||||
if len(peers) == 0 {
|
|
||||||
return int64(network_map.OldForwarderPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
reqVer := semver.Canonical(requiredVersion)
|
|
||||||
|
|
||||||
// Check if all peers have the required version or newer
|
|
||||||
for _, peer := range peers {
|
|
||||||
|
|
||||||
// Development version is always supported
|
|
||||||
if peer.Meta.WtVersion == "development" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
|
|
||||||
if peerVersion == "" {
|
|
||||||
// If any peer doesn't have version info, return 0
|
|
||||||
return int64(network_map.OldForwarderPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compare versions
|
|
||||||
if semver.Compare(peerVersion, reqVer) < 0 {
|
|
||||||
return int64(network_map.OldForwarderPort)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// All peers have the required version or newer
|
|
||||||
return int64(network_map.DnsForwarderPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
|
|
||||||
func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error {
|
|
||||||
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isInGroup {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
|
||||||
postureCheck := account.GetPostureChecks(sourcePostureCheckID)
|
|
||||||
if postureCheck == nil {
|
|
||||||
return errors.New("failed to add policy posture checks: posture checks not found")
|
|
||||||
}
|
|
||||||
peerPostureChecks[sourcePostureCheckID] = postureCheck
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
|
|
||||||
func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) {
|
|
||||||
for _, rule := range policy.Rules {
|
|
||||||
if !rule.Enabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, sourceGroup := range rule.Sources {
|
|
||||||
group := account.GetGroup(sourceGroup)
|
|
||||||
if group == nil {
|
|
||||||
return false, fmt.Errorf("failed to check peer in policy source group: group not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
if slices.Contains(group.Peers, peerID) {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) {
|
|
||||||
c.UpdatePeerInNetworkMapCache(accountId, peer)
|
|
||||||
_ = c.bufferSendUpdateAccountPeers(context.Background(), accountId)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error {
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
|
||||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = c.onPeerAddedUpdNetworkMapCache(account, peerID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error {
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
|
||||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
|
||||||
func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
|
|
||||||
account, err := c.repo.GetAccountByPeerID(ctx, peerID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
peer := account.GetPeer(peerID)
|
|
||||||
if peer == nil {
|
|
||||||
return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
groups := make(map[string][]string)
|
|
||||||
for groupID, group := range account.Groups {
|
|
||||||
groups[groupID] = group.Peers
|
|
||||||
}
|
|
||||||
|
|
||||||
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
|
|
||||||
|
|
||||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var networkMap *types.NetworkMap
|
|
||||||
|
|
||||||
if c.experimentalNetworkMap(peer.AccountID) {
|
|
||||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
|
|
||||||
} else {
|
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
|
||||||
if ok {
|
|
||||||
networkMap.Merge(proxyNetworkMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
return networkMap, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) {
|
|
||||||
c.peersUpdateManager.CloseChannels(ctx, peerIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) IsConnected(peerID string) bool {
|
|
||||||
return c.peersUpdateManager.HasChannel(peerID)
|
|
||||||
}
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user