mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-12 10:59:54 +00:00
Compare commits
69 Commits
move-licen
...
test-ldfla
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
619f1588b3 | ||
|
|
60d2a2c7df | ||
|
|
ff9585735b | ||
|
|
194951c88d | ||
|
|
d63f2e5196 | ||
|
|
14b5637555 | ||
|
|
3d3b05c157 | ||
|
|
f36e206238 | ||
|
|
9a808244d7 | ||
|
|
2dec76f8ea | ||
|
|
224bd8ff22 | ||
|
|
fe88a5662e | ||
|
|
f9f6409f94 | ||
|
|
b03154dce5 | ||
|
|
c57364596a | ||
|
|
60f4d5f9b0 | ||
|
|
4eeb2d8deb | ||
|
|
2765bcfb89 | ||
|
|
d71a82769c | ||
|
|
fa6151b849 | ||
|
|
a939c1767c | ||
|
|
938554fb0f | ||
|
|
0d79301141 | ||
|
|
39bec2dd74 | ||
|
|
554c9bcf4b | ||
|
|
f3639675e7 | ||
|
|
a1457f541b | ||
|
|
9cdfb0d78c | ||
|
|
22d796097e | ||
|
|
aa39a5d528 | ||
|
|
e4b41d0ad7 | ||
|
|
9cc9462cd5 | ||
|
|
3176b53968 | ||
|
|
27957036c9 | ||
|
|
6fb568728f | ||
|
|
cc97cffff1 | ||
|
|
1d2a5371ce | ||
|
|
6898e57686 | ||
|
|
c8bc865f2f | ||
|
|
06bb8658b1 | ||
|
|
8fc4fed3a0 | ||
|
|
df14f1399f | ||
|
|
6d6f090764 | ||
|
|
c28275611b | ||
|
|
56f169eede | ||
|
|
07cf9d5895 | ||
|
|
7df49e249d | ||
|
|
dbfc8a52c9 | ||
|
|
98ddac07bf | ||
|
|
48475ddc05 | ||
|
|
6aa4ba7af4 | ||
|
|
2e16c9914a | ||
|
|
5c29d395b2 | ||
|
|
229e0038ee | ||
|
|
75327d9519 | ||
|
|
13febbbfca | ||
|
|
49d36b7e7e | ||
|
|
976787dbf1 | ||
|
|
536b0003ab | ||
|
|
0e9438d658 | ||
|
|
e570570fe5 | ||
|
|
23f9dd04b8 | ||
|
|
7a95bf5652 | ||
|
|
60c5782905 | ||
|
|
5a12c5d345 | ||
|
|
bdae55ab79 | ||
|
|
d01c3d5011 | ||
|
|
17a2af96ea | ||
|
|
cc595da1ad |
119
.github/workflows/check-license-dependencies.yml
vendored
119
.github/workflows/check-license-dependencies.yml
vendored
@@ -3,39 +3,108 @@ name: Check License Dependencies
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
|
||||
jobs:
|
||||
check-dependencies:
|
||||
check-internal-dependencies:
|
||||
name: Check Internal AGPL Dependencies
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||
echo ""
|
||||
|
||||
# Find all directories except the problematic ones and system dirs
|
||||
FOUND_ISSUES=0
|
||||
while IFS= read -r dir; do
|
||||
echo "=== Checking $dir ==="
|
||||
# Search for problematic imports, excluding test files
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
if [ -n "$RESULTS" ]; then
|
||||
echo "❌ Found problematic dependencies:"
|
||||
echo "$RESULTS"
|
||||
FOUND_ISSUES=1
|
||||
else
|
||||
echo "✓ No problematic dependencies found"
|
||||
fi
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
||||
|
||||
echo ""
|
||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||
exit 1
|
||||
else
|
||||
echo ""
|
||||
echo "✅ All internal license dependencies are clean"
|
||||
fi
|
||||
|
||||
check-external-licenses:
|
||||
name: Check External GPL/AGPL Licenses
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: 'go.mod'
|
||||
cache: true
|
||||
|
||||
# Find all directories except the problematic ones and system dirs
|
||||
FOUND_ISSUES=0
|
||||
find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort | while read dir; do
|
||||
echo "=== Checking $dir ==="
|
||||
# Search for problematic imports, excluding test files
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
if [ ! -z "$RESULTS" ]; then
|
||||
echo "❌ Found problematic dependencies:"
|
||||
echo "$RESULTS"
|
||||
FOUND_ISSUES=1
|
||||
else
|
||||
echo "✓ No problematic dependencies found"
|
||||
- name: Install go-licenses
|
||||
run: go install github.com/google/go-licenses@v1.6.0
|
||||
|
||||
- name: Check for GPL/AGPL licensed dependencies
|
||||
run: |
|
||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||
echo ""
|
||||
|
||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||
|
||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||
echo "Found copyleft licensed dependencies:"
|
||||
echo "$COPYLEFT_DEPS"
|
||||
echo ""
|
||||
|
||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||
INCOMPATIBLE=""
|
||||
while IFS=',' read -r package url license; do
|
||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||
# Find ALL packages that import this GPL package using go list
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||
else
|
||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||
fi
|
||||
fi
|
||||
done <<< "$COPYLEFT_DEPS"
|
||||
|
||||
if [ -n "$INCOMPATIBLE" ]; then
|
||||
echo ""
|
||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||
echo -e "$INCOMPATIBLE"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||
echo ""
|
||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||
echo "These packages will change license and should not be imported by client or shared code"
|
||||
exit 1
|
||||
else
|
||||
echo ""
|
||||
echo "✅ All license dependencies are clean"
|
||||
fi
|
||||
|
||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||
|
||||
2
.github/workflows/wasm-build-validation.yml
vendored
2
.github/workflows/wasm-build-validation.yml
vendored
@@ -47,7 +47,7 @@ jobs:
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
- name: Build Wasm client
|
||||
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
||||
run: GOOS=js GOARCH=wasm go build -o netbird.wasm -ldflags="-s -w" ./client/wasm/cmd
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
- name: Check Wasm build size
|
||||
|
||||
@@ -17,9 +17,9 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
|
||||
@@ -200,7 +200,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
}
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -201,6 +201,94 @@ func (p *Preferences) SetServerSSHAllowed(allowed bool) {
|
||||
p.configInput.ServerSSHAllowed = &allowed
|
||||
}
|
||||
|
||||
// GetEnableSSHRoot reads SSH root login setting from config file
|
||||
func (p *Preferences) GetEnableSSHRoot() (bool, error) {
|
||||
if p.configInput.EnableSSHRoot != nil {
|
||||
return *p.configInput.EnableSSHRoot, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.EnableSSHRoot == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.EnableSSHRoot, err
|
||||
}
|
||||
|
||||
// SetEnableSSHRoot stores the given value and waits for commit
|
||||
func (p *Preferences) SetEnableSSHRoot(enabled bool) {
|
||||
p.configInput.EnableSSHRoot = &enabled
|
||||
}
|
||||
|
||||
// GetEnableSSHSFTP reads SSH SFTP setting from config file
|
||||
func (p *Preferences) GetEnableSSHSFTP() (bool, error) {
|
||||
if p.configInput.EnableSSHSFTP != nil {
|
||||
return *p.configInput.EnableSSHSFTP, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.EnableSSHSFTP == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.EnableSSHSFTP, err
|
||||
}
|
||||
|
||||
// SetEnableSSHSFTP stores the given value and waits for commit
|
||||
func (p *Preferences) SetEnableSSHSFTP(enabled bool) {
|
||||
p.configInput.EnableSSHSFTP = &enabled
|
||||
}
|
||||
|
||||
// GetEnableSSHLocalPortForwarding reads SSH local port forwarding setting from config file
|
||||
func (p *Preferences) GetEnableSSHLocalPortForwarding() (bool, error) {
|
||||
if p.configInput.EnableSSHLocalPortForwarding != nil {
|
||||
return *p.configInput.EnableSSHLocalPortForwarding, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.EnableSSHLocalPortForwarding == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.EnableSSHLocalPortForwarding, err
|
||||
}
|
||||
|
||||
// SetEnableSSHLocalPortForwarding stores the given value and waits for commit
|
||||
func (p *Preferences) SetEnableSSHLocalPortForwarding(enabled bool) {
|
||||
p.configInput.EnableSSHLocalPortForwarding = &enabled
|
||||
}
|
||||
|
||||
// GetEnableSSHRemotePortForwarding reads SSH remote port forwarding setting from config file
|
||||
func (p *Preferences) GetEnableSSHRemotePortForwarding() (bool, error) {
|
||||
if p.configInput.EnableSSHRemotePortForwarding != nil {
|
||||
return *p.configInput.EnableSSHRemotePortForwarding, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.EnableSSHRemotePortForwarding == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.EnableSSHRemotePortForwarding, err
|
||||
}
|
||||
|
||||
// SetEnableSSHRemotePortForwarding stores the given value and waits for commit
|
||||
func (p *Preferences) SetEnableSSHRemotePortForwarding(enabled bool) {
|
||||
p.configInput.EnableSSHRemotePortForwarding = &enabled
|
||||
}
|
||||
|
||||
// GetBlockInbound reads block inbound setting from config file
|
||||
func (p *Preferences) GetBlockInbound() (bool, error) {
|
||||
if p.configInput.BlockInbound != nil {
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
@@ -98,7 +97,6 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||
SystemInfo: systemInfoFlag,
|
||||
LogFileCount: logFileCount,
|
||||
}
|
||||
@@ -220,9 +218,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
|
||||
|
||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||
return waitErr
|
||||
}
|
||||
@@ -230,11 +225,8 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
SystemInfo: systemInfoFlag,
|
||||
LogFileCount: logFileCount,
|
||||
}
|
||||
@@ -301,25 +293,6 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
var statusOutputString string
|
||||
statusResp, err := getStatus(cmd.Context(), true)
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||
} else {
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
|
||||
)
|
||||
}
|
||||
return statusOutputString
|
||||
}
|
||||
|
||||
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
@@ -378,7 +351,7 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
|
||||
InternalConfig: config,
|
||||
StatusRecorder: recorder,
|
||||
SyncResponse: syncResponse,
|
||||
LogFile: logFilePath,
|
||||
LogPath: logFilePath,
|
||||
},
|
||||
debug.BundleConfig{
|
||||
IncludeSystemInfo: true,
|
||||
|
||||
@@ -106,6 +106,13 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
|
||||
Username: &username,
|
||||
}
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
loginRequest.Hint = &profileState.Email
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
loginRequest.OptionalPreSharedKey = &preSharedKey
|
||||
}
|
||||
@@ -241,7 +248,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
return fmt.Errorf("read config file %s: %v", configFilePath, err)
|
||||
}
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -269,7 +276,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error {
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||
needsLogin := false
|
||||
|
||||
err := WithBackOff(func() error {
|
||||
@@ -286,7 +293,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
|
||||
jwtToken := ""
|
||||
if setupKey == "" && needsLogin {
|
||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config)
|
||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
@@ -315,8 +322,17 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
|
||||
hint := ""
|
||||
pm := profilemanager.NewProfileManager()
|
||||
profileState, err := pm.GetProfileState(profileName)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
hint = profileState.Email
|
||||
}
|
||||
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -35,7 +35,6 @@ const (
|
||||
wireguardPortFlag = "wireguard-port"
|
||||
networkMonitorFlag = "network-monitor"
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
enableLazyConnectionFlag = "enable-lazy-connection"
|
||||
@@ -64,7 +63,6 @@ var (
|
||||
customDNSAddress string
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
serverSSHAllowed bool
|
||||
interfaceName string
|
||||
wireguardPort uint16
|
||||
networkMonitor bool
|
||||
@@ -176,7 +174,6 @@ func init() {
|
||||
)
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
|
||||
|
||||
|
||||
@@ -259,6 +259,7 @@ func isServiceRunning() (bool, error) {
|
||||
}
|
||||
|
||||
const (
|
||||
networkdConf = "/etc/systemd/networkd.conf"
|
||||
networkdConfDir = "/etc/systemd/networkd.conf.d"
|
||||
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
|
||||
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
|
||||
@@ -273,12 +274,16 @@ ManageForeignRoutingPolicyRules=no
|
||||
// configureSystemdNetworkd creates a drop-in configuration file to prevent
|
||||
// systemd-networkd from removing NetBird's routes and policy rules.
|
||||
func configureSystemdNetworkd() error {
|
||||
parentDir := filepath.Dir(networkdConfDir)
|
||||
if _, err := os.Stat(parentDir); os.IsNotExist(err) {
|
||||
log.Debug("systemd networkd.conf.d parent directory does not exist, skipping configuration")
|
||||
if _, err := os.Stat(networkdConf); os.IsNotExist(err) {
|
||||
log.Debug("systemd-networkd not in use, skipping configuration")
|
||||
return nil
|
||||
}
|
||||
|
||||
// nolint:gosec // standard networkd permissions
|
||||
if err := os.MkdirAll(networkdConfDir, 0755); err != nil {
|
||||
return fmt.Errorf("create networkd.conf.d directory: %w", err)
|
||||
}
|
||||
|
||||
// nolint:gosec // standard networkd permissions
|
||||
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
|
||||
return fmt.Errorf("write networkd configuration: %w", err)
|
||||
|
||||
@@ -3,125 +3,809 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"os/user"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
sshproxy "github.com/netbirdio/netbird/client/ssh/proxy"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
port int
|
||||
userName = "root"
|
||||
host string
|
||||
const (
|
||||
sshUsernameDesc = "SSH username"
|
||||
hostArgumentRequired = "host argument required"
|
||||
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
enableSSHRootFlag = "enable-ssh-root"
|
||||
enableSSHSFTPFlag = "enable-ssh-sftp"
|
||||
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
||||
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
||||
disableSSHAuthFlag = "disable-ssh-auth"
|
||||
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
||||
)
|
||||
|
||||
var sshCmd = &cobra.Command{
|
||||
Use: "ssh [user@]host",
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New("requires a host argument")
|
||||
}
|
||||
var (
|
||||
port int
|
||||
username string
|
||||
host string
|
||||
command string
|
||||
localForwards []string
|
||||
remoteForwards []string
|
||||
strictHostKeyChecking bool
|
||||
knownHostsFile string
|
||||
identityFile string
|
||||
skipCachedToken bool
|
||||
requestPTY bool
|
||||
)
|
||||
|
||||
split := strings.Split(args[0], "@")
|
||||
if len(split) == 2 {
|
||||
userName = split[0]
|
||||
host = split[1]
|
||||
} else {
|
||||
host = args[0]
|
||||
}
|
||||
var (
|
||||
serverSSHAllowed bool
|
||||
enableSSHRoot bool
|
||||
enableSSHSFTP bool
|
||||
enableSSHLocalPortForward bool
|
||||
enableSSHRemotePortForward bool
|
||||
disableSSHAuth bool
|
||||
sshJWTCacheTTL int
|
||||
)
|
||||
|
||||
return nil
|
||||
},
|
||||
Short: "Connect to a remote SSH server",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHRoot, enableSSHRootFlag, false, "Enable root login for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
||||
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
||||
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
||||
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||
sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation")
|
||||
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
|
||||
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
|
||||
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
|
||||
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
|
||||
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
|
||||
err := util.InitLog(logLevel, util.LogConsole)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
||||
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
|
||||
|
||||
if !util.IsAdmin() {
|
||||
cmd.Printf("error: you must have Administrator privileges to run this command\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
sm := profilemanager.NewServiceManager(configPath)
|
||||
activeProf, err := sm.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile: %v", err)
|
||||
}
|
||||
profPath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile path: %v", err)
|
||||
}
|
||||
|
||||
config, err := profilemanager.ReadConfig(profPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read profile config: %v", err)
|
||||
}
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||
sshctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
// blocking
|
||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sig:
|
||||
cancel()
|
||||
case <-sshctx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
sshCmd.AddCommand(sshSftpCmd)
|
||||
sshCmd.AddCommand(sshProxyCmd)
|
||||
sshCmd.AddCommand(sshDetectCmd)
|
||||
}
|
||||
|
||||
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
|
||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), userName, pemKey)
|
||||
if err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
||||
"\nYou can verify the connection by running:\n\n" +
|
||||
" netbird status\n\n")
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
err = c.Close()
|
||||
if err != nil {
|
||||
return
|
||||
var sshCmd = &cobra.Command{
|
||||
Use: "ssh [flags] [user@]host [command]",
|
||||
Short: "Connect to a NetBird peer via SSH",
|
||||
Long: `Connect to a NetBird peer using SSH with support for port forwarding.
|
||||
|
||||
Port Forwarding:
|
||||
-L [bind_address:]port:host:hostport Local port forwarding
|
||||
-L [bind_address:]port:/path/to/socket Local port forwarding to Unix socket
|
||||
-R [bind_address:]port:host:hostport Remote port forwarding
|
||||
-R [bind_address:]port:/path/to/socket Remote port forwarding to Unix socket
|
||||
|
||||
SSH Options:
|
||||
-p, --port int Remote SSH port (default 22)
|
||||
-u, --user string SSH username
|
||||
--login string SSH username (alias for --user)
|
||||
-t, --tty Force pseudo-terminal allocation
|
||||
--strict-host-key-checking Enable strict host key checking (default: true)
|
||||
-o, --known-hosts string Path to known_hosts file
|
||||
|
||||
Examples:
|
||||
netbird ssh peer-hostname
|
||||
netbird ssh root@peer-hostname
|
||||
netbird ssh --login root peer-hostname
|
||||
netbird ssh peer-hostname ls -la
|
||||
netbird ssh peer-hostname whoami
|
||||
netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen
|
||||
netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo
|
||||
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
|
||||
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
|
||||
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
|
||||
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
|
||||
DisableFlagParsing: true,
|
||||
Args: validateSSHArgsWithoutFlagParsing,
|
||||
RunE: sshFn,
|
||||
Aliases: []string{"ssh"},
|
||||
}
|
||||
|
||||
func sshFn(cmd *cobra.Command, args []string) error {
|
||||
for _, arg := range args {
|
||||
if arg == "-h" || arg == "--help" {
|
||||
return cmd.Help()
|
||||
}
|
||||
}
|
||||
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
logOutput := "console"
|
||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||
logOutput = firstLogFile
|
||||
}
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||
sshctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := runSSH(sshctx, host, cmd); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err = c.OpenTerminal()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-sig:
|
||||
cancel()
|
||||
<-sshctx.Done()
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-sshctx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort))
|
||||
// getEnvOrDefault checks for environment variables with WT_ and NB_ prefixes
|
||||
func getEnvOrDefault(flagName, defaultValue string) string {
|
||||
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
|
||||
return envValue
|
||||
}
|
||||
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
|
||||
return envValue
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// resetSSHGlobals sets SSH globals to their default values
|
||||
func resetSSHGlobals() {
|
||||
port = sshserver.DefaultSSHPort
|
||||
username = ""
|
||||
host = ""
|
||||
command = ""
|
||||
localForwards = nil
|
||||
remoteForwards = nil
|
||||
strictHostKeyChecking = true
|
||||
knownHostsFile = ""
|
||||
identityFile = ""
|
||||
}
|
||||
|
||||
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
|
||||
func parseCustomSSHFlags(args []string) ([]string, []string, []string) {
|
||||
var localForwardFlags []string
|
||||
var remoteForwardFlags []string
|
||||
var filteredArgs []string
|
||||
|
||||
for i := 0; i < len(args); i++ {
|
||||
arg := args[i]
|
||||
switch {
|
||||
case strings.HasPrefix(arg, "-L"):
|
||||
localForwardFlags, i = parseForwardFlag(arg, args, i, localForwardFlags)
|
||||
case strings.HasPrefix(arg, "-R"):
|
||||
remoteForwardFlags, i = parseForwardFlag(arg, args, i, remoteForwardFlags)
|
||||
default:
|
||||
filteredArgs = append(filteredArgs, arg)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredArgs, localForwardFlags, remoteForwardFlags
|
||||
}
|
||||
|
||||
func parseForwardFlag(arg string, args []string, i int, flags []string) ([]string, int) {
|
||||
if arg == "-L" || arg == "-R" {
|
||||
if i+1 < len(args) {
|
||||
flags = append(flags, args[i+1])
|
||||
i++
|
||||
}
|
||||
} else if len(arg) > 2 {
|
||||
flags = append(flags, arg[2:])
|
||||
}
|
||||
return flags, i
|
||||
}
|
||||
|
||||
// extractGlobalFlags parses global flags that were passed before 'ssh' command
|
||||
func extractGlobalFlags(args []string) {
|
||||
sshPos := findSSHCommandPosition(args)
|
||||
if sshPos == -1 {
|
||||
return
|
||||
}
|
||||
|
||||
globalArgs := args[:sshPos]
|
||||
parseGlobalArgs(globalArgs)
|
||||
}
|
||||
|
||||
// findSSHCommandPosition locates the 'ssh' command in the argument list
|
||||
func findSSHCommandPosition(args []string) int {
|
||||
for i, arg := range args {
|
||||
if arg == "ssh" {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
const (
|
||||
configFlag = "config"
|
||||
logLevelFlag = "log-level"
|
||||
logFileFlag = "log-file"
|
||||
)
|
||||
|
||||
// parseGlobalArgs processes the global arguments and sets the corresponding variables
|
||||
func parseGlobalArgs(globalArgs []string) {
|
||||
flagHandlers := map[string]func(string){
|
||||
configFlag: func(value string) { configPath = value },
|
||||
logLevelFlag: func(value string) { logLevel = value },
|
||||
logFileFlag: func(value string) {
|
||||
if !slices.Contains(logFiles, value) {
|
||||
logFiles = append(logFiles, value)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
shortFlags := map[string]string{
|
||||
"c": configFlag,
|
||||
"l": logLevelFlag,
|
||||
}
|
||||
|
||||
for i := 0; i < len(globalArgs); i++ {
|
||||
arg := globalArgs[i]
|
||||
|
||||
if handled, nextIndex := parseFlag(arg, globalArgs, i, flagHandlers, shortFlags); handled {
|
||||
i = nextIndex
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseFlag handles generic flag parsing for both long and short forms
|
||||
func parseFlag(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (bool, int) {
|
||||
if parsedValue, found := parseEqualsFormat(arg, flagHandlers, shortFlags); found {
|
||||
flagHandlers[parsedValue.flagName](parsedValue.value)
|
||||
return true, currentIndex
|
||||
}
|
||||
|
||||
if parsedValue, found := parseSpacedFormat(arg, args, currentIndex, flagHandlers, shortFlags); found {
|
||||
flagHandlers[parsedValue.flagName](parsedValue.value)
|
||||
return true, currentIndex + 1
|
||||
}
|
||||
|
||||
return false, currentIndex
|
||||
}
|
||||
|
||||
type parsedFlag struct {
|
||||
flagName string
|
||||
value string
|
||||
}
|
||||
|
||||
// parseEqualsFormat handles --flag=value and -f=value formats
|
||||
func parseEqualsFormat(arg string, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
|
||||
if !strings.Contains(arg, "=") {
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
parts := strings.SplitN(arg, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
if strings.HasPrefix(parts[0], "--") {
|
||||
flagName := strings.TrimPrefix(parts[0], "--")
|
||||
if _, exists := flagHandlers[flagName]; exists {
|
||||
return parsedFlag{flagName: flagName, value: parts[1]}, true
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(parts[0], "-") && len(parts[0]) == 2 {
|
||||
shortFlag := strings.TrimPrefix(parts[0], "-")
|
||||
if longFlag, exists := shortFlags[shortFlag]; exists {
|
||||
if _, exists := flagHandlers[longFlag]; exists {
|
||||
return parsedFlag{flagName: longFlag, value: parts[1]}, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
// parseSpacedFormat handles --flag value and -f value formats
|
||||
func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
|
||||
if currentIndex+1 >= len(args) {
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
if strings.HasPrefix(arg, "--") {
|
||||
flagName := strings.TrimPrefix(arg, "--")
|
||||
if _, exists := flagHandlers[flagName]; exists {
|
||||
return parsedFlag{flagName: flagName, value: args[currentIndex+1]}, true
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(arg, "-") && len(arg) == 2 {
|
||||
shortFlag := strings.TrimPrefix(arg, "-")
|
||||
if longFlag, exists := shortFlags[shortFlag]; exists {
|
||||
if _, exists := flagHandlers[longFlag]; exists {
|
||||
return parsedFlag{flagName: longFlag, value: args[currentIndex+1]}, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
// createSSHFlagSet creates and configures the flag set for SSH command parsing
|
||||
// sshFlags contains all SSH-related flags and parameters
|
||||
type sshFlags struct {
|
||||
Port int
|
||||
Username string
|
||||
Login string
|
||||
RequestPTY bool
|
||||
StrictHostKeyChecking bool
|
||||
KnownHostsFile string
|
||||
IdentityFile string
|
||||
SkipCachedToken bool
|
||||
ConfigPath string
|
||||
LogLevel string
|
||||
LocalForwards []string
|
||||
RemoteForwards []string
|
||||
Host string
|
||||
Command string
|
||||
}
|
||||
|
||||
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
||||
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
||||
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||
|
||||
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
||||
fs.SetOutput(nil)
|
||||
|
||||
flags := &sshFlags{}
|
||||
|
||||
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
|
||||
fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port")
|
||||
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
|
||||
fs.StringVar(&flags.Username, "user", "", sshUsernameDesc)
|
||||
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||
fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation")
|
||||
fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation")
|
||||
|
||||
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
|
||||
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
|
||||
fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file")
|
||||
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
|
||||
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
|
||||
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
|
||||
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
|
||||
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
|
||||
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
|
||||
fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level")
|
||||
|
||||
return fs, flags
|
||||
}
|
||||
|
||||
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
resetSSHGlobals()
|
||||
|
||||
if len(os.Args) > 2 {
|
||||
extractGlobalFlags(os.Args[1:])
|
||||
}
|
||||
|
||||
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
|
||||
|
||||
fs, flags := createSSHFlagSet()
|
||||
|
||||
if err := fs.Parse(filteredArgs); err != nil {
|
||||
if errors.Is(err, flag.ErrHelp) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
remaining := fs.Args()
|
||||
if len(remaining) < 1 {
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
port = flags.Port
|
||||
if flags.Username != "" {
|
||||
username = flags.Username
|
||||
} else if flags.Login != "" {
|
||||
username = flags.Login
|
||||
}
|
||||
|
||||
requestPTY = flags.RequestPTY
|
||||
strictHostKeyChecking = flags.StrictHostKeyChecking
|
||||
knownHostsFile = flags.KnownHostsFile
|
||||
identityFile = flags.IdentityFile
|
||||
skipCachedToken = flags.SkipCachedToken
|
||||
|
||||
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
|
||||
configPath = flags.ConfigPath
|
||||
}
|
||||
if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) {
|
||||
logLevel = flags.LogLevel
|
||||
}
|
||||
|
||||
localForwards = localForwardFlags
|
||||
remoteForwards = remoteForwardFlags
|
||||
|
||||
return parseHostnameAndCommand(remaining)
|
||||
}
|
||||
|
||||
func parseHostnameAndCommand(args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
arg := args[0]
|
||||
if strings.Contains(arg, "@") {
|
||||
parts := strings.SplitN(arg, "@", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return errors.New("invalid user@host format")
|
||||
}
|
||||
if username == "" {
|
||||
username = parts[0]
|
||||
}
|
||||
host = parts[1]
|
||||
} else {
|
||||
host = arg
|
||||
}
|
||||
|
||||
if username == "" {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
username = sudoUser
|
||||
} else if currentUser, err := user.Current(); err == nil {
|
||||
username = currentUser.Username
|
||||
} else {
|
||||
username = "root"
|
||||
}
|
||||
}
|
||||
|
||||
// Everything after hostname becomes the command
|
||||
if len(args) > 1 {
|
||||
command = strings.Join(args[1:], " ")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
||||
target := fmt.Sprintf("%s:%d", addr, port)
|
||||
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
|
||||
KnownHostsFile: knownHostsFile,
|
||||
IdentityFile: identityFile,
|
||||
DaemonAddr: daemonAddr,
|
||||
SkipCachedToken: skipCachedToken,
|
||||
InsecureSkipVerify: !strictHostKeyChecking,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
cmd.Printf("Failed to connect to %s@%s\n", username, target)
|
||||
cmd.Printf("\nTroubleshooting steps:\n")
|
||||
cmd.Printf(" 1. Check peer connectivity: netbird status -d\n")
|
||||
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
|
||||
cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
|
||||
return fmt.Errorf("dial %s: %w", target, err)
|
||||
}
|
||||
|
||||
sshCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-sshCtx.Done()
|
||||
if err := c.Close(); err != nil {
|
||||
cmd.Printf("Error closing SSH connection: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := startPortForwarding(sshCtx, c, cmd); err != nil {
|
||||
return fmt.Errorf("start port forwarding: %w", err)
|
||||
}
|
||||
|
||||
if command != "" {
|
||||
return executeSSHCommand(sshCtx, c, command)
|
||||
}
|
||||
return openSSHTerminal(sshCtx, c)
|
||||
}
|
||||
|
||||
// executeSSHCommand executes a command over SSH.
|
||||
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
|
||||
var err error
|
||||
if requestPTY {
|
||||
err = c.ExecuteCommandWithPTY(ctx, command)
|
||||
} else {
|
||||
err = c.ExecuteCommandWithIO(ctx, command)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var exitErr *ssh.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
os.Exit(exitErr.ExitStatus())
|
||||
}
|
||||
|
||||
var exitMissingErr *ssh.ExitMissingError
|
||||
if errors.As(err, &exitMissingErr) {
|
||||
log.Debugf("Remote command exited without exit status: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("execute command: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// openSSHTerminal opens an interactive SSH terminal.
|
||||
func openSSHTerminal(ctx context.Context, c *sshclient.Client) error {
|
||||
if err := c.OpenTerminal(ctx); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var exitMissingErr *ssh.ExitMissingError
|
||||
if errors.As(err, &exitMissingErr) {
|
||||
log.Debugf("Remote terminal exited without exit status: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("open terminal: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// startPortForwarding starts local and remote port forwarding based on command line flags
|
||||
func startPortForwarding(ctx context.Context, c *sshclient.Client, cmd *cobra.Command) error {
|
||||
for _, forward := range localForwards {
|
||||
if err := parseAndStartLocalForward(ctx, c, forward, cmd); err != nil {
|
||||
return fmt.Errorf("local port forward %s: %w", forward, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, forward := range remoteForwards {
|
||||
if err := parseAndStartRemoteForward(ctx, c, forward, cmd); err != nil {
|
||||
return fmt.Errorf("remote port forward %s: %w", forward, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseAndStartLocalForward parses and starts a local port forward (-L)
|
||||
func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
|
||||
localAddr, remoteAddr, err := parsePortForwardSpec(forward)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
|
||||
|
||||
go func() {
|
||||
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||
cmd.Printf("Local port forward error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseAndStartRemoteForward parses and starts a remote port forward (-R)
|
||||
func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
|
||||
remoteAddr, localAddr, err := parsePortForwardSpec(forward)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
|
||||
|
||||
go func() {
|
||||
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||
cmd.Printf("Remote port forward error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
|
||||
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
|
||||
func parsePortForwardSpec(spec string) (string, string, error) {
|
||||
// Support formats:
|
||||
// port:host:hostport -> localhost:port -> host:hostport
|
||||
// host:port:host:hostport -> host:port -> host:hostport
|
||||
// [host]:port:host:hostport -> [host]:port -> host:hostport
|
||||
// port:unix_socket_path -> localhost:port -> unix_socket_path
|
||||
// host:port:unix_socket_path -> host:port -> unix_socket_path
|
||||
|
||||
if strings.HasPrefix(spec, "[") && strings.Contains(spec, "]:") {
|
||||
return parseIPv6ForwardSpec(spec)
|
||||
}
|
||||
|
||||
parts := strings.Split(spec, ":")
|
||||
if len(parts) < 2 {
|
||||
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_target)", spec)
|
||||
}
|
||||
|
||||
switch len(parts) {
|
||||
case 2:
|
||||
return parseTwoPartForwardSpec(parts, spec)
|
||||
case 3:
|
||||
return parseThreePartForwardSpec(parts)
|
||||
case 4:
|
||||
return parseFourPartForwardSpec(parts)
|
||||
default:
|
||||
return "", "", fmt.Errorf("invalid port forward specification: %s", spec)
|
||||
}
|
||||
}
|
||||
|
||||
// parseTwoPartForwardSpec handles "port:unix_socket" format.
|
||||
func parseTwoPartForwardSpec(parts []string, spec string) (string, string, error) {
|
||||
if isUnixSocket(parts[1]) {
|
||||
localAddr := "localhost:" + parts[0]
|
||||
remoteAddr := parts[1]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_host:remote_port or [local_host:]local_port:unix_socket)", spec)
|
||||
}
|
||||
|
||||
// parseThreePartForwardSpec handles "port:host:hostport" or "host:port:unix_socket" formats.
|
||||
func parseThreePartForwardSpec(parts []string) (string, string, error) {
|
||||
if isUnixSocket(parts[2]) {
|
||||
localHost := normalizeLocalHost(parts[0])
|
||||
localAddr := localHost + ":" + parts[1]
|
||||
remoteAddr := parts[2]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
localAddr := "localhost:" + parts[0]
|
||||
remoteAddr := parts[1] + ":" + parts[2]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
// parseFourPartForwardSpec handles "host:port:host:hostport" format.
|
||||
func parseFourPartForwardSpec(parts []string) (string, string, error) {
|
||||
localHost := normalizeLocalHost(parts[0])
|
||||
localAddr := localHost + ":" + parts[1]
|
||||
remoteAddr := parts[2] + ":" + parts[3]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
// parseIPv6ForwardSpec handles "[host]:port:host:hostport" format.
|
||||
func parseIPv6ForwardSpec(spec string) (string, string, error) {
|
||||
idx := strings.Index(spec, "]:")
|
||||
if idx == -1 {
|
||||
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s", spec)
|
||||
}
|
||||
|
||||
ipv6Host := spec[:idx+1]
|
||||
remaining := spec[idx+2:]
|
||||
|
||||
parts := strings.Split(remaining, ":")
|
||||
if len(parts) != 3 {
|
||||
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s (expected [ipv6]:port:host:hostport)", spec)
|
||||
}
|
||||
|
||||
localAddr := ipv6Host + ":" + parts[0]
|
||||
remoteAddr := parts[1] + ":" + parts[2]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
// isUnixSocket checks if a path is a Unix socket path.
|
||||
func isUnixSocket(path string) bool {
|
||||
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
|
||||
}
|
||||
|
||||
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
|
||||
func normalizeLocalHost(host string) string {
|
||||
if host == "*" {
|
||||
return "0.0.0.0"
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
var sshProxyCmd = &cobra.Command{
|
||||
Use: "proxy <host> <port>",
|
||||
Short: "Internal SSH proxy for native SSH client integration",
|
||||
Long: "Internal command used by SSH ProxyCommand to handle JWT authentication",
|
||||
Hidden: true,
|
||||
Args: cobra.ExactArgs(2),
|
||||
RunE: sshProxyFn,
|
||||
}
|
||||
|
||||
func sshProxyFn(cmd *cobra.Command, args []string) error {
|
||||
logOutput := "console"
|
||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||
logOutput = firstLogFile
|
||||
}
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
host := args[0]
|
||||
portStr := args[1]
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid port: %s", portStr)
|
||||
}
|
||||
|
||||
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
|
||||
if err != nil {
|
||||
return fmt.Errorf("create SSH proxy: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxy.Close(); err != nil {
|
||||
log.Debugf("close SSH proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := proxy.Connect(cmd.Context()); err != nil {
|
||||
return fmt.Errorf("SSH proxy: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var sshDetectCmd = &cobra.Command{
|
||||
Use: "detect <host> <port>",
|
||||
Short: "Detect if a host is running NetBird SSH",
|
||||
Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH",
|
||||
Hidden: true,
|
||||
Args: cobra.ExactArgs(2),
|
||||
RunE: sshDetectFn,
|
||||
}
|
||||
|
||||
func sshDetectFn(cmd *cobra.Command, args []string) error {
|
||||
if err := util.InitLog(logLevel, "console"); err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
host := args[0]
|
||||
portStr := args[1]
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
|
||||
if err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
os.Exit(serverType.ExitCode())
|
||||
return nil
|
||||
}
|
||||
|
||||
74
client/cmd/ssh_exec_unix.go
Normal file
74
client/cmd/ssh_exec_unix.go
Normal file
@@ -0,0 +1,74 @@
|
||||
//go:build unix
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
)
|
||||
|
||||
var (
|
||||
sshExecUID uint32
|
||||
sshExecGID uint32
|
||||
sshExecGroups []uint
|
||||
sshExecWorkingDir string
|
||||
sshExecShell string
|
||||
sshExecCommand string
|
||||
sshExecPTY bool
|
||||
)
|
||||
|
||||
// sshExecCmd represents the hidden ssh exec subcommand for privilege dropping
|
||||
var sshExecCmd = &cobra.Command{
|
||||
Use: "exec",
|
||||
Short: "Internal SSH execution with privilege dropping (hidden)",
|
||||
Hidden: true,
|
||||
RunE: runSSHExec,
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshExecCmd.Flags().Uint32Var(&sshExecUID, "uid", 0, "Target user ID")
|
||||
sshExecCmd.Flags().Uint32Var(&sshExecGID, "gid", 0, "Target group ID")
|
||||
sshExecCmd.Flags().UintSliceVar(&sshExecGroups, "groups", nil, "Supplementary group IDs (can be repeated)")
|
||||
sshExecCmd.Flags().StringVar(&sshExecWorkingDir, "working-dir", "", "Working directory")
|
||||
sshExecCmd.Flags().StringVar(&sshExecShell, "shell", "/bin/sh", "Shell to execute")
|
||||
sshExecCmd.Flags().BoolVar(&sshExecPTY, "pty", false, "Request PTY (will fail as executor doesn't support PTY)")
|
||||
sshExecCmd.Flags().StringVar(&sshExecCommand, "cmd", "", "Command to execute")
|
||||
|
||||
if err := sshExecCmd.MarkFlagRequired("uid"); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "failed to mark uid flag as required: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := sshExecCmd.MarkFlagRequired("gid"); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "failed to mark gid flag as required: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
sshCmd.AddCommand(sshExecCmd)
|
||||
}
|
||||
|
||||
// runSSHExec handles the SSH exec subcommand execution.
|
||||
func runSSHExec(cmd *cobra.Command, _ []string) error {
|
||||
privilegeDropper := sshserver.NewPrivilegeDropper()
|
||||
|
||||
var groups []uint32
|
||||
for _, groupInt := range sshExecGroups {
|
||||
groups = append(groups, uint32(groupInt))
|
||||
}
|
||||
|
||||
config := sshserver.ExecutorConfig{
|
||||
UID: sshExecUID,
|
||||
GID: sshExecGID,
|
||||
Groups: groups,
|
||||
WorkingDir: sshExecWorkingDir,
|
||||
Shell: sshExecShell,
|
||||
Command: sshExecCommand,
|
||||
PTY: sshExecPTY,
|
||||
}
|
||||
|
||||
privilegeDropper.ExecuteWithPrivilegeDrop(cmd.Context(), config)
|
||||
return nil
|
||||
}
|
||||
94
client/cmd/ssh_sftp_unix.go
Normal file
94
client/cmd/ssh_sftp_unix.go
Normal file
@@ -0,0 +1,94 @@
|
||||
//go:build unix
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
)
|
||||
|
||||
var (
|
||||
sftpUID uint32
|
||||
sftpGID uint32
|
||||
sftpGroupsInt []uint
|
||||
sftpWorkingDir string
|
||||
)
|
||||
|
||||
var sshSftpCmd = &cobra.Command{
|
||||
Use: "sftp",
|
||||
Short: "SFTP server with privilege dropping (internal use)",
|
||||
Hidden: true,
|
||||
RunE: sftpMain,
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshSftpCmd.Flags().Uint32Var(&sftpUID, "uid", 0, "Target user ID")
|
||||
sshSftpCmd.Flags().Uint32Var(&sftpGID, "gid", 0, "Target group ID")
|
||||
sshSftpCmd.Flags().UintSliceVar(&sftpGroupsInt, "groups", nil, "Supplementary group IDs (can be repeated)")
|
||||
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
|
||||
}
|
||||
|
||||
func sftpMain(cmd *cobra.Command, _ []string) error {
|
||||
privilegeDropper := sshserver.NewPrivilegeDropper()
|
||||
|
||||
var groups []uint32
|
||||
for _, groupInt := range sftpGroupsInt {
|
||||
groups = append(groups, uint32(groupInt))
|
||||
}
|
||||
|
||||
config := sshserver.ExecutorConfig{
|
||||
UID: sftpUID,
|
||||
GID: sftpGID,
|
||||
Groups: groups,
|
||||
WorkingDir: sftpWorkingDir,
|
||||
Shell: "",
|
||||
Command: "",
|
||||
}
|
||||
|
||||
log.Tracef("dropping privileges for SFTP to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups)
|
||||
|
||||
if err := privilegeDropper.DropPrivileges(config.UID, config.GID, config.Groups); err != nil {
|
||||
cmd.PrintErrf("privilege drop failed: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodePrivilegeDropFail)
|
||||
}
|
||||
|
||||
if config.WorkingDir != "" {
|
||||
if err := os.Chdir(config.WorkingDir); err != nil {
|
||||
cmd.PrintErrf("failed to change to working directory %s: %v\n", config.WorkingDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
sftpServer, err := sftp.NewServer(struct {
|
||||
io.Reader
|
||||
io.WriteCloser
|
||||
}{
|
||||
Reader: os.Stdin,
|
||||
WriteCloser: os.Stdout,
|
||||
})
|
||||
if err != nil {
|
||||
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
log.Tracef("starting SFTP server with dropped privileges")
|
||||
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
|
||||
cmd.PrintErrf("SFTP server error: %v\n", err)
|
||||
if closeErr := sftpServer.Close(); closeErr != nil {
|
||||
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
|
||||
}
|
||||
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
if closeErr := sftpServer.Close(); closeErr != nil {
|
||||
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
|
||||
}
|
||||
os.Exit(sshserver.ExitCodeSuccess)
|
||||
return nil
|
||||
}
|
||||
94
client/cmd/ssh_sftp_windows.go
Normal file
94
client/cmd/ssh_sftp_windows.go
Normal file
@@ -0,0 +1,94 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/user"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
)
|
||||
|
||||
var (
|
||||
sftpWorkingDir string
|
||||
windowsUsername string
|
||||
windowsDomain string
|
||||
)
|
||||
|
||||
var sshSftpCmd = &cobra.Command{
|
||||
Use: "sftp",
|
||||
Short: "SFTP server with user switching for Windows (internal use)",
|
||||
Hidden: true,
|
||||
RunE: sftpMain,
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
|
||||
sshSftpCmd.Flags().StringVar(&windowsUsername, "windows-username", "", "Windows username for user switching")
|
||||
sshSftpCmd.Flags().StringVar(&windowsDomain, "windows-domain", "", "Windows domain for user switching")
|
||||
}
|
||||
|
||||
func sftpMain(cmd *cobra.Command, _ []string) error {
|
||||
return sftpMainDirect(cmd)
|
||||
}
|
||||
|
||||
func sftpMainDirect(cmd *cobra.Command) error {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
cmd.PrintErrf("failed to get current user: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodeValidationFail)
|
||||
}
|
||||
|
||||
if windowsUsername != "" {
|
||||
expectedUsername := windowsUsername
|
||||
if windowsDomain != "" {
|
||||
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
|
||||
}
|
||||
if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) {
|
||||
cmd.PrintErrf("user switching failed\n")
|
||||
os.Exit(sshserver.ExitCodeValidationFail)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("SFTP process running as: %s (UID: %s, Name: %s)", currentUser.Username, currentUser.Uid, currentUser.Name)
|
||||
|
||||
if sftpWorkingDir != "" {
|
||||
if err := os.Chdir(sftpWorkingDir); err != nil {
|
||||
cmd.PrintErrf("failed to change to working directory %s: %v\n", sftpWorkingDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
sftpServer, err := sftp.NewServer(struct {
|
||||
io.Reader
|
||||
io.WriteCloser
|
||||
}{
|
||||
Reader: os.Stdin,
|
||||
WriteCloser: os.Stdout,
|
||||
})
|
||||
if err != nil {
|
||||
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
log.Debugf("starting SFTP server")
|
||||
exitCode := sshserver.ExitCodeSuccess
|
||||
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
|
||||
cmd.PrintErrf("SFTP server error: %v\n", err)
|
||||
exitCode = sshserver.ExitCodeShellExecFail
|
||||
}
|
||||
|
||||
if err := sftpServer.Close(); err != nil {
|
||||
log.Debugf("SFTP server close error: %v", err)
|
||||
}
|
||||
|
||||
os.Exit(exitCode)
|
||||
return nil
|
||||
}
|
||||
717
client/cmd/ssh_test.go
Normal file
717
client/cmd/ssh_test.go
Normal file
@@ -0,0 +1,717 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSSHCommand_FlagParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedUser string
|
||||
expectedPort int
|
||||
expectedCmd string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "basic host",
|
||||
args: []string{"hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "",
|
||||
},
|
||||
{
|
||||
name: "user@host format",
|
||||
args: []string{"user@hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "user",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "",
|
||||
},
|
||||
{
|
||||
name: "host with command",
|
||||
args: []string{"hostname", "echo", "hello"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "echo hello",
|
||||
},
|
||||
{
|
||||
name: "command with flags should be preserved",
|
||||
args: []string{"hostname", "ls", "-la", "/tmp"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "ls -la /tmp",
|
||||
},
|
||||
{
|
||||
name: "double dash separator",
|
||||
args: []string{"hostname", "--", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "-- ls -la",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
// Mock command for testing
|
||||
cmd := sshCmd
|
||||
cmd.SetArgs(tt.args)
|
||||
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
if tt.expectedUser != "" {
|
||||
assert.Equal(t, tt.expectedUser, username, "username mismatch")
|
||||
}
|
||||
assert.Equal(t, tt.expectedPort, port, "port mismatch")
|
||||
assert.Equal(t, tt.expectedCmd, command, "command mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_FlagConflictPrevention(t *testing.T) {
|
||||
// Test that SSH flags don't conflict with command flags
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedCmd string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ls with -la flags",
|
||||
args: []string{"hostname", "ls", "-la"},
|
||||
expectedCmd: "ls -la",
|
||||
description: "ls flags should be passed to remote command",
|
||||
},
|
||||
{
|
||||
name: "grep with -r flag",
|
||||
args: []string{"hostname", "grep", "-r", "pattern", "/path"},
|
||||
expectedCmd: "grep -r pattern /path",
|
||||
description: "grep flags should be passed to remote command",
|
||||
},
|
||||
{
|
||||
name: "ps with aux flags",
|
||||
args: []string{"hostname", "ps", "aux"},
|
||||
expectedCmd: "ps aux",
|
||||
description: "ps flags should be passed to remote command",
|
||||
},
|
||||
{
|
||||
name: "command with double dash",
|
||||
args: []string{"hostname", "--", "ls", "-la"},
|
||||
expectedCmd: "-- ls -la",
|
||||
description: "double dash should be preserved in command",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_NonInteractiveExecution(t *testing.T) {
|
||||
// Test that commands with arguments should execute the command and exit,
|
||||
// not drop to an interactive shell
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedCmd string
|
||||
shouldExit bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ls command should execute and exit",
|
||||
args: []string{"hostname", "ls"},
|
||||
expectedCmd: "ls",
|
||||
shouldExit: true,
|
||||
description: "ls command should execute and exit, not drop to shell",
|
||||
},
|
||||
{
|
||||
name: "ls with flags should execute and exit",
|
||||
args: []string{"hostname", "ls", "-la"},
|
||||
expectedCmd: "ls -la",
|
||||
shouldExit: true,
|
||||
description: "ls with flags should execute and exit, not drop to shell",
|
||||
},
|
||||
{
|
||||
name: "pwd command should execute and exit",
|
||||
args: []string{"hostname", "pwd"},
|
||||
expectedCmd: "pwd",
|
||||
shouldExit: true,
|
||||
description: "pwd command should execute and exit, not drop to shell",
|
||||
},
|
||||
{
|
||||
name: "echo command should execute and exit",
|
||||
args: []string{"hostname", "echo", "hello"},
|
||||
expectedCmd: "echo hello",
|
||||
shouldExit: true,
|
||||
description: "echo command should execute and exit, not drop to shell",
|
||||
},
|
||||
{
|
||||
name: "no command should open shell",
|
||||
args: []string{"hostname"},
|
||||
expectedCmd: "",
|
||||
shouldExit: false,
|
||||
description: "no command should open interactive shell",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
|
||||
// When command is present, it should execute the command and exit
|
||||
// When command is empty, it should open interactive shell
|
||||
hasCommand := command != ""
|
||||
assert.Equal(t, tt.shouldExit, hasCommand, "Command presence should match expected behavior")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_FlagHandling(t *testing.T) {
|
||||
// Test that flags after hostname are not parsed by netbird but passed to SSH command
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedCmd string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ls with -la flag should not be parsed by netbird",
|
||||
args: []string{"debian2", "ls", "-la"},
|
||||
expectedHost: "debian2",
|
||||
expectedCmd: "ls -la",
|
||||
expectError: false,
|
||||
description: "ls -la should be passed as SSH command, not parsed as netbird flags",
|
||||
},
|
||||
{
|
||||
name: "command with netbird-like flags should be passed through",
|
||||
args: []string{"hostname", "echo", "--help"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "echo --help",
|
||||
expectError: false,
|
||||
description: "--help should be passed to echo, not parsed by netbird",
|
||||
},
|
||||
{
|
||||
name: "command with -p flag should not conflict with SSH port flag",
|
||||
args: []string{"hostname", "ps", "-p", "1234"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "ps -p 1234",
|
||||
expectError: false,
|
||||
description: "ps -p should be passed to ps command, not parsed as port",
|
||||
},
|
||||
{
|
||||
name: "tar with flags should be passed through",
|
||||
args: []string{"hostname", "tar", "-czf", "backup.tar.gz", "/home"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "tar -czf backup.tar.gz /home",
|
||||
expectError: false,
|
||||
description: "tar flags should be passed to tar command",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_RegressionFlagParsing(t *testing.T) {
|
||||
// Regression test for the specific issue: "sudo ./netbird ssh debian2 ls -la"
|
||||
// should not parse -la as netbird flags but pass them to the SSH command
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedCmd string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "original issue: ls -la should be preserved",
|
||||
args: []string{"debian2", "ls", "-la"},
|
||||
expectedHost: "debian2",
|
||||
expectedCmd: "ls -la",
|
||||
expectError: false,
|
||||
description: "The original failing case should now work",
|
||||
},
|
||||
{
|
||||
name: "ls -l should be preserved",
|
||||
args: []string{"hostname", "ls", "-l"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "ls -l",
|
||||
expectError: false,
|
||||
description: "Single letter flags should be preserved",
|
||||
},
|
||||
{
|
||||
name: "SSH port flag should work",
|
||||
args: []string{"-p", "2222", "hostname", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "ls -la",
|
||||
expectError: false,
|
||||
description: "SSH -p flag should be parsed, command flags preserved",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
|
||||
// Check port for the test case with -p flag
|
||||
if len(tt.args) > 0 && tt.args[0] == "-p" {
|
||||
assert.Equal(t, 2222, port, "port should be parsed from -p flag")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_PortForwardingFlagParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedLocal []string
|
||||
expectedRemote []string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "local port forwarding -L",
|
||||
args: []string{"-L", "8080:localhost:80", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Single -L flag should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "remote port forwarding -R",
|
||||
args: []string{"-R", "8080:localhost:80", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{},
|
||||
expectedRemote: []string{"8080:localhost:80"},
|
||||
expectError: false,
|
||||
description: "Single -R flag should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "multiple local port forwards",
|
||||
args: []string{"-L", "8080:localhost:80", "-L", "9090:localhost:443", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80", "9090:localhost:443"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Multiple -L flags should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "multiple remote port forwards",
|
||||
args: []string{"-R", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{},
|
||||
expectedRemote: []string{"8080:localhost:80", "9090:localhost:443"},
|
||||
expectError: false,
|
||||
description: "Multiple -R flags should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "mixed local and remote forwards",
|
||||
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{"9090:localhost:443"},
|
||||
expectError: false,
|
||||
description: "Mixed -L and -R flags should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "port forwarding with bind address",
|
||||
args: []string{"-L", "127.0.0.1:8080:localhost:80", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"127.0.0.1:8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Port forwarding with bind address should work",
|
||||
},
|
||||
{
|
||||
name: "port forwarding with command",
|
||||
args: []string{"-L", "8080:localhost:80", "hostname", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Port forwarding with command should work",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
localForwards = nil
|
||||
remoteForwards = nil
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
// Handle nil vs empty slice comparison
|
||||
if len(tt.expectedLocal) == 0 {
|
||||
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
|
||||
}
|
||||
if len(tt.expectedRemote) == 0 {
|
||||
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePortForward(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spec string
|
||||
expectedLocal string
|
||||
expectedRemote string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "simple port forward",
|
||||
spec: "8080:localhost:80",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "Simple port:host:port format should work",
|
||||
},
|
||||
{
|
||||
name: "port forward with bind address",
|
||||
spec: "127.0.0.1:8080:localhost:80",
|
||||
expectedLocal: "127.0.0.1:8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "bind_address:port:host:port format should work",
|
||||
},
|
||||
{
|
||||
name: "port forward to different host",
|
||||
spec: "8080:example.com:443",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "example.com:443",
|
||||
expectError: false,
|
||||
description: "Forwarding to different host should work",
|
||||
},
|
||||
{
|
||||
name: "port forward with IPv6 (needs bracket support)",
|
||||
spec: "::1:8080:localhost:80",
|
||||
expectError: true,
|
||||
description: "IPv6 without brackets fails as expected (feature to implement)",
|
||||
},
|
||||
{
|
||||
name: "invalid format - too few parts",
|
||||
spec: "8080:localhost",
|
||||
expectError: true,
|
||||
description: "Invalid format with too few parts should fail",
|
||||
},
|
||||
{
|
||||
name: "invalid format - too many parts",
|
||||
spec: "127.0.0.1:8080:localhost:80:extra",
|
||||
expectError: true,
|
||||
description: "Invalid format with too many parts should fail",
|
||||
},
|
||||
{
|
||||
name: "empty spec",
|
||||
spec: "",
|
||||
expectError: true,
|
||||
description: "Empty spec should fail",
|
||||
},
|
||||
{
|
||||
name: "unix socket local forward",
|
||||
spec: "8080:/tmp/socket",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "/tmp/socket",
|
||||
expectError: false,
|
||||
description: "Unix socket forwarding should work",
|
||||
},
|
||||
{
|
||||
name: "unix socket with bind address",
|
||||
spec: "127.0.0.1:8080:/tmp/socket",
|
||||
expectedLocal: "127.0.0.1:8080",
|
||||
expectedRemote: "/tmp/socket",
|
||||
expectError: false,
|
||||
description: "Unix socket with bind address should work",
|
||||
},
|
||||
{
|
||||
name: "wildcard bind all interfaces",
|
||||
spec: "*:8080:localhost:80",
|
||||
expectedLocal: "0.0.0.0:8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
|
||||
},
|
||||
{
|
||||
name: "wildcard for port only",
|
||||
spec: "8080:*:80",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "*:80",
|
||||
expectError: false,
|
||||
description: "Wildcard in remote host should be preserved",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
localAddr, remoteAddr, err := parsePortForwardSpec(tt.spec)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, tt.description)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, tt.description)
|
||||
assert.Equal(t, tt.expectedLocal, localAddr, tt.description+" - local address")
|
||||
assert.Equal(t, tt.expectedRemote, remoteAddr, tt.description+" - remote address")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_IntegrationPortForwarding(t *testing.T) {
|
||||
// Integration test for port forwarding with the actual SSH command implementation
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedLocal []string
|
||||
expectedRemote []string
|
||||
expectedCmd string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "local forward with command",
|
||||
args: []string{"-L", "8080:localhost:80", "hostname", "echo", "test"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectedCmd: "echo test",
|
||||
description: "Local forwarding should work with commands",
|
||||
},
|
||||
{
|
||||
name: "remote forward with command",
|
||||
args: []string{"-R", "8080:localhost:80", "hostname", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{},
|
||||
expectedRemote: []string{"8080:localhost:80"},
|
||||
expectedCmd: "ls -la",
|
||||
description: "Remote forwarding should work with commands",
|
||||
},
|
||||
{
|
||||
name: "multiple forwards with user and command",
|
||||
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "user@hostname", "ps", "aux"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{"9090:localhost:443"},
|
||||
expectedCmd: "ps aux",
|
||||
description: "Complex case with multiple forwards, user, and command",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
localForwards = nil
|
||||
remoteForwards = nil
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
// Handle nil vs empty slice comparison
|
||||
if len(tt.expectedLocal) == 0 {
|
||||
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
|
||||
}
|
||||
if len(tt.expectedRemote) == 0 {
|
||||
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
|
||||
}
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description+" - command")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_ParameterIsolation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedCmd string
|
||||
}{
|
||||
{
|
||||
name: "cmd flag passed as command",
|
||||
args: []string{"hostname", "--cmd", "echo test"},
|
||||
expectedCmd: "--cmd echo test",
|
||||
},
|
||||
{
|
||||
name: "uid flag passed as command",
|
||||
args: []string{"hostname", "--uid", "1000"},
|
||||
expectedCmd: "--uid 1000",
|
||||
},
|
||||
{
|
||||
name: "shell flag passed as command",
|
||||
args: []string{"hostname", "--shell", "/bin/bash"},
|
||||
expectedCmd: "--shell /bin/bash",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "hostname", host)
|
||||
assert.Equal(t, tt.expectedCmd, command)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_InvalidFlagRejection(t *testing.T) {
|
||||
// Test that invalid flags are properly rejected and not misinterpreted as hostnames
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "invalid long flag before hostname",
|
||||
args: []string{"--invalid-flag", "hostname"},
|
||||
description: "Invalid flag should return parse error, not treat flag as hostname",
|
||||
},
|
||||
{
|
||||
name: "invalid short flag before hostname",
|
||||
args: []string{"-x", "hostname"},
|
||||
description: "Invalid short flag should return parse error",
|
||||
},
|
||||
{
|
||||
name: "invalid flag with value before hostname",
|
||||
args: []string{"--invalid-option=value", "hostname"},
|
||||
description: "Invalid flag with value should return parse error",
|
||||
},
|
||||
{
|
||||
name: "typo in known flag",
|
||||
args: []string{"--por", "2222", "hostname"},
|
||||
description: "Typo in flag name should return parse error (not silently ignored)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
|
||||
|
||||
// Should return an error for invalid flags
|
||||
assert.Error(t, err, tt.description)
|
||||
|
||||
// Should not have set host to the invalid flag
|
||||
assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -99,7 +99,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
@@ -109,7 +109,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
case yamlFlag:
|
||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||
default:
|
||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
|
||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -13,6 +13,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
client "github.com/netbirdio/netbird/client/server"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
@@ -84,7 +89,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
jobManager := job.NewJobManager(nil, store)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
@@ -110,13 +115,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, jobManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
|
||||
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -200,7 +200,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||
|
||||
return connectClient.Run(nil)
|
||||
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
|
||||
@@ -286,6 +286,13 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
|
||||
loginRequest.ProfileName = &activeProf.Name
|
||||
loginRequest.Username = &username
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
loginRequest.Hint = &profileState.Email
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
@@ -348,6 +355,25 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
req.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
req.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
req.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
req.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
req.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
}
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
log.Errorf("parse interface name: %v", err)
|
||||
@@ -432,6 +458,30 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
ic.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
ic.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
ic.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return nil, err
|
||||
@@ -532,6 +582,31 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
loginRequest.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||
}
|
||||
|
||||
@@ -18,12 +18,16 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
|
||||
var ErrClientAlreadyStarted = errors.New("client already started")
|
||||
var ErrClientNotStarted = errors.New("client not started")
|
||||
var ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
var (
|
||||
ErrClientAlreadyStarted = errors.New("client already started")
|
||||
ErrClientNotStarted = errors.New("client not started")
|
||||
ErrEngineNotStarted = errors.New("engine not started")
|
||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
)
|
||||
|
||||
// Client manages a netbird embedded client instance.
|
||||
type Client struct {
|
||||
@@ -169,6 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
}
|
||||
|
||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||
|
||||
client := internal.NewConnectClient(ctx, c.config, recorder)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
@@ -176,7 +181,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
run := make(chan struct{})
|
||||
clientErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := client.Run(run); err != nil {
|
||||
if err := client.Run(run, ""); err != nil {
|
||||
clientErr <- err
|
||||
}
|
||||
}()
|
||||
@@ -238,17 +243,9 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
|
||||
// Dial dials a network address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
if connect == nil {
|
||||
c.mu.Unlock()
|
||||
return nil, ErrClientNotStarted
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
engine := connect.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine not started")
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nsnet, err := engine.GetNet()
|
||||
@@ -259,6 +256,11 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
|
||||
return nsnet.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// DialContext dials a network address in the netbird network with context
|
||||
func (c *Client) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return c.Dial(ctx, network, address)
|
||||
}
|
||||
|
||||
// ListenTCP listens on the given address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
@@ -314,18 +316,47 @@ func (c *Client) NewHTTPClient() *http.Client {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
|
||||
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
|
||||
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
|
||||
func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
storedKey, found := engine.GetPeerSSHKey(peerAddress)
|
||||
if !found {
|
||||
return sshcommon.ErrPeerNotFound
|
||||
}
|
||||
|
||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||
}
|
||||
|
||||
// getEngine safely retrieves the engine from the client with proper locking.
|
||||
// Returns ErrClientNotStarted if the client is not started.
|
||||
// Returns ErrEngineNotStarted if the engine is not available.
|
||||
func (c *Client) getEngine() (*internal.Engine, error) {
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
if connect == nil {
|
||||
c.mu.Unlock()
|
||||
return nil, netip.Addr{}, errors.New("client not started")
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if connect == nil {
|
||||
return nil, ErrClientNotStarted
|
||||
}
|
||||
|
||||
engine := connect.Engine()
|
||||
if engine == nil {
|
||||
return nil, netip.Addr{}, errors.New("engine not started")
|
||||
return nil, ErrEngineNotStarted
|
||||
}
|
||||
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, netip.Addr{}, err
|
||||
}
|
||||
|
||||
addr, err := engine.Address()
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/uuid"
|
||||
"github.com/nadoo/ipset"
|
||||
ipset "github.com/lrh3321/ipset-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -40,19 +41,13 @@ type aclManager struct {
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
||||
m := &aclManager{
|
||||
return &aclManager{
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
entries: make(map[string][][]string),
|
||||
optionalEntries: make(map[string][]entry),
|
||||
ipsetStore: newIpsetStore(),
|
||||
}
|
||||
|
||||
if err := ipset.Init(); err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||
@@ -98,8 +93,8 @@ func (m *aclManager) AddPeerFiltering(
|
||||
specs = append(specs, "-j", actionToStr(action))
|
||||
if ipsetName != "" {
|
||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||
}
|
||||
// if ruleset already exists it means we already have the firewall rule
|
||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||
@@ -113,14 +108,18 @@ func (m *aclManager) AddPeerFiltering(
|
||||
}}, nil
|
||||
}
|
||||
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
|
||||
if err := m.flushIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
if err := ipset.Create(ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
||||
if err := m.createIPSet(ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("create ipset: %w", err)
|
||||
}
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||
}
|
||||
|
||||
ipList := newIpList(ip.String())
|
||||
@@ -172,11 +171,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
shouldDestroyIpset := false
|
||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||
// delete IP from ruleset IPs list and ipset
|
||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
||||
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
||||
ip := net.ParseIP(r.ip)
|
||||
if ip == nil {
|
||||
return fmt.Errorf("parse IP %s", r.ip)
|
||||
}
|
||||
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
|
||||
return fmt.Errorf("delete ip from ipset: %w", err)
|
||||
}
|
||||
delete(ipsetList.ips, r.ip)
|
||||
}
|
||||
@@ -190,10 +194,7 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// we delete last IP from the set, that means we need to delete
|
||||
// set itself and associated firewall rule too
|
||||
m.ipsetStore.deleteIpset(r.ipsetName)
|
||||
|
||||
if err := ipset.Destroy(r.ipsetName); err != nil {
|
||||
log.Errorf("delete empty ipset: %v", err)
|
||||
}
|
||||
shouldDestroyIpset = true
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
||||
@@ -206,6 +207,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
}
|
||||
}
|
||||
|
||||
if shouldDestroyIpset {
|
||||
if err := m.destroyIPSet(r.ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("destroy empty ipset: %v", err)
|
||||
} else {
|
||||
log.Errorf("destroy empty ipset: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
@@ -264,11 +275,19 @@ func (m *aclManager) cleanChains() error {
|
||||
}
|
||||
|
||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
if err := m.flushIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
if err := ipset.Destroy(ipsetName); err != nil {
|
||||
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
||||
if err := m.destroyIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
m.ipsetStore.deleteIpset(ipsetName)
|
||||
}
|
||||
@@ -368,8 +387,8 @@ func (m *aclManager) updateState() {
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if ip.String() == "0.0.0.0" {
|
||||
// don't use IP matching if IP is 0.0.0.0
|
||||
if ip.IsUnspecified() {
|
||||
matchByIP = false
|
||||
}
|
||||
|
||||
@@ -416,3 +435,61 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
|
||||
return ipsetName + actionSuffix
|
||||
}
|
||||
}
|
||||
|
||||
func (m *aclManager) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
log.Debugf("created ipset %s with type hash:net", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
|
||||
cidr := uint8(32)
|
||||
if ip.To4() == nil {
|
||||
cidr = 128
|
||||
}
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: cidr,
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Add(name, entry); err != nil {
|
||||
return fmt.Errorf("add IP to ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
|
||||
cidr := uint8(32)
|
||||
if ip.To4() == nil {
|
||||
cidr = 128
|
||||
}
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: cidr,
|
||||
}
|
||||
|
||||
if err := ipset.Del(name, entry); err != nil {
|
||||
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) flushIPSet(name string) error {
|
||||
return ipset.Flush(name)
|
||||
}
|
||||
|
||||
func (m *aclManager) destroyIPSet(name string) error {
|
||||
return ipset.Destroy(name)
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/nadoo/ipset"
|
||||
ipset "github.com/lrh3321/ipset-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
@@ -107,10 +107,6 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
|
||||
},
|
||||
)
|
||||
|
||||
if err := ipset.Init(); err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
@@ -232,12 +228,12 @@ func (r *router) findSets(rule []string) []string {
|
||||
}
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
|
||||
if err := r.createIPSet(setName); err != nil {
|
||||
return fmt.Errorf("create set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
for _, prefix := range sources {
|
||||
if err := ipset.AddPrefix(setName, prefix); err != nil {
|
||||
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
|
||||
return fmt.Errorf("add element to set %s: %w", setName, err)
|
||||
}
|
||||
}
|
||||
@@ -246,7 +242,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||
}
|
||||
|
||||
func (r *router) deleteIpSet(setName string) error {
|
||||
if err := ipset.Destroy(setName); err != nil {
|
||||
if err := r.destroyIPSet(setName); err != nil {
|
||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
@@ -915,8 +911,8 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
|
||||
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
||||
}
|
||||
}
|
||||
if merr == nil {
|
||||
@@ -993,3 +989,37 @@ func applyPort(flag string, port *firewall.Port) []string {
|
||||
|
||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||
}
|
||||
|
||||
func (r *router) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
log.Debugf("created ipset %s with type hash:net", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
|
||||
addr := prefix.Addr()
|
||||
ip := addr.AsSlice()
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: uint8(prefix.Bits()),
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Add(name, entry); err != nil {
|
||||
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) destroyIPSet(name string) error {
|
||||
return ipset.Destroy(name)
|
||||
}
|
||||
|
||||
@@ -35,6 +35,12 @@ const (
|
||||
ipTCPHeaderMinSize = 40
|
||||
)
|
||||
|
||||
// serviceKey represents a protocol/port combination for netstack service registry
|
||||
type serviceKey struct {
|
||||
protocol gopacket.LayerType
|
||||
port uint16
|
||||
}
|
||||
|
||||
const (
|
||||
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
||||
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||
@@ -59,12 +65,6 @@ const (
|
||||
|
||||
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||
|
||||
// serviceKey represents a protocol/port combination for netstack service registry
|
||||
type serviceKey struct {
|
||||
protocol gopacket.LayerType
|
||||
port uint16
|
||||
}
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]PeerRule
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
@@ -1114,3 +1115,138 @@ func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dst
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestShouldForward(t *testing.T) {
|
||||
// Set up test addresses
|
||||
wgIP := netip.MustParseAddr("100.10.0.1")
|
||||
otherIP := netip.MustParseAddr("100.10.0.2")
|
||||
|
||||
// Create test manager with mock interface
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
// Set the mock to return our test WG IP
|
||||
ifaceMock.AddressFunc = func() wgaddr.Address {
|
||||
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Helper to create decoder with TCP packet
|
||||
createTCPDecoder := func(dstPort uint16) *decoder {
|
||||
ipv4 := &layers.IPv4{
|
||||
Version: 4,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: net.ParseIP("192.168.1.100"),
|
||||
DstIP: wgIP.AsSlice(),
|
||||
}
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: 54321,
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
}
|
||||
|
||||
err := tcp.SetNetworkLayerForChecksum(ipv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
|
||||
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
|
||||
require.NoError(t, err)
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
localForwarding bool
|
||||
netstack bool
|
||||
dstIP netip.Addr
|
||||
serviceRegistered bool
|
||||
servicePort uint16
|
||||
expected bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "no local forwarding",
|
||||
localForwarding: false,
|
||||
netstack: true,
|
||||
dstIP: wgIP,
|
||||
expected: false,
|
||||
description: "should never forward when local forwarding disabled",
|
||||
},
|
||||
{
|
||||
name: "traffic to other local interface",
|
||||
localForwarding: true,
|
||||
netstack: false,
|
||||
dstIP: otherIP,
|
||||
expected: true,
|
||||
description: "should forward traffic to our other local interfaces (not NetBird IP)",
|
||||
},
|
||||
{
|
||||
name: "traffic to NetBird IP, no netstack",
|
||||
localForwarding: true,
|
||||
netstack: false,
|
||||
dstIP: wgIP,
|
||||
expected: false,
|
||||
description: "should send to netstack listeners (final return false path)",
|
||||
},
|
||||
{
|
||||
name: "traffic to our IP, netstack mode, no service",
|
||||
localForwarding: true,
|
||||
netstack: true,
|
||||
dstIP: wgIP,
|
||||
expected: true,
|
||||
description: "should forward when in netstack mode with no matching service",
|
||||
},
|
||||
{
|
||||
name: "traffic to our IP, netstack mode, with service",
|
||||
localForwarding: true,
|
||||
netstack: true,
|
||||
dstIP: wgIP,
|
||||
serviceRegistered: true,
|
||||
servicePort: 22,
|
||||
expected: false,
|
||||
description: "should send to netstack listeners when service is registered",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Configure manager
|
||||
manager.localForwarding = tt.localForwarding
|
||||
manager.netstack = tt.netstack
|
||||
|
||||
// Register service if needed
|
||||
if tt.serviceRegistered {
|
||||
manager.RegisterNetstackService(nftypes.TCP, tt.servicePort)
|
||||
defer manager.UnregisterNetstackService(nftypes.TCP, tt.servicePort)
|
||||
}
|
||||
|
||||
// Create decoder for the test
|
||||
decoder := createTCPDecoder(tt.servicePort)
|
||||
if !tt.serviceRegistered {
|
||||
decoder = createTCPDecoder(8080) // Use non-registered port
|
||||
}
|
||||
|
||||
// Test the method
|
||||
result := manager.shouldForward(decoder, tt.dstIP)
|
||||
require.Equal(t, tt.expected, result, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
85
client/firewall/uspfilter/nat_stateful_test.go
Normal file
85
client/firewall/uspfilter/nat_stateful_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// TestPortDNATBasic tests basic port DNAT functionality
|
||||
func TestPortDNATBasic(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define peer IPs
|
||||
peerA := netip.MustParseAddr("100.10.0.50")
|
||||
peerB := netip.MustParseAddr("100.10.0.51")
|
||||
|
||||
// Add SSH port redirection rule for peer B (the target)
|
||||
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Scenario: Peer A connects to Peer B on port 22 (should get NAT)
|
||||
packetAtoB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
|
||||
d := parsePacket(t, packetAtoB)
|
||||
translatedAtoB := manager.translateInboundPortDNAT(packetAtoB, d, peerA, peerB)
|
||||
require.True(t, translatedAtoB, "Peer A to Peer B should be translated (NAT applied)")
|
||||
|
||||
// Verify port was translated to 22022
|
||||
d = parsePacket(t, packetAtoB)
|
||||
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be rewritten to 22022")
|
||||
|
||||
// Scenario: Return traffic from Peer B to Peer A should NOT be translated
|
||||
// (prevents double NAT - original port stored in conntrack)
|
||||
returnPacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 22022, 54321)
|
||||
d2 := parsePacket(t, returnPacket)
|
||||
translatedReturn := manager.translateInboundPortDNAT(returnPacket, d2, peerB, peerA)
|
||||
require.False(t, translatedReturn, "Return traffic from same IP should not be translated")
|
||||
}
|
||||
|
||||
// TestPortDNATMultipleRules tests multiple port DNAT rules
|
||||
func TestPortDNATMultipleRules(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define peer IPs
|
||||
peerA := netip.MustParseAddr("100.10.0.50")
|
||||
peerB := netip.MustParseAddr("100.10.0.51")
|
||||
|
||||
// Add SSH port redirection rules for both peers
|
||||
err = manager.addPortRedirection(peerA, layers.LayerTypeTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test traffic to peer B gets translated
|
||||
packetToB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
|
||||
d1 := parsePacket(t, packetToB)
|
||||
translatedToB := manager.translateInboundPortDNAT(packetToB, d1, peerA, peerB)
|
||||
require.True(t, translatedToB, "Traffic to peer B should be translated")
|
||||
d1 = parsePacket(t, packetToB)
|
||||
require.Equal(t, uint16(22022), uint16(d1.tcp.DstPort), "Port should be 22022")
|
||||
|
||||
// Test traffic to peer A gets translated
|
||||
packetToA := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22)
|
||||
d2 := parsePacket(t, packetToA)
|
||||
translatedToA := manager.translateInboundPortDNAT(packetToA, d2, peerB, peerA)
|
||||
require.True(t, translatedToA, "Traffic to peer A should be translated")
|
||||
d2 = parsePacket(t, packetToA)
|
||||
require.Equal(t, uint16(22022), uint16(d2.tcp.DstPort), "Port should be 22022")
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"time"
|
||||
@@ -12,7 +11,6 @@ import (
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
@@ -20,9 +18,6 @@ import (
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
|
||||
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
|
||||
|
||||
// Backoff returns a backoff configuration for gRPC calls
|
||||
func Backoff(ctx context.Context) backoff.BackOff {
|
||||
b := backoff.NewExponentialBackOff()
|
||||
@@ -31,26 +26,6 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(b, ctx)
|
||||
}
|
||||
|
||||
// waitForConnectionReady blocks until the connection becomes ready or fails.
|
||||
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
|
||||
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
|
||||
conn.Connect()
|
||||
|
||||
state := conn.GetState()
|
||||
for state != connectivity.Ready && state != connectivity.Shutdown {
|
||||
if !conn.WaitForStateChange(ctx, state) {
|
||||
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
|
||||
}
|
||||
state = conn.GetState()
|
||||
}
|
||||
|
||||
if state == connectivity.Shutdown {
|
||||
return ErrConnectionShutdown
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||
@@ -68,25 +43,22 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
||||
}))
|
||||
}
|
||||
|
||||
conn, err := grpc.NewClient(
|
||||
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
connCtx,
|
||||
addr,
|
||||
transportOption,
|
||||
WithCustomDialer(tlsEnabled, component),
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 30 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new client: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := waitForConnectionReady(ctx, conn); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("dial context: %w", err)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -9,13 +10,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
// keep darwin compatibility
|
||||
@@ -40,7 +41,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||
addr := "100.64.0.1/8"
|
||||
wgPort := 33100
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -123,7 +124,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
|
||||
func Test_CreateInterface(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
|
||||
wgIP := "10.99.99.1/32"
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -166,7 +167,7 @@ func Test_Close(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||
wgIP := "10.99.99.2/32"
|
||||
wgPort := 33100
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -211,7 +212,7 @@ func TestRecreation(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||
wgIP := "10.99.99.2/32"
|
||||
wgPort := 33100
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -284,7 +285,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
|
||||
wgIP := "10.99.99.5/30"
|
||||
wgPort := 33100
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -339,7 +340,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
||||
func Test_UpdatePeer(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||
wgIP := "10.99.99.9/30"
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -409,7 +410,7 @@ func Test_UpdatePeer(t *testing.T) {
|
||||
func Test_RemovePeer(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||
wgIP := "10.99.99.13/30"
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -471,7 +472,7 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
peer2wgPort := 33200
|
||||
|
||||
keepAlive := 1 * time.Second
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -514,7 +515,7 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
guid = fmt.Sprintf("{%s}", uuid.New().String())
|
||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||
|
||||
newNet, err = stdnet.NewNet()
|
||||
newNet, err = stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package udpmux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -12,8 +13,9 @@ import (
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/stun/v3"
|
||||
"github.com/pion/transport/v3"
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
/*
|
||||
@@ -199,7 +201,7 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() {
|
||||
if len(networks) > 0 {
|
||||
if m.params.Net == nil {
|
||||
var err error
|
||||
if m.params.Net, err = stdnet.NewNet(); err != nil {
|
||||
if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil {
|
||||
m.params.Logger.Errorf("failed to get create network: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
@@ -83,22 +82,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
rules := networkMap.FirewallRules
|
||||
|
||||
enableSSH := networkMap.PeerConfig != nil &&
|
||||
networkMap.PeerConfig.SshConfig != nil &&
|
||||
networkMap.PeerConfig.SshConfig.SshEnabled
|
||||
|
||||
// If SSH enabled, add default firewall rule which accepts connection to any peer
|
||||
// in the network by SSH (TCP port defined by ssh.DefaultSSHPort).
|
||||
if enableSSH {
|
||||
rules = append(rules, &mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: strconv.Itoa(ssh.DefaultSSHPort),
|
||||
})
|
||||
}
|
||||
|
||||
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
||||
// we have old version of management without rules handling, we should allow all traffic
|
||||
if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty {
|
||||
|
||||
@@ -272,70 +272,3 @@ func TestPortInfoEmpty(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
PeerConfig: &mgmProto.PeerConfig{
|
||||
SshConfig: &mgmProto.SSHConfig{
|
||||
SshEnabled: true,
|
||||
},
|
||||
},
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{AllowedIps: []string{"10.93.0.1"}},
|
||||
{AllowedIps: []string{"10.93.0.2"}},
|
||||
{AllowedIps: []string{"10.93.0.3"}},
|
||||
},
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
expectedRules := 3
|
||||
if fw.IsStateful() {
|
||||
expectedRules = 3 // 2 inbound rules + SSH rule
|
||||
}
|
||||
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||
}
|
||||
|
||||
@@ -128,9 +128,34 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow
|
||||
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
||||
}
|
||||
|
||||
if d.providerConfig.LoginHint != "" {
|
||||
deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint)
|
||||
if deviceCode.VerificationURI != "" {
|
||||
deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint)
|
||||
}
|
||||
}
|
||||
|
||||
return deviceCode, err
|
||||
}
|
||||
|
||||
func appendLoginHint(uri, loginHint string) string {
|
||||
if uri == "" || loginHint == "" {
|
||||
return uri
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
log.Debugf("failed to parse verification URI for login_hint: %v", err)
|
||||
return uri
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
query.Set("login_hint", loginHint)
|
||||
parsedURL.RawQuery = query.Encode()
|
||||
|
||||
return parsedURL.String()
|
||||
}
|
||||
|
||||
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
|
||||
form := url.Values{}
|
||||
form.Add("client_id", d.providerConfig.ClientID)
|
||||
|
||||
@@ -66,32 +66,34 @@ func (t TokenInfo) GetTokenToUse() string {
|
||||
// and if that also fails, the authentication process is deemed unsuccessful
|
||||
//
|
||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
|
||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) {
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||
}
|
||||
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint)
|
||||
if err != nil {
|
||||
// fallback to device code flow
|
||||
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
|
||||
log.Debug("falling back to device code flow")
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||
}
|
||||
return pkceFlow, nil
|
||||
}
|
||||
|
||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
|
||||
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
switch s, ok := gstatus.FromError(err); {
|
||||
@@ -107,5 +109,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
||||
}
|
||||
}
|
||||
|
||||
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
@@ -109,6 +109,9 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||
}
|
||||
}
|
||||
if p.providerConfig.LoginHint != "" {
|
||||
params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint))
|
||||
}
|
||||
|
||||
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||
|
||||
@@ -189,17 +192,20 @@ func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token,
|
||||
|
||||
if authError := query.Get(queryError); authError != "" {
|
||||
authErrorDesc := query.Get(queryErrorDesc)
|
||||
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
||||
if authErrorDesc != "" {
|
||||
return nil, fmt.Errorf("authentication failed: %s", authErrorDesc)
|
||||
}
|
||||
return nil, fmt.Errorf("authentication failed: %s", authError)
|
||||
}
|
||||
|
||||
// Prevent timing attacks on the state
|
||||
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
||||
return nil, fmt.Errorf("invalid state")
|
||||
return nil, fmt.Errorf("authentication failed: Invalid state")
|
||||
}
|
||||
|
||||
code := query.Get(queryCode)
|
||||
if code == "" {
|
||||
return nil, fmt.Errorf("missing code")
|
||||
return nil, fmt.Errorf("authentication failed: missing code")
|
||||
}
|
||||
|
||||
return p.oAuthConfig.Exchange(
|
||||
@@ -228,7 +234,7 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo,
|
||||
}
|
||||
|
||||
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||
return TokenInfo{}, fmt.Errorf("authentication failed: invalid access token - %w", err)
|
||||
}
|
||||
|
||||
email, err := parseEmailFromIDToken(tokenInfo.IDToken)
|
||||
|
||||
@@ -52,7 +52,6 @@ func NewConnectClient(
|
||||
ctx context.Context,
|
||||
config *profilemanager.Config,
|
||||
statusRecorder *peer.Status,
|
||||
|
||||
) *ConnectClient {
|
||||
return &ConnectClient{
|
||||
ctx: ctx,
|
||||
@@ -63,8 +62,8 @@ func NewConnectClient(
|
||||
}
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}) error {
|
||||
return c.run(MobileDependency{}, runningChan)
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
||||
return c.run(MobileDependency{}, runningChan, logPath)
|
||||
}
|
||||
|
||||
// RunOnAndroid with main logic on mobile system
|
||||
@@ -83,7 +82,7 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return c.run(mobileDependency, nil)
|
||||
return c.run(mobileDependency, nil, "")
|
||||
}
|
||||
|
||||
func (c *ConnectClient) RunOniOS(
|
||||
@@ -101,10 +100,10 @@ func (c *ConnectClient) RunOniOS(
|
||||
DnsManager: dnsManager,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil)
|
||||
return c.run(mobileDependency, nil, "")
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
rec := c.statusRecorder
|
||||
@@ -247,7 +246,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
relayURLs, token := parseRelayInfo(loginResp)
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
@@ -271,7 +270,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
||||
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, c.config)
|
||||
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
@@ -410,26 +409,31 @@ func (c *ConnectClient) SetSyncResponsePersistence(enabled bool) {
|
||||
}
|
||||
|
||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
||||
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig, logPath string) (*EngineConfig, error) {
|
||||
nm := false
|
||||
if config.NetworkMonitor != nil {
|
||||
nm = *config.NetworkMonitor
|
||||
}
|
||||
engineConf := &EngineConfig{
|
||||
WgIfaceName: config.WgIface,
|
||||
WgAddr: peerConfig.Address,
|
||||
IFaceBlackList: config.IFaceBlackList,
|
||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||
WgPrivateKey: key,
|
||||
WgPort: config.WgPort,
|
||||
NetworkMonitor: nm,
|
||||
SSHKey: []byte(config.SSHKey),
|
||||
NATExternalIPs: config.NATExternalIPs,
|
||||
CustomDNSAddress: config.CustomDNSAddress,
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
WgIfaceName: config.WgIface,
|
||||
WgAddr: peerConfig.Address,
|
||||
IFaceBlackList: config.IFaceBlackList,
|
||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||
WgPrivateKey: key,
|
||||
WgPort: config.WgPort,
|
||||
NetworkMonitor: nm,
|
||||
SSHKey: []byte(config.SSHKey),
|
||||
NATExternalIPs: config.NATExternalIPs,
|
||||
CustomDNSAddress: config.CustomDNSAddress,
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
EnableSSHRoot: config.EnableSSHRoot,
|
||||
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: config.DisableSSHAuth,
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
|
||||
DisableClientRoutes: config.DisableClientRoutes,
|
||||
DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
|
||||
@@ -440,7 +444,10 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
|
||||
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||
|
||||
MTU: selectMTU(config.MTU, peerConfig.Mtu),
|
||||
MTU: selectMTU(config.MTU, peerConfig.Mtu),
|
||||
LogPath: logPath,
|
||||
|
||||
ProfileConfig: config,
|
||||
}
|
||||
|
||||
if config.PreSharedKey != "" {
|
||||
@@ -515,6 +522,11 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
|
||||
@@ -27,8 +27,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const readmeContent = `Netbird debug bundle
|
||||
@@ -44,6 +46,8 @@ interfaces.txt: Anonymized network interface information, if --system-info flag
|
||||
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
|
||||
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
|
||||
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
|
||||
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
|
||||
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
|
||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||
@@ -184,6 +188,20 @@ The ip_rules.txt file contains detailed IP routing rule information:
|
||||
The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing.
|
||||
|
||||
For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged.
|
||||
|
||||
DNS Configuration
|
||||
The debug bundle includes platform-specific DNS configuration files:
|
||||
|
||||
resolv.conf (Unix systems):
|
||||
- Contains DNS resolver configuration from /etc/resolv.conf
|
||||
- Includes nameserver entries, search domains, and resolver options
|
||||
- All IP addresses and domain names are anonymized following the same rules as other files
|
||||
|
||||
scutil_dns.txt (macOS only):
|
||||
- Contains detailed DNS configuration from scutil --dns
|
||||
- Shows DNS configuration for all network interfaces
|
||||
- Includes search domains, nameservers, and DNS resolver settings
|
||||
- All IP addresses and domain names are anonymized
|
||||
`
|
||||
|
||||
const (
|
||||
@@ -202,10 +220,9 @@ type BundleGenerator struct {
|
||||
internalConfig *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
syncResponse *mgmProto.SyncResponse
|
||||
logFile string
|
||||
logPath string
|
||||
|
||||
anonymize bool
|
||||
clientStatus string
|
||||
includeSystemInfo bool
|
||||
logFileCount uint32
|
||||
|
||||
@@ -214,7 +231,6 @@ type BundleGenerator struct {
|
||||
|
||||
type BundleConfig struct {
|
||||
Anonymize bool
|
||||
ClientStatus string
|
||||
IncludeSystemInfo bool
|
||||
LogFileCount uint32
|
||||
}
|
||||
@@ -223,7 +239,7 @@ type GeneratorDependencies struct {
|
||||
InternalConfig *profilemanager.Config
|
||||
StatusRecorder *peer.Status
|
||||
SyncResponse *mgmProto.SyncResponse
|
||||
LogFile string
|
||||
LogPath string
|
||||
}
|
||||
|
||||
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
||||
@@ -239,10 +255,9 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
internalConfig: deps.InternalConfig,
|
||||
statusRecorder: deps.StatusRecorder,
|
||||
syncResponse: deps.SyncResponse,
|
||||
logFile: deps.LogFile,
|
||||
logPath: deps.LogPath,
|
||||
|
||||
anonymize: cfg.Anonymize,
|
||||
clientStatus: cfg.ClientStatus,
|
||||
includeSystemInfo: cfg.IncludeSystemInfo,
|
||||
logFileCount: logFileCount,
|
||||
}
|
||||
@@ -288,13 +303,6 @@ func (g *BundleGenerator) createArchive() error {
|
||||
return fmt.Errorf("add status: %w", err)
|
||||
}
|
||||
|
||||
if g.statusRecorder != nil {
|
||||
status := g.statusRecorder.GetFullStatus()
|
||||
seedFromStatus(g.anonymizer, &status)
|
||||
} else {
|
||||
log.Debugf("no status recorder available for seeding")
|
||||
}
|
||||
|
||||
if err := g.addConfig(); err != nil {
|
||||
log.Errorf("failed to add config to debug bundle: %v", err)
|
||||
}
|
||||
@@ -327,7 +335,7 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add wg show output: %v", err)
|
||||
}
|
||||
|
||||
if g.logFile != "" && !slices.Contains(util.SpecialLogs, g.logFile) {
|
||||
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
||||
if err := g.addLogfile(); err != nil {
|
||||
log.Errorf("failed to add log file to debug bundle: %v", err)
|
||||
if err := g.trySystemdLogFallback(); err != nil {
|
||||
@@ -357,6 +365,10 @@ func (g *BundleGenerator) addSystemInfo() {
|
||||
if err := g.addFirewallRules(); err != nil {
|
||||
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addDNSInfo(); err != nil {
|
||||
log.Errorf("failed to add DNS info to debug bundle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addReadme() error {
|
||||
@@ -368,11 +380,26 @@ func (g *BundleGenerator) addReadme() error {
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStatus() error {
|
||||
if status := g.clientStatus; status != "" {
|
||||
statusReader := strings.NewReader(status)
|
||||
if g.statusRecorder != nil {
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
fullStatus := g.statusRecorder.GetFullStatus()
|
||||
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
||||
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
|
||||
statusOutput := nbstatus.ParseToFullDetailSummary(overview)
|
||||
|
||||
statusReader := strings.NewReader(statusOutput)
|
||||
if err := g.addFileToZip(statusReader, "status.txt"); err != nil {
|
||||
return fmt.Errorf("add status file to zip: %w", err)
|
||||
}
|
||||
seedFromStatus(g.anonymizer, &fullStatus)
|
||||
} else {
|
||||
log.Debugf("no status recorder available for seeding")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -433,6 +460,18 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
if g.internalConfig.ServerSSHAllowed != nil {
|
||||
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
|
||||
}
|
||||
if g.internalConfig.EnableSSHRoot != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHRoot: %v\n", *g.internalConfig.EnableSSHRoot))
|
||||
}
|
||||
if g.internalConfig.EnableSSHSFTP != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHSFTP: %v\n", *g.internalConfig.EnableSSHSFTP))
|
||||
}
|
||||
if g.internalConfig.EnableSSHLocalPortForwarding != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHLocalPortForwarding: %v\n", *g.internalConfig.EnableSSHLocalPortForwarding))
|
||||
}
|
||||
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||
@@ -630,14 +669,14 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addLogfile() error {
|
||||
if g.logFile == "" {
|
||||
if g.logPath == "" {
|
||||
log.Debugf("skipping empty log file in debug bundle")
|
||||
return nil
|
||||
}
|
||||
|
||||
logDir := filepath.Dir(g.logFile)
|
||||
logDir := filepath.Dir(g.logPath)
|
||||
|
||||
if err := g.addSingleLogfile(g.logFile, clientLogFile); err != nil {
|
||||
if err := g.addSingleLogfile(g.logPath, clientLogFile); err != nil {
|
||||
return fmt.Errorf("add client log file to zip: %w", err)
|
||||
}
|
||||
|
||||
|
||||
53
client/internal/debug/debug_darwin.go
Normal file
53
client/internal/debug/debug_darwin.go
Normal file
@@ -0,0 +1,53 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// addDNSInfo collects and adds DNS configuration information to the archive
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
if err := g.addResolvConf(); err != nil {
|
||||
log.Errorf("failed to add resolv.conf: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addScutilDNS(); err != nil {
|
||||
log.Errorf("failed to add scutil DNS output: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addScutilDNS() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "scutil", "--dns")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("execute scutil --dns: %w", err)
|
||||
}
|
||||
|
||||
if len(bytes.TrimSpace(output)) == 0 {
|
||||
return fmt.Errorf("no scutil DNS output")
|
||||
}
|
||||
|
||||
content := string(output)
|
||||
if g.anonymize {
|
||||
content = g.anonymizer.AnonymizeString(content)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(strings.NewReader(content), "scutil_dns.txt"); err != nil {
|
||||
return fmt.Errorf("add scutil DNS output to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -5,3 +5,7 @@ package debug
|
||||
func (g *BundleGenerator) addRoutes() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
16
client/internal/debug/debug_nondarwin.go
Normal file
16
client/internal/debug/debug_nondarwin.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build unix && !darwin && !android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// addDNSInfo collects and adds DNS configuration information to the archive
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
if err := g.addResolvConf(); err != nil {
|
||||
log.Errorf("failed to add resolv.conf: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
7
client/internal/debug/debug_nonunix.go
Normal file
7
client/internal/debug/debug_nonunix.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !unix
|
||||
|
||||
package debug
|
||||
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
return nil
|
||||
}
|
||||
29
client/internal/debug/debug_unix.go
Normal file
29
client/internal/debug/debug_unix.go
Normal file
@@ -0,0 +1,29 @@
|
||||
//go:build unix && !android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const resolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
func (g *BundleGenerator) addResolvConf() error {
|
||||
data, err := os.ReadFile(resolvConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", resolvConfPath, err)
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
if g.anonymize {
|
||||
content = g.anonymizer.AnonymizeString(content)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(strings.NewReader(content), "resolv.conf"); err != nil {
|
||||
return fmt.Errorf("add resolv.conf to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
101
client/internal/debug/upload.go
Normal file
101
client/internal/debug/upload.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
const maxBundleUploadSize = 50 * 1024 * 1024
|
||||
|
||||
func UploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
|
||||
response, err := getUploadURL(ctx, url, managementURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = upload(ctx, filePath, response)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return response.Key, nil
|
||||
}
|
||||
|
||||
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
|
||||
fileData, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open file: %w", err)
|
||||
}
|
||||
|
||||
defer fileData.Close()
|
||||
|
||||
stat, err := fileData.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat file: %w", err)
|
||||
}
|
||||
|
||||
if stat.Size() > maxBundleUploadSize {
|
||||
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create PUT request: %w", err)
|
||||
}
|
||||
|
||||
req.ContentLength = stat.Size()
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
|
||||
putResp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upload failed: %v", err)
|
||||
}
|
||||
defer putResp.Body.Close()
|
||||
|
||||
if putResp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(putResp.Body)
|
||||
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
|
||||
id := getURLHash(managementURL)
|
||||
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create GET request: %w", err)
|
||||
}
|
||||
|
||||
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
|
||||
|
||||
resp, err := http.DefaultClient.Do(getReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get presigned URL: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
urlBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response body: %w", err)
|
||||
}
|
||||
var response types.GetURLResponse
|
||||
if err := json.Unmarshal(urlBytes, &response); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func getURLHash(url string) string {
|
||||
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package debug
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -38,7 +38,7 @@ func TestUpload(t *testing.T) {
|
||||
fileContent := []byte("test file content")
|
||||
err := os.WriteFile(file, fileContent, 0640)
|
||||
require.NoError(t, err)
|
||||
key, err := uploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
|
||||
key, err := UploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
|
||||
require.NoError(t, err)
|
||||
id := getURLHash(testURL)
|
||||
require.Contains(t, key, id+"/")
|
||||
@@ -38,6 +38,8 @@ type DeviceAuthProviderConfig struct {
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
||||
|
||||
@@ -335,7 +335,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
privKey, _ := wgtypes.GenerateKey()
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -434,7 +434,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||
if err != nil {
|
||||
t.Errorf("create stdnet: %v", err)
|
||||
return
|
||||
@@ -915,7 +915,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||
if err != nil {
|
||||
t.Fatalf("create stdnet: %v", err)
|
||||
return nil, err
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
@@ -30,9 +29,9 @@ import (
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
@@ -50,11 +49,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/jobexec"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -115,7 +115,12 @@ type EngineConfig struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
|
||||
ServerSSHAllowed bool
|
||||
ServerSSHAllowed bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
|
||||
DNSRouteInterval time.Duration
|
||||
|
||||
@@ -129,6 +134,11 @@ type EngineConfig struct {
|
||||
LazyConnectionEnabled bool
|
||||
|
||||
MTU uint16
|
||||
|
||||
// for debug bundle generation
|
||||
ProfileConfig *profilemanager.Config
|
||||
|
||||
LogPath string
|
||||
}
|
||||
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
@@ -148,8 +158,6 @@ type Engine struct {
|
||||
|
||||
// syncMsgMux is used to guarantee sequential Management Service message processing
|
||||
syncMsgMux *sync.Mutex
|
||||
// sshMux protects sshServer field access
|
||||
sshMux sync.Mutex
|
||||
|
||||
config *EngineConfig
|
||||
mobileDep MobileDependency
|
||||
@@ -175,8 +183,7 @@ type Engine struct {
|
||||
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
||||
sshServer nbssh.Server
|
||||
sshServer sshServer
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -195,7 +202,8 @@ type Engine struct {
|
||||
stateManager *statemanager.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
|
||||
// Sync response persistence
|
||||
// Sync response persistence (protected by syncRespMux)
|
||||
syncRespMux sync.RWMutex
|
||||
persistSyncResponse bool
|
||||
latestSyncResponse *mgmProto.SyncResponse
|
||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||
@@ -208,6 +216,9 @@ type Engine struct {
|
||||
shutdownWg sync.WaitGroup
|
||||
|
||||
probeStunTurn *relay.StunTurnProbe
|
||||
|
||||
jobExecutor *jobexec.Executor
|
||||
jobExecutorWG sync.WaitGroup
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -221,17 +232,7 @@ type localIpUpdater interface {
|
||||
}
|
||||
|
||||
// NewEngine creates a new Connection Engine with probes attached
|
||||
func NewEngine(
|
||||
clientCtx context.Context,
|
||||
clientCancel context.CancelFunc,
|
||||
signalClient signal.Client,
|
||||
mgmClient mgm.Client,
|
||||
relayManager *relayClient.Manager,
|
||||
config *EngineConfig,
|
||||
mobileDep MobileDependency,
|
||||
statusRecorder *peer.Status,
|
||||
checks []*mgmProto.Checks,
|
||||
) *Engine {
|
||||
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, c *profilemanager.Config) *Engine {
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
@@ -246,11 +247,11 @@ func NewEngine(
|
||||
STUNs: []*stun.URI{},
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
sshServerFunc: nbssh.DefaultSSHServer,
|
||||
statusRecorder: statusRecorder,
|
||||
checks: checks,
|
||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
}
|
||||
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
@@ -268,6 +269,7 @@ func NewEngine(
|
||||
path = mobileDep.StateFilePath
|
||||
}
|
||||
engine.stateManager = statemanager.New(path)
|
||||
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
return engine
|
||||
@@ -292,6 +294,12 @@ func (e *Engine) Stop() error {
|
||||
}
|
||||
log.Info("Network monitor: stopped")
|
||||
|
||||
if err := e.stopSSHServer(); err != nil {
|
||||
log.Warnf("failed to stop SSH server: %v", err)
|
||||
}
|
||||
|
||||
e.cleanupSSHConfig()
|
||||
|
||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||
e.stopDNSServer()
|
||||
|
||||
@@ -324,6 +332,8 @@ func (e *Engine) Stop() error {
|
||||
e.cancel()
|
||||
}
|
||||
|
||||
e.jobExecutorWG.Wait() // block until job goroutines finish
|
||||
|
||||
e.close()
|
||||
|
||||
// stop flow manager after wg interface is gone
|
||||
@@ -510,6 +520,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
e.receiveSignalEvents()
|
||||
e.receiveManagementEvents()
|
||||
e.receiveJobEvents()
|
||||
|
||||
// starting network monitor at the very last to avoid disruptions
|
||||
e.startNetworkMonitor()
|
||||
@@ -703,16 +714,10 @@ func (e *Engine) removeAllPeers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// removePeer closes an existing peer connection, removes a peer, and clears authorized key of the SSH server
|
||||
// removePeer closes an existing peer connection and removes a peer
|
||||
func (e *Engine) removePeer(peerKey string) error {
|
||||
log.Debugf("removing peer from engine %s", peerKey)
|
||||
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
e.sshServer.RemoveAuthorizedKey(peerKey)
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
|
||||
e.connMgr.RemovePeerConn(peerKey)
|
||||
|
||||
err := e.statusRecorder.RemovePeer(peerKey)
|
||||
@@ -793,9 +798,18 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
|
||||
// Read the storage-enabled flag under the syncRespMux too.
|
||||
e.syncRespMux.RLock()
|
||||
enabled := e.persistSyncResponse
|
||||
e.syncRespMux.RUnlock()
|
||||
|
||||
// Store sync response if persistence is enabled
|
||||
if e.persistSyncResponse {
|
||||
if enabled {
|
||||
e.syncRespMux.Lock()
|
||||
e.latestSyncResponse = update
|
||||
e.syncRespMux.Unlock()
|
||||
|
||||
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
|
||||
}
|
||||
|
||||
@@ -884,6 +898,11 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
@@ -893,74 +912,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func isNil(server nbssh.Server) bool {
|
||||
return server == nil || reflect.ValueOf(server).IsNil()
|
||||
}
|
||||
|
||||
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
if e.config.BlockInbound {
|
||||
log.Infof("SSH server is disabled because inbound connections are blocked")
|
||||
return nil
|
||||
}
|
||||
|
||||
if !e.config.ServerSSHAllowed {
|
||||
log.Info("SSH server is not enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
if sshConf.GetSshEnabled() {
|
||||
if runtime.GOOS == "windows" {
|
||||
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
e.sshMux.Lock()
|
||||
// start SSH server if it wasn't running
|
||||
if isNil(e.sshServer) {
|
||||
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
|
||||
if nbnetstack.IsEnabled() {
|
||||
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
|
||||
}
|
||||
// nil sshServer means it has not yet been started
|
||||
server, err := e.sshServerFunc(e.config.SSHKey, listenAddr)
|
||||
if err != nil {
|
||||
e.sshMux.Unlock()
|
||||
return fmt.Errorf("create ssh server: %w", err)
|
||||
}
|
||||
|
||||
e.sshServer = server
|
||||
e.sshMux.Unlock()
|
||||
|
||||
go func() {
|
||||
// blocking
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
// will throw error when we stop it even if it is a graceful stop
|
||||
log.Debugf("stopped SSH server with error %v", err)
|
||||
}
|
||||
e.sshMux.Lock()
|
||||
e.sshServer = nil
|
||||
e.sshMux.Unlock()
|
||||
log.Infof("stopped SSH server")
|
||||
}()
|
||||
} else {
|
||||
e.sshMux.Unlock()
|
||||
log.Debugf("SSH server is already running")
|
||||
}
|
||||
} else {
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
// Disable SSH server request, so stop it if it was running
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed to stop SSH server %v", err)
|
||||
}
|
||||
e.sshServer = nil
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wireguard interface is not initialized")
|
||||
@@ -973,8 +924,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
}
|
||||
|
||||
if conf.GetSshConfig() != nil {
|
||||
err := e.updateSSH(conf.GetSshConfig())
|
||||
if err != nil {
|
||||
if err := e.updateSSH(conf.GetSshConfig()); err != nil {
|
||||
log.Warnf("failed handling SSH server setup: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -989,6 +939,77 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
func (e *Engine) receiveJobEvents() {
|
||||
e.jobExecutorWG.Add(1)
|
||||
go func() {
|
||||
defer e.jobExecutorWG.Done()
|
||||
err := e.mgmClient.Job(e.ctx, func(msg *mgmProto.JobRequest) *mgmProto.JobResponse {
|
||||
resp := mgmProto.JobResponse{
|
||||
ID: msg.ID,
|
||||
Status: mgmProto.JobStatus_failed,
|
||||
}
|
||||
switch params := msg.WorkloadParameters.(type) {
|
||||
case *mgmProto.JobRequest_Bundle:
|
||||
bundleResult, err := e.handleBundle(params.Bundle)
|
||||
if err != nil {
|
||||
log.Errorf("handling bundle: %v", err)
|
||||
resp.Reason = []byte(err.Error())
|
||||
return &resp
|
||||
}
|
||||
resp.Status = mgmProto.JobStatus_succeeded
|
||||
resp.WorkloadResults = bundleResult
|
||||
return &resp
|
||||
default:
|
||||
resp.Reason = []byte(jobexec.ErrJobNotImplemented.Error())
|
||||
return &resp
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
e.clientCancel()
|
||||
return
|
||||
}
|
||||
log.Info("stopped receiving jobs from Management Service")
|
||||
}()
|
||||
log.Info("connecting to Management Service jobs stream")
|
||||
}
|
||||
|
||||
func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobResponse_Bundle, error) {
|
||||
log.Infof("handle remote debug bundle request: %s", params.String())
|
||||
syncResponse, err := e.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
log.Warnf("get latest sync response: %v", err)
|
||||
}
|
||||
|
||||
bundleDeps := debug.GeneratorDependencies{
|
||||
InternalConfig: e.config.ProfileConfig,
|
||||
StatusRecorder: e.statusRecorder,
|
||||
SyncResponse: syncResponse,
|
||||
LogPath: e.config.LogPath,
|
||||
}
|
||||
|
||||
bundleJobParams := debug.BundleConfig{
|
||||
Anonymize: params.Anonymize,
|
||||
IncludeSystemInfo: true,
|
||||
LogFileCount: uint32(params.LogFileCount),
|
||||
}
|
||||
|
||||
waitFor := time.Duration(params.BundleForTime) * time.Minute
|
||||
|
||||
uploadKey, err := e.jobExecutor.BundleJob(e.ctx, bundleDeps, bundleJobParams, waitFor, e.config.ProfileConfig.ManagementURL.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := &mgmProto.JobResponse_Bundle{
|
||||
Bundle: &mgmProto.BundleResult{
|
||||
UploadKey: uploadKey,
|
||||
},
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||
@@ -1012,6 +1033,11 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
@@ -1170,19 +1196,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
|
||||
// update SSHServer by adding remote peer SSH keys
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
for _, config := range networkMap.GetRemotePeers() {
|
||||
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
|
||||
err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey()))
|
||||
if err != nil {
|
||||
log.Warnf("failed adding authorized key to SSH DefaultServer %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
|
||||
|
||||
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
|
||||
log.Warnf("failed to update SSH client config: %v", err)
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
}
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
@@ -1259,7 +1277,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
|
||||
}
|
||||
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
||||
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
|
||||
forwarderPort := uint16(protoDNSConfig.GetForwarderPort()) //nolint
|
||||
if forwarderPort == 0 {
|
||||
forwarderPort = nbdns.ForwarderClientPort
|
||||
}
|
||||
@@ -1544,15 +1562,6 @@ func (e *Engine) close() {
|
||||
e.statusRecorder.SetWgIface(nil)
|
||||
}
|
||||
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed stopping the SSH server: %v", err)
|
||||
}
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
|
||||
if e.firewall != nil {
|
||||
err := e.firewall.Close(e.stateManager)
|
||||
if err != nil {
|
||||
@@ -1583,6 +1592,11 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||
@@ -1856,8 +1870,8 @@ func (e *Engine) stopDNSServer() {
|
||||
|
||||
// SetSyncResponsePersistence enables or disables sync response persistence
|
||||
func (e *Engine) SetSyncResponsePersistence(enabled bool) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
e.syncRespMux.Lock()
|
||||
defer e.syncRespMux.Unlock()
|
||||
|
||||
if enabled == e.persistSyncResponse {
|
||||
return
|
||||
@@ -1872,20 +1886,22 @@ func (e *Engine) SetSyncResponsePersistence(enabled bool) {
|
||||
|
||||
// GetLatestSyncResponse returns the stored sync response if persistence is enabled
|
||||
func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
e.syncRespMux.RLock()
|
||||
enabled := e.persistSyncResponse
|
||||
latest := e.latestSyncResponse
|
||||
e.syncRespMux.RUnlock()
|
||||
|
||||
if !e.persistSyncResponse {
|
||||
if !enabled {
|
||||
return nil, errors.New("sync response persistence is disabled")
|
||||
}
|
||||
|
||||
if e.latestSyncResponse == nil {
|
||||
if latest == nil {
|
||||
//nolint:nilnil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(e.latestSyncResponse))
|
||||
sr, ok := proto.Clone(e.latestSyncResponse).(*mgmProto.SyncResponse)
|
||||
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(latest))
|
||||
sr, ok := proto.Clone(latest).(*mgmProto.SyncResponse)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to clone sync response")
|
||||
}
|
||||
|
||||
355
client/internal/engine_ssh.go
Normal file
355
client/internal/engine_ssh.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type sshServer interface {
|
||||
Start(ctx context.Context, addr netip.AddrPort) error
|
||||
Stop() error
|
||||
GetStatus() (bool, []sshserver.SessionInfo)
|
||||
}
|
||||
|
||||
func (e *Engine) setupSSHPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
|
||||
return fmt.Errorf("add SSH port redirection: %w", err)
|
||||
}
|
||||
log.Infof("SSH port redirection enabled: %s:22 -> %s:22022", localAddr, localAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
if e.config.BlockInbound {
|
||||
log.Info("SSH server is disabled because inbound connections are blocked")
|
||||
return e.stopSSHServer()
|
||||
}
|
||||
|
||||
if !e.config.ServerSSHAllowed {
|
||||
log.Info("SSH server is disabled in config")
|
||||
return e.stopSSHServer()
|
||||
}
|
||||
|
||||
if !sshConf.GetSshEnabled() {
|
||||
if e.config.ServerSSHAllowed {
|
||||
log.Info("SSH server is locally allowed but disabled by management server")
|
||||
}
|
||||
return e.stopSSHServer()
|
||||
}
|
||||
|
||||
if e.sshServer != nil {
|
||||
log.Debug("SSH server is already running")
|
||||
return nil
|
||||
}
|
||||
|
||||
if e.config.DisableSSHAuth != nil && *e.config.DisableSSHAuth {
|
||||
log.Info("starting SSH server without JWT authentication (authentication disabled by config)")
|
||||
return e.startSSHServer(nil)
|
||||
}
|
||||
|
||||
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
|
||||
jwtConfig := &sshserver.JWTConfig{
|
||||
Issuer: protoJWT.GetIssuer(),
|
||||
Audience: protoJWT.GetAudience(),
|
||||
KeysLocation: protoJWT.GetKeysLocation(),
|
||||
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||
}
|
||||
|
||||
return e.startSSHServer(jwtConfig)
|
||||
}
|
||||
|
||||
return errors.New("SSH server requires valid JWT configuration")
|
||||
}
|
||||
|
||||
// updateSSHClientConfig updates the SSH client configuration with peer information
|
||||
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
||||
if len(peerInfo) == 0 {
|
||||
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
||||
return nil
|
||||
}
|
||||
|
||||
configMgr := sshconfig.New()
|
||||
if err := configMgr.SetupSSHClientConfig(peerInfo); err != nil {
|
||||
log.Warnf("failed to update SSH client config: %v", err)
|
||||
return nil // Don't fail engine startup on SSH config issues
|
||||
}
|
||||
|
||||
log.Debugf("updated SSH client config with %d peers", len(peerInfo))
|
||||
|
||||
if err := e.stateManager.UpdateState(&sshconfig.ShutdownState{
|
||||
SSHConfigDir: configMgr.GetSSHConfigDir(),
|
||||
SSHConfigFile: configMgr.GetSSHConfigFile(),
|
||||
}); err != nil {
|
||||
log.Warnf("failed to update SSH config state: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractPeerSSHInfo extracts SSH information from peer configurations
|
||||
func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerSSHInfo {
|
||||
var peerInfo []sshconfig.PeerSSHInfo
|
||||
|
||||
for _, peerConfig := range remotePeers {
|
||||
if peerConfig.GetSshConfig() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
|
||||
if len(sshPubKeyBytes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
peerIP := e.extractPeerIP(peerConfig)
|
||||
hostname := e.extractHostname(peerConfig)
|
||||
|
||||
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
|
||||
Hostname: hostname,
|
||||
IP: peerIP,
|
||||
FQDN: peerConfig.GetFqdn(),
|
||||
})
|
||||
}
|
||||
|
||||
return peerInfo
|
||||
}
|
||||
|
||||
// extractPeerIP extracts IP address from peer's allowed IPs
|
||||
func (e *Engine) extractPeerIP(peerConfig *mgmProto.RemotePeerConfig) string {
|
||||
if len(peerConfig.GetAllowedIps()) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if prefix, err := netip.ParsePrefix(peerConfig.GetAllowedIps()[0]); err == nil {
|
||||
return prefix.Addr().String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractHostname extracts short hostname from FQDN
|
||||
func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
|
||||
fqdn := peerConfig.GetFqdn()
|
||||
if fqdn == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := strings.Split(fqdn, ".")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access
|
||||
func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) {
|
||||
for _, peerConfig := range remotePeers {
|
||||
if peerConfig.GetSshConfig() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
|
||||
if len(sshPubKeyBytes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := e.statusRecorder.UpdatePeerSSHHostKey(peerConfig.GetWgPubKey(), sshPubKeyBytes); err != nil {
|
||||
log.Warnf("failed to update SSH host key for peer %s: %v", peerConfig.GetWgPubKey(), err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("updated peer SSH host keys for daemon API access")
|
||||
}
|
||||
|
||||
// GetPeerSSHKey returns the SSH host key for a specific peer by IP or FQDN
|
||||
func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
|
||||
e.syncMsgMux.Lock()
|
||||
statusRecorder := e.statusRecorder
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
if statusRecorder == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
fullStatus := statusRecorder.GetFullStatus()
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
|
||||
if len(peerState.SSHHostKey) > 0 {
|
||||
return peerState.SSHHostKey, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
||||
func (e *Engine) cleanupSSHConfig() {
|
||||
configMgr := sshconfig.New()
|
||||
|
||||
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
||||
log.Warnf("failed to remove SSH client config: %v", err)
|
||||
} else {
|
||||
log.Debugf("SSH client config cleanup completed")
|
||||
}
|
||||
}
|
||||
|
||||
// startSSHServer initializes and starts the SSH server with proper configuration.
|
||||
func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: e.config.SSHKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
|
||||
wgAddr := e.wgInterface.Address()
|
||||
server.SetNetworkValidation(wgAddr)
|
||||
|
||||
netbirdIP := wgAddr.IP
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort)
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
server.SetNetstackNet(netstackNet)
|
||||
}
|
||||
|
||||
e.configureSSHServer(server)
|
||||
|
||||
if err := server.Start(e.ctx, listenAddr); err != nil {
|
||||
return fmt.Errorf("start SSH server: %w", err)
|
||||
}
|
||||
|
||||
e.sshServer = server
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
|
||||
log.Debugf("registered SSH service with netstack for TCP:%d", sshserver.InternalSSHPort)
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.setupSSHPortRedirection(); err != nil {
|
||||
log.Warnf("failed to setup SSH port redirection: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// configureSSHServer applies SSH configuration options to the server.
|
||||
func (e *Engine) configureSSHServer(server *sshserver.Server) {
|
||||
if e.config.EnableSSHRoot != nil && *e.config.EnableSSHRoot {
|
||||
server.SetAllowRootLogin(true)
|
||||
log.Info("SSH root login enabled")
|
||||
} else {
|
||||
server.SetAllowRootLogin(false)
|
||||
log.Info("SSH root login disabled (default)")
|
||||
}
|
||||
|
||||
if e.config.EnableSSHSFTP != nil && *e.config.EnableSSHSFTP {
|
||||
server.SetAllowSFTP(true)
|
||||
log.Info("SSH SFTP subsystem enabled")
|
||||
} else {
|
||||
server.SetAllowSFTP(false)
|
||||
log.Info("SSH SFTP subsystem disabled (default)")
|
||||
}
|
||||
|
||||
if e.config.EnableSSHLocalPortForwarding != nil && *e.config.EnableSSHLocalPortForwarding {
|
||||
server.SetAllowLocalPortForwarding(true)
|
||||
log.Info("SSH local port forwarding enabled")
|
||||
} else {
|
||||
server.SetAllowLocalPortForwarding(false)
|
||||
log.Info("SSH local port forwarding disabled (default)")
|
||||
}
|
||||
|
||||
if e.config.EnableSSHRemotePortForwarding != nil && *e.config.EnableSSHRemotePortForwarding {
|
||||
server.SetAllowRemotePortForwarding(true)
|
||||
log.Info("SSH remote port forwarding enabled")
|
||||
} else {
|
||||
server.SetAllowRemotePortForwarding(false)
|
||||
log.Info("SSH remote port forwarding disabled (default)")
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) cleanupSSHPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
|
||||
return fmt.Errorf("remove SSH port redirection: %w", err)
|
||||
}
|
||||
log.Debugf("SSH port redirection removed: %s:22 -> %s:22022", localAddr, localAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) stopSSHServer() error {
|
||||
if e.sshServer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.cleanupSSHPortRedirection(); err != nil {
|
||||
log.Warnf("failed to cleanup SSH port redirection: %v", err)
|
||||
}
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
|
||||
log.Debugf("unregistered SSH service from netstack for TCP:%d", sshserver.InternalSSHPort)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("stopping SSH server")
|
||||
err := e.sshServer.Stop()
|
||||
e.sshServer = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSSHServerStatus returns the SSH server status and active sessions
|
||||
func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.SessionInfo) {
|
||||
e.syncMsgMux.Lock()
|
||||
sshServer := e.sshServer
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
if sshServer == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return sshServer.GetStatus()
|
||||
}
|
||||
@@ -7,5 +7,5 @@ import (
|
||||
)
|
||||
|
||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNet(e.config.IFaceBlackList)
|
||||
return stdnet.NewNet(e.clientCtx, e.config.IFaceBlackList)
|
||||
}
|
||||
|
||||
@@ -3,5 +3,5 @@ package internal
|
||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
|
||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
|
||||
return stdnet.NewNetWithDiscover(e.clientCtx, e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -25,8 +24,15 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
@@ -43,7 +49,7 @@ import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
@@ -211,11 +217,13 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
func TestEngine_SSH(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping TestEngine_SSH")
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -237,45 +245,19 @@ func TestEngine_SSH(t *testing.T) {
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
MTU: iface.DefaultMTU,
|
||||
SSHKey: sshKey,
|
||||
},
|
||||
MobileDependency{},
|
||||
peer.NewRecorder("https://mgm"),
|
||||
nil,
|
||||
nil, nil,
|
||||
)
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
var sshKeysAdded []string
|
||||
var sshPeersRemoved []string
|
||||
|
||||
sshCtx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
engine.sshServerFunc = func(hostKeyPEM []byte, addr string) (ssh.Server, error) {
|
||||
return &ssh.MockServer{
|
||||
Ctx: sshCtx,
|
||||
StopFunc: func() error {
|
||||
cancel()
|
||||
return nil
|
||||
},
|
||||
StartFunc: func() error {
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
},
|
||||
AddAuthorizedKeyFunc: func(peer, newKey string) error {
|
||||
sshKeysAdded = append(sshKeysAdded, newKey)
|
||||
return nil
|
||||
},
|
||||
RemoveAuthorizedKeyFunc: func(peer string) {
|
||||
sshPeersRemoved = append(sshPeersRemoved, peer)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
@@ -301,9 +283,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
|
||||
@@ -311,19 +291,24 @@ func TestEngine_SSH(t *testing.T) {
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 7,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{SshEnabled: true}},
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshEnabled: true,
|
||||
JwtConfig: &mgmtProto.JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeysLocation: "test-keys",
|
||||
MaxTokenAge: 3600,
|
||||
},
|
||||
}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
assert.Contains(t, sshKeysAdded, "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ")
|
||||
|
||||
// now remove peer
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
@@ -333,13 +318,10 @@ func TestEngine_SSH(t *testing.T) {
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
|
||||
|
||||
// now disable SSH server
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
@@ -351,12 +333,70 @@ func TestEngine_SSH(t *testing.T) {
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
}
|
||||
|
||||
func TestEngine_SSHUpdateLogic(t *testing.T) {
|
||||
// Test that SSH server start/stop logic works based on config
|
||||
engine := &Engine{
|
||||
config: &EngineConfig{
|
||||
ServerSSHAllowed: false, // Start with SSH disabled
|
||||
},
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
|
||||
// Test SSH disabled config
|
||||
sshConfig := &mgmtProto.SSHConfig{SshEnabled: false}
|
||||
err := engine.updateSSH(sshConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, engine.sshServer)
|
||||
|
||||
// Test inbound blocked
|
||||
engine.config.BlockInbound = true
|
||||
err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, engine.sshServer)
|
||||
engine.config.BlockInbound = false
|
||||
|
||||
// Test with server SSH not allowed
|
||||
err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, engine.sshServer)
|
||||
}
|
||||
|
||||
func TestEngine_SSHServerConsistency(t *testing.T) {
|
||||
|
||||
t.Run("server set only on successful creation", func(t *testing.T) {
|
||||
engine := &Engine{
|
||||
config: &EngineConfig{
|
||||
ServerSSHAllowed: true,
|
||||
SSHKey: []byte("test-key"),
|
||||
},
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
|
||||
engine.wgInterface = nil
|
||||
|
||||
err := engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, engine.sshServer)
|
||||
})
|
||||
|
||||
t.Run("cleanup handles nil gracefully", func(t *testing.T) {
|
||||
engine := &Engine{
|
||||
config: &EngineConfig{
|
||||
ServerSSHAllowed: false,
|
||||
},
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
|
||||
err := engine.stopSSHServer()
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, engine.sshServer)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
@@ -371,21 +411,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(
|
||||
ctx, cancel,
|
||||
&signal.MockClient{},
|
||||
&mgmt.MockClient{},
|
||||
relayMgr,
|
||||
&EngineConfig{
|
||||
WgIfaceName: "utun102",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
},
|
||||
MobileDependency{},
|
||||
peer.NewRecorder("https://mgm"),
|
||||
nil)
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
WgIfaceName: "utun102",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
|
||||
wgIface := &MockWGIface{
|
||||
NameFunc: func() string { return "utun102" },
|
||||
@@ -604,7 +636,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
@@ -769,9 +801,9 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
engine.ctx = ctx
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -971,10 +1003,10 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
engine.ctx = ctx
|
||||
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1497,7 +1529,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
@@ -1556,7 +1588,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
jobManager := job.NewJobManager(nil, store)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
@@ -1584,13 +1616,16 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, jobManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -124,6 +124,11 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
@@ -150,6 +155,11 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
|
||||
@@ -666,7 +666,7 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
}
|
||||
}()
|
||||
|
||||
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package peer
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -10,5 +11,8 @@ const (
|
||||
)
|
||||
|
||||
func isForceRelayed() bool {
|
||||
if runtime.GOOS == "js" {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
|
||||
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
|
||||
log.Debugf("Gathering ICE candidates")
|
||||
|
||||
agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
|
||||
agent, err := icemaker.NewAgent(ctx, cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("create ICE agent: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -22,6 +23,8 @@ const (
|
||||
iceFailedTimeoutDefault = 6 * time.Second
|
||||
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
|
||||
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
|
||||
// iceAgentCloseTimeout is the maximum time to wait for ICE agent close to complete
|
||||
iceAgentCloseTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
type ThreadSafeAgent struct {
|
||||
@@ -32,18 +35,28 @@ type ThreadSafeAgent struct {
|
||||
func (a *ThreadSafeAgent) Close() error {
|
||||
var err error
|
||||
a.once.Do(func() {
|
||||
err = a.Agent.Close()
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- a.Agent.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err = <-done:
|
||||
case <-time.After(iceAgentCloseTimeout):
|
||||
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
||||
err = nil
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
||||
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
||||
iceKeepAlive := iceKeepAlive()
|
||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||
iceFailedTimeout := iceFailedTimeout()
|
||||
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
|
||||
|
||||
transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList)
|
||||
transportNet, err := newStdNet(ctx, iFaceDiscover, config.InterfaceBlackList)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
|
||||
@@ -3,9 +3,11 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||
return stdnet.NewNet(ifaceBlacklist)
|
||||
func newStdNet(ctx context.Context, _ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||
return stdnet.NewNet(ctx, ifaceBlacklist)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package ice
|
||||
|
||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
import (
|
||||
"context"
|
||||
|
||||
func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||
return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist)
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
func newStdNet(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||
return stdnet.NewNetWithDiscover(ctx, iFaceDiscover, ifaceBlacklist)
|
||||
}
|
||||
|
||||
@@ -21,9 +21,9 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const eventQueueSize = 10
|
||||
@@ -67,6 +67,7 @@ type State struct {
|
||||
BytesRx int64
|
||||
Latency time.Duration
|
||||
RosenpassEnabled bool
|
||||
SSHHostKey []byte
|
||||
routes map[string]struct{}
|
||||
}
|
||||
|
||||
@@ -572,6 +573,22 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePeerSSHHostKey updates peer's SSH host key
|
||||
func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
peerState.SSHHostKey = sshHostKey
|
||||
d.peers[peerPubKey] = peerState
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FinishPeerListModifications this event invoke the notification
|
||||
func (d *Status) FinishPeerListModifications() {
|
||||
d.mux.Lock()
|
||||
|
||||
@@ -209,7 +209,7 @@ func (w *WorkerICE) Close() {
|
||||
}
|
||||
|
||||
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
|
||||
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
||||
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create agent: %w", err)
|
||||
}
|
||||
@@ -411,7 +411,7 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
||||
|
||||
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
|
||||
if isController(w.config) {
|
||||
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
} else {
|
||||
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
}
|
||||
|
||||
@@ -44,6 +44,8 @@ type PKCEAuthProviderConfig struct {
|
||||
DisablePromptLogin bool
|
||||
// LoginFlag is used to configure the PKCE flow login behavior
|
||||
LoginFlag common.LoginFlag
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||
|
||||
@@ -44,24 +44,30 @@ var DefaultInterfaceBlacklist = []string{
|
||||
|
||||
// ConfigInput carries configuration changes to the client
|
||||
type ConfigInput struct {
|
||||
ManagementURL string
|
||||
AdminURL string
|
||||
ConfigPath string
|
||||
StateFilePath string
|
||||
PreSharedKey *string
|
||||
ServerSSHAllowed *bool
|
||||
NATExternalIPs []string
|
||||
CustomDNSAddress []byte
|
||||
RosenpassEnabled *bool
|
||||
RosenpassPermissive *bool
|
||||
InterfaceName *string
|
||||
WireguardPort *int
|
||||
NetworkMonitor *bool
|
||||
DisableAutoConnect *bool
|
||||
ExtraIFaceBlackList []string
|
||||
DNSRouteInterval *time.Duration
|
||||
ClientCertPath string
|
||||
ClientCertKeyPath string
|
||||
ManagementURL string
|
||||
AdminURL string
|
||||
ConfigPath string
|
||||
StateFilePath string
|
||||
PreSharedKey *string
|
||||
ServerSSHAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
SSHJWTCacheTTL *int
|
||||
NATExternalIPs []string
|
||||
CustomDNSAddress []byte
|
||||
RosenpassEnabled *bool
|
||||
RosenpassPermissive *bool
|
||||
InterfaceName *string
|
||||
WireguardPort *int
|
||||
NetworkMonitor *bool
|
||||
DisableAutoConnect *bool
|
||||
ExtraIFaceBlackList []string
|
||||
DNSRouteInterval *time.Duration
|
||||
ClientCertPath string
|
||||
ClientCertKeyPath string
|
||||
|
||||
DisableClientRoutes *bool
|
||||
DisableServerRoutes *bool
|
||||
@@ -82,18 +88,24 @@ type ConfigInput struct {
|
||||
// Config Configuration type
|
||||
type Config struct {
|
||||
// Wireguard private key of local peer
|
||||
PrivateKey string
|
||||
PreSharedKey string
|
||||
ManagementURL *url.URL
|
||||
AdminURL *url.URL
|
||||
WgIface string
|
||||
WgPort int
|
||||
NetworkMonitor *bool
|
||||
IFaceBlackList []string
|
||||
DisableIPv6Discovery bool
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
PrivateKey string
|
||||
PreSharedKey string
|
||||
ManagementURL *url.URL
|
||||
AdminURL *url.URL
|
||||
WgIface string
|
||||
WgPort int
|
||||
NetworkMonitor *bool
|
||||
IFaceBlackList []string
|
||||
DisableIPv6Discovery bool
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
SSHJWTCacheTTL *int
|
||||
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
@@ -376,6 +388,62 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||
if *input.EnableSSHRoot {
|
||||
log.Infof("enabling SSH root login")
|
||||
} else {
|
||||
log.Infof("disabling SSH root login")
|
||||
}
|
||||
config.EnableSSHRoot = input.EnableSSHRoot
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHSFTP != nil && input.EnableSSHSFTP != config.EnableSSHSFTP {
|
||||
if *input.EnableSSHSFTP {
|
||||
log.Infof("enabling SSH SFTP subsystem")
|
||||
} else {
|
||||
log.Infof("disabling SSH SFTP subsystem")
|
||||
}
|
||||
config.EnableSSHSFTP = input.EnableSSHSFTP
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHLocalPortForwarding != nil && input.EnableSSHLocalPortForwarding != config.EnableSSHLocalPortForwarding {
|
||||
if *input.EnableSSHLocalPortForwarding {
|
||||
log.Infof("enabling SSH local port forwarding")
|
||||
} else {
|
||||
log.Infof("disabling SSH local port forwarding")
|
||||
}
|
||||
config.EnableSSHLocalPortForwarding = input.EnableSSHLocalPortForwarding
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHRemotePortForwarding != nil && input.EnableSSHRemotePortForwarding != config.EnableSSHRemotePortForwarding {
|
||||
if *input.EnableSSHRemotePortForwarding {
|
||||
log.Infof("enabling SSH remote port forwarding")
|
||||
} else {
|
||||
log.Infof("disabling SSH remote port forwarding")
|
||||
}
|
||||
config.EnableSSHRemotePortForwarding = input.EnableSSHRemotePortForwarding
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth {
|
||||
if *input.DisableSSHAuth {
|
||||
log.Infof("disabling SSH authentication")
|
||||
} else {
|
||||
log.Infof("enabling SSH authentication")
|
||||
}
|
||||
config.DisableSSHAuth = input.DisableSSHAuth
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
|
||||
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
|
||||
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
|
||||
log.Infof("updating DNS route interval to %s (old value %s)",
|
||||
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
|
||||
|
||||
@@ -193,10 +193,10 @@ func TestWireguardPortZeroExplicit(t *testing.T) {
|
||||
|
||||
func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wireguardPort *int
|
||||
expectedPort int
|
||||
description string
|
||||
name string
|
||||
wireguardPort *int
|
||||
expectedPort int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "no port specified uses default",
|
||||
|
||||
@@ -132,3 +132,21 @@ func (pm *ProfileManager) setActiveProfileState(profileName string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLoginHint retrieves the email from the active profile to use as login_hint.
|
||||
func GetLoginHint() string {
|
||||
pm := NewProfileManager()
|
||||
activeProf, err := pm.GetActiveProfile()
|
||||
if err != nil {
|
||||
log.Debugf("failed to get active profile for login hint: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
return profileState.Email
|
||||
}
|
||||
|
||||
@@ -197,7 +197,7 @@ func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr stri
|
||||
}
|
||||
}()
|
||||
|
||||
net, err := stdnet.NewNet(nil)
|
||||
net, err := stdnet.NewNet(ctx, nil)
|
||||
if err != nil {
|
||||
probeErr = fmt.Errorf("new net: %w", err)
|
||||
return
|
||||
@@ -286,7 +286,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri
|
||||
}
|
||||
}()
|
||||
|
||||
net, err := stdnet.NewNet(nil)
|
||||
net, err := stdnet.NewNet(ctx, nil)
|
||||
if err != nil {
|
||||
probeErr = fmt.Errorf("new net: %w", err)
|
||||
return
|
||||
|
||||
@@ -18,8 +18,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
||||
@@ -39,6 +38,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -403,7 +403,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
@@ -436,7 +436,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
|
||||
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
|
||||
@@ -4,17 +4,28 @@
|
||||
package stdnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/pion/transport/v3"
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
)
|
||||
|
||||
const updateInterval = 30 * time.Second
|
||||
const (
|
||||
updateInterval = 30 * time.Second
|
||||
dnsResolveTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
var errNoSuitableAddress = errors.New("no suitable address found")
|
||||
|
||||
// Net is an implementation of the net.Net interface
|
||||
// based on functions of the standard net package.
|
||||
@@ -28,12 +39,19 @@ type Net struct {
|
||||
|
||||
// mu is shared between interfaces and lastUpdate
|
||||
mu sync.Mutex
|
||||
|
||||
// ctx is the context for network operations that supports cancellation
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewNetWithDiscover creates a new StdNet instance.
|
||||
func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
|
||||
func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
n := &Net{
|
||||
interfaceFilter: InterfaceFilter(disallowList),
|
||||
ctx: ctx,
|
||||
}
|
||||
// current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client
|
||||
// so in android cli use pionDiscover
|
||||
@@ -46,14 +64,64 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri
|
||||
}
|
||||
|
||||
// NewNet creates a new StdNet instance.
|
||||
func NewNet(disallowList []string) (*Net, error) {
|
||||
func NewNet(ctx context.Context, disallowList []string) (*Net, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
n := &Net{
|
||||
iFaceDiscover: pionDiscover{},
|
||||
interfaceFilter: InterfaceFilter(disallowList),
|
||||
ctx: ctx,
|
||||
}
|
||||
return n, n.UpdateInterfaces()
|
||||
}
|
||||
|
||||
// resolveAddr performs DNS resolution with context support and timeout.
|
||||
func (n *Net) resolveAddr(network, address string) (netip.AddrPort, error) {
|
||||
host, portStr, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, fmt.Errorf("invalid port: %w", err)
|
||||
}
|
||||
if port < 0 || port > 65535 {
|
||||
return netip.AddrPort{}, fmt.Errorf("invalid port: %d", port)
|
||||
}
|
||||
|
||||
ipNet := "ip"
|
||||
switch network {
|
||||
case "tcp4", "udp4":
|
||||
ipNet = "ip4"
|
||||
case "tcp6", "udp6":
|
||||
ipNet = "ip6"
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
addr := netip.IPv4Unspecified()
|
||||
if ipNet == "ip6" {
|
||||
addr = netip.IPv6Unspecified()
|
||||
}
|
||||
return netip.AddrPortFrom(addr, uint16(port)), nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(n.ctx, dnsResolveTimeout)
|
||||
defer cancel()
|
||||
|
||||
addrs, err := net.DefaultResolver.LookupNetIP(ctx, ipNet, host)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
|
||||
if len(addrs) == 0 {
|
||||
return netip.AddrPort{}, errNoSuitableAddress
|
||||
}
|
||||
|
||||
return netip.AddrPortFrom(addrs[0], uint16(port)), nil
|
||||
}
|
||||
|
||||
// UpdateInterfaces updates the internal list of network interfaces
|
||||
// and associated addresses filtering them by name.
|
||||
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
|
||||
@@ -137,3 +205,39 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ResolveUDPAddr resolves UDP addresses with context support and timeout.
|
||||
func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) {
|
||||
switch network {
|
||||
case "udp", "udp4", "udp6":
|
||||
case "":
|
||||
network = "udp"
|
||||
default:
|
||||
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
|
||||
}
|
||||
|
||||
addrPort, err := n.resolveAddr(network, address)
|
||||
if err != nil {
|
||||
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.UDPAddr{IP: nil}, Err: err}
|
||||
}
|
||||
|
||||
return net.UDPAddrFromAddrPort(addrPort), nil
|
||||
}
|
||||
|
||||
// ResolveTCPAddr resolves TCP addresses with context support and timeout.
|
||||
func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) {
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
case "":
|
||||
network = "tcp"
|
||||
default:
|
||||
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
|
||||
}
|
||||
|
||||
addrPort, err := n.resolveAddr(network, address)
|
||||
if err != nil {
|
||||
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.TCPAddr{IP: nil}, Err: err}
|
||||
}
|
||||
|
||||
return net.TCPAddrFromAddrPort(addrPort), nil
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
299
client/internal/templates/pkce_auth_msg_test.go
Normal file
299
client/internal/templates/pkce_auth_msg_test.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package templates
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPKCEAuthMsgTemplate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data map[string]string
|
||||
outputFile string
|
||||
expectedTitle string
|
||||
expectedInContent []string
|
||||
notExpectedInContent []string
|
||||
}{
|
||||
{
|
||||
name: "error_state",
|
||||
data: map[string]string{
|
||||
"Error": "authentication failed: invalid state",
|
||||
},
|
||||
outputFile: "pkce-auth-error.html",
|
||||
expectedTitle: "Login Failed",
|
||||
expectedInContent: []string{
|
||||
"authentication failed: invalid state",
|
||||
"Login Failed",
|
||||
},
|
||||
notExpectedInContent: []string{
|
||||
"Login Successful",
|
||||
"Your device is now registered and logged in to NetBird",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success_state",
|
||||
data: map[string]string{
|
||||
// No error field means success
|
||||
},
|
||||
outputFile: "pkce-auth-success.html",
|
||||
expectedTitle: "Login Successful",
|
||||
expectedInContent: []string{
|
||||
"Login Successful",
|
||||
"Your device is now registered and logged in to NetBird. You can now close this window.",
|
||||
},
|
||||
notExpectedInContent: []string{
|
||||
"Login Failed",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "error_state_timeout",
|
||||
data: map[string]string{
|
||||
"Error": "authentication timeout: request expired after 5 minutes",
|
||||
},
|
||||
outputFile: "pkce-auth-timeout.html",
|
||||
expectedTitle: "Login Failed",
|
||||
expectedInContent: []string{
|
||||
"authentication timeout: request expired after 5 minutes",
|
||||
"Login Failed",
|
||||
},
|
||||
notExpectedInContent: []string{
|
||||
"Login Successful",
|
||||
"Your device is now registered and logged in to NetBird",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Parse the template
|
||||
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
// Create temp directory for this test
|
||||
tempDir := t.TempDir()
|
||||
outputPath := filepath.Join(tempDir, tt.outputFile)
|
||||
|
||||
// Create output file
|
||||
file, err := os.Create(outputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create output file: %v", err)
|
||||
}
|
||||
|
||||
// Execute the template
|
||||
if err := tmpl.Execute(file, tt.data); err != nil {
|
||||
file.Close()
|
||||
t.Fatalf("Failed to execute template: %v", err)
|
||||
}
|
||||
file.Close()
|
||||
|
||||
t.Logf("Generated test output: %s", outputPath)
|
||||
|
||||
// Read the generated file
|
||||
content, err := os.ReadFile(outputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read output file: %v", err)
|
||||
}
|
||||
|
||||
contentStr := string(content)
|
||||
|
||||
// Verify file has content
|
||||
if len(contentStr) == 0 {
|
||||
t.Error("Output file is empty")
|
||||
}
|
||||
|
||||
// Verify basic HTML structure
|
||||
basicElements := []string{
|
||||
"<!DOCTYPE html>",
|
||||
"<html",
|
||||
"<head>",
|
||||
"<body>",
|
||||
"NetBird",
|
||||
}
|
||||
|
||||
for _, elem := range basicElements {
|
||||
if !contains(contentStr, elem) {
|
||||
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify expected title
|
||||
if !contains(contentStr, tt.expectedTitle) {
|
||||
t.Errorf("Expected HTML to contain title '%s', but it was not found", tt.expectedTitle)
|
||||
}
|
||||
|
||||
// Verify expected content is present
|
||||
for _, expected := range tt.expectedInContent {
|
||||
if !contains(contentStr, expected) {
|
||||
t.Errorf("Expected HTML to contain '%s', but it was not found", expected)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify unexpected content is not present
|
||||
for _, notExpected := range tt.notExpectedInContent {
|
||||
if contains(contentStr, notExpected) {
|
||||
t.Errorf("Expected HTML to NOT contain '%s', but it was found", notExpected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPKCEAuthMsgTemplateValidation(t *testing.T) {
|
||||
// Test that the template can be parsed without errors
|
||||
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
|
||||
if err != nil {
|
||||
t.Fatalf("Template parsing failed: %v", err)
|
||||
}
|
||||
|
||||
// Test with empty data
|
||||
t.Run("empty_data", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
outputPath := filepath.Join(tempDir, "empty-data.html")
|
||||
|
||||
file, err := os.Create(outputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create output file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if err := tmpl.Execute(file, nil); err != nil {
|
||||
t.Errorf("Template execution with nil data failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Test with error data
|
||||
t.Run("with_error", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
outputPath := filepath.Join(tempDir, "with-error.html")
|
||||
|
||||
file, err := os.Create(outputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create output file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
data := map[string]string{
|
||||
"Error": "test error message",
|
||||
}
|
||||
if err := tmpl.Execute(file, data); err != nil {
|
||||
t.Errorf("Template execution with error data failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPKCEAuthMsgTemplateContent(t *testing.T) {
|
||||
// Test that the template contains expected elements
|
||||
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
|
||||
if err != nil {
|
||||
t.Fatalf("Template parsing failed: %v", err)
|
||||
}
|
||||
|
||||
t.Run("success_content", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
outputPath := filepath.Join(tempDir, "success.html")
|
||||
|
||||
file, err := os.Create(outputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create output file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
data := map[string]string{}
|
||||
if err := tmpl.Execute(file, data); err != nil {
|
||||
t.Fatalf("Template execution failed: %v", err)
|
||||
}
|
||||
|
||||
// Read the file and verify it contains expected content
|
||||
content, err := os.ReadFile(outputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read output file: %v", err)
|
||||
}
|
||||
|
||||
// Check for success indicators
|
||||
contentStr := string(content)
|
||||
if len(contentStr) == 0 {
|
||||
t.Error("Generated HTML is empty")
|
||||
}
|
||||
|
||||
// Basic HTML structure checks
|
||||
requiredElements := []string{
|
||||
"<!DOCTYPE html>",
|
||||
"<html",
|
||||
"<head>",
|
||||
"<body>",
|
||||
"Login Successful",
|
||||
"NetBird",
|
||||
}
|
||||
|
||||
for _, elem := range requiredElements {
|
||||
if !contains(contentStr, elem) {
|
||||
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error_content", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
outputPath := filepath.Join(tempDir, "error.html")
|
||||
|
||||
file, err := os.Create(outputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create output file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
errorMsg := "test error message"
|
||||
data := map[string]string{
|
||||
"Error": errorMsg,
|
||||
}
|
||||
if err := tmpl.Execute(file, data); err != nil {
|
||||
t.Fatalf("Template execution failed: %v", err)
|
||||
}
|
||||
|
||||
// Read the file and verify it contains expected content
|
||||
content, err := os.ReadFile(outputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read output file: %v", err)
|
||||
}
|
||||
|
||||
// Check for error indicators
|
||||
contentStr := string(content)
|
||||
if len(contentStr) == 0 {
|
||||
t.Error("Generated HTML is empty")
|
||||
}
|
||||
|
||||
// Basic HTML structure checks
|
||||
requiredElements := []string{
|
||||
"<!DOCTYPE html>",
|
||||
"<html",
|
||||
"<head>",
|
||||
"<body>",
|
||||
"Login Failed",
|
||||
errorMsg,
|
||||
}
|
||||
|
||||
for _, elem := range requiredElements {
|
||||
if !contains(contentStr, elem) {
|
||||
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
|
||||
(len(s) > 0 && len(substr) > 0 && containsHelper(s, substr)))
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -20,8 +20,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
@@ -228,7 +228,7 @@ func (c *Client) LoginForMobile() string {
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false)
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "")
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
66
client/jobexec/executor.go
Normal file
66
client/jobexec/executor.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package jobexec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxBundleWaitTime = 60 * time.Minute // maximum wait time for bundle generation (1 hour)
|
||||
)
|
||||
|
||||
var (
|
||||
ErrJobNotImplemented = errors.New("job not implemented")
|
||||
)
|
||||
|
||||
type Executor struct {
|
||||
}
|
||||
|
||||
func NewExecutor() *Executor {
|
||||
return &Executor{}
|
||||
}
|
||||
|
||||
func (e *Executor) BundleJob(ctx context.Context, debugBundleDependencies debug.GeneratorDependencies, params debug.BundleConfig, waitForDuration time.Duration, mgmURL string) (string, error) {
|
||||
if waitForDuration > MaxBundleWaitTime {
|
||||
log.Warnf("bundle wait time %v exceeds maximum %v, capping to maximum", waitForDuration, MaxBundleWaitTime)
|
||||
waitForDuration = MaxBundleWaitTime
|
||||
}
|
||||
|
||||
if waitForDuration > 0 {
|
||||
waitFor(ctx, waitForDuration)
|
||||
}
|
||||
|
||||
log.Infof("execute debug bundle generation")
|
||||
|
||||
bundleGenerator := debug.NewBundleGenerator(debugBundleDependencies, params)
|
||||
|
||||
path, err := bundleGenerator.Generate()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate debug bundle: %w", err)
|
||||
}
|
||||
|
||||
key, err := debug.UploadDebugBundle(ctx, types.DefaultBundleURL, mgmURL, path)
|
||||
if err != nil {
|
||||
log.Errorf("failed to upload debug bundle: %v", err)
|
||||
return "", fmt.Errorf("upload debug bundle: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("debug bundle has been generated well")
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func waitFor(ctx context.Context, duration time.Duration) {
|
||||
log.Infof("wait for %v minutes before executing debug bundle", duration.Minutes())
|
||||
select {
|
||||
case <-time.After(duration):
|
||||
case <-ctx.Done():
|
||||
log.Infof("wait cancelled: %v", ctx.Err())
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -84,6 +84,15 @@ service DaemonService {
|
||||
rpc Logout(LogoutRequest) returns (LogoutResponse) {}
|
||||
|
||||
rpc GetFeatures(GetFeaturesRequest) returns (GetFeaturesResponse) {}
|
||||
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
rpc GetPeerSSHHostKey(GetPeerSSHHostKeyRequest) returns (GetPeerSSHHostKeyResponse) {}
|
||||
|
||||
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||
rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {}
|
||||
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
||||
}
|
||||
|
||||
|
||||
@@ -158,6 +167,16 @@ message LoginRequest {
|
||||
optional string username = 31;
|
||||
|
||||
optional int64 mtu = 32;
|
||||
|
||||
// hint is used to pre-fill the email/username field during SSO authentication
|
||||
optional string hint = 33;
|
||||
|
||||
optional bool enableSSHRoot = 34;
|
||||
optional bool enableSSHSFTP = 35;
|
||||
optional bool enableSSHLocalPortForwarding = 36;
|
||||
optional bool enableSSHRemotePortForwarding = 37;
|
||||
optional bool disableSSHAuth = 38;
|
||||
optional int32 sshJWTCacheTTL = 39;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
@@ -185,9 +204,9 @@ message UpResponse {}
|
||||
|
||||
message StatusRequest{
|
||||
bool getFullPeerStatus = 1;
|
||||
bool shouldRunProbes = 2;
|
||||
bool shouldRunProbes = 2;
|
||||
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
|
||||
optional bool waitForReady = 3;
|
||||
optional bool waitForReady = 3;
|
||||
}
|
||||
|
||||
message StatusResponse{
|
||||
@@ -252,6 +271,18 @@ message GetConfigResponse {
|
||||
bool disable_server_routes = 19;
|
||||
|
||||
bool block_lan_access = 20;
|
||||
|
||||
bool enableSSHRoot = 21;
|
||||
|
||||
bool enableSSHSFTP = 24;
|
||||
|
||||
bool enableSSHLocalPortForwarding = 22;
|
||||
|
||||
bool enableSSHRemotePortForwarding = 23;
|
||||
|
||||
bool disableSSHAuth = 25;
|
||||
|
||||
int32 sshJWTCacheTTL = 26;
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
@@ -273,6 +304,7 @@ message PeerState {
|
||||
repeated string networks = 16;
|
||||
google.protobuf.Duration latency = 17;
|
||||
string relayAddress = 18;
|
||||
bytes sshHostKey = 19;
|
||||
}
|
||||
|
||||
// LocalPeerState contains the latest state of the local peer
|
||||
@@ -314,6 +346,20 @@ message NSGroupState {
|
||||
string error = 4;
|
||||
}
|
||||
|
||||
// SSHSessionInfo contains information about an active SSH session
|
||||
message SSHSessionInfo {
|
||||
string username = 1;
|
||||
string remoteAddress = 2;
|
||||
string command = 3;
|
||||
string jwtUsername = 4;
|
||||
}
|
||||
|
||||
// SSHServerState contains the latest state of the SSH server
|
||||
message SSHServerState {
|
||||
bool enabled = 1;
|
||||
repeated SSHSessionInfo sessions = 2;
|
||||
}
|
||||
|
||||
// FullStatus contains the full state held by the Status instance
|
||||
message FullStatus {
|
||||
ManagementState managementState = 1;
|
||||
@@ -327,6 +373,7 @@ message FullStatus {
|
||||
repeated SystemEvent events = 7;
|
||||
|
||||
bool lazyConnectionEnabled = 9;
|
||||
SSHServerState sshServerState = 10;
|
||||
}
|
||||
|
||||
// Networks
|
||||
@@ -387,7 +434,6 @@ message ForwardingRulesResponse {
|
||||
// DebugBundler
|
||||
message DebugBundleRequest {
|
||||
bool anonymize = 1;
|
||||
string status = 2;
|
||||
bool systemInfo = 3;
|
||||
string uploadURL = 4;
|
||||
uint32 logFileCount = 5;
|
||||
@@ -540,56 +586,63 @@ message SwitchProfileRequest {
|
||||
message SwitchProfileResponse {}
|
||||
|
||||
message SetConfigRequest {
|
||||
string username = 1;
|
||||
string profileName = 2;
|
||||
// managementUrl to authenticate.
|
||||
string managementUrl = 3;
|
||||
string username = 1;
|
||||
string profileName = 2;
|
||||
// managementUrl to authenticate.
|
||||
string managementUrl = 3;
|
||||
|
||||
// adminUrl to manage keys.
|
||||
string adminURL = 4;
|
||||
// adminUrl to manage keys.
|
||||
string adminURL = 4;
|
||||
|
||||
optional bool rosenpassEnabled = 5;
|
||||
optional bool rosenpassEnabled = 5;
|
||||
|
||||
optional string interfaceName = 6;
|
||||
optional string interfaceName = 6;
|
||||
|
||||
optional int64 wireguardPort = 7;
|
||||
optional int64 wireguardPort = 7;
|
||||
|
||||
optional string optionalPreSharedKey = 8;
|
||||
optional string optionalPreSharedKey = 8;
|
||||
|
||||
optional bool disableAutoConnect = 9;
|
||||
optional bool disableAutoConnect = 9;
|
||||
|
||||
optional bool serverSSHAllowed = 10;
|
||||
optional bool serverSSHAllowed = 10;
|
||||
|
||||
optional bool rosenpassPermissive = 11;
|
||||
optional bool rosenpassPermissive = 11;
|
||||
|
||||
optional bool networkMonitor = 12;
|
||||
optional bool networkMonitor = 12;
|
||||
|
||||
optional bool disable_client_routes = 13;
|
||||
optional bool disable_server_routes = 14;
|
||||
optional bool disable_dns = 15;
|
||||
optional bool disable_firewall = 16;
|
||||
optional bool block_lan_access = 17;
|
||||
optional bool disable_client_routes = 13;
|
||||
optional bool disable_server_routes = 14;
|
||||
optional bool disable_dns = 15;
|
||||
optional bool disable_firewall = 16;
|
||||
optional bool block_lan_access = 17;
|
||||
|
||||
optional bool disable_notifications = 18;
|
||||
optional bool disable_notifications = 18;
|
||||
|
||||
optional bool lazyConnectionEnabled = 19;
|
||||
optional bool lazyConnectionEnabled = 19;
|
||||
|
||||
optional bool block_inbound = 20;
|
||||
optional bool block_inbound = 20;
|
||||
|
||||
repeated string natExternalIPs = 21;
|
||||
bool cleanNATExternalIPs = 22;
|
||||
repeated string natExternalIPs = 21;
|
||||
bool cleanNATExternalIPs = 22;
|
||||
|
||||
bytes customDNSAddress = 23;
|
||||
bytes customDNSAddress = 23;
|
||||
|
||||
repeated string extraIFaceBlacklist = 24;
|
||||
repeated string extraIFaceBlacklist = 24;
|
||||
|
||||
repeated string dns_labels = 25;
|
||||
// cleanDNSLabels clean map list of DNS labels.
|
||||
bool cleanDNSLabels = 26;
|
||||
repeated string dns_labels = 25;
|
||||
// cleanDNSLabels clean map list of DNS labels.
|
||||
bool cleanDNSLabels = 26;
|
||||
|
||||
optional google.protobuf.Duration dnsRouteInterval = 27;
|
||||
optional google.protobuf.Duration dnsRouteInterval = 27;
|
||||
|
||||
optional int64 mtu = 28;
|
||||
optional int64 mtu = 28;
|
||||
|
||||
optional bool enableSSHRoot = 29;
|
||||
optional bool enableSSHSFTP = 30;
|
||||
optional bool enableSSHLocalPortForwarding = 31;
|
||||
optional bool enableSSHRemotePortForwarding = 32;
|
||||
optional bool disableSSHAuth = 33;
|
||||
optional int32 sshJWTCacheTTL = 34;
|
||||
}
|
||||
|
||||
message SetConfigResponse{}
|
||||
@@ -641,3 +694,63 @@ message GetFeaturesResponse{
|
||||
bool disable_profiles = 1;
|
||||
bool disable_update_settings = 2;
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer
|
||||
message GetPeerSSHHostKeyRequest {
|
||||
// peer IP address or FQDN to get SSH host key for
|
||||
string peerAddress = 1;
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKeyResponse contains the SSH host key for the requested peer
|
||||
message GetPeerSSHHostKeyResponse {
|
||||
// SSH host key in SSH public key format (e.g., "ssh-ed25519 AAAAC3... hostname")
|
||||
bytes sshHostKey = 1;
|
||||
// peer IP address
|
||||
string peerIP = 2;
|
||||
// peer FQDN
|
||||
string peerFQDN = 3;
|
||||
// indicates if the SSH host key was found
|
||||
bool found = 4;
|
||||
}
|
||||
|
||||
// RequestJWTAuthRequest for initiating JWT authentication flow
|
||||
message RequestJWTAuthRequest {
|
||||
// hint for OIDC login_hint parameter (typically email address)
|
||||
optional string hint = 1;
|
||||
}
|
||||
|
||||
// RequestJWTAuthResponse contains authentication flow information
|
||||
message RequestJWTAuthResponse {
|
||||
// verification URI for user authentication
|
||||
string verificationURI = 1;
|
||||
// complete verification URI (with embedded user code)
|
||||
string verificationURIComplete = 2;
|
||||
// user code to enter on verification URI
|
||||
string userCode = 3;
|
||||
// device code for polling
|
||||
string deviceCode = 4;
|
||||
// expiration time in seconds
|
||||
int64 expiresIn = 5;
|
||||
// if a cached token is available, it will be returned here
|
||||
string cachedToken = 6;
|
||||
// maximum age of JWT tokens in seconds (from management server)
|
||||
int64 maxTokenAge = 7;
|
||||
}
|
||||
|
||||
// WaitJWTTokenRequest for waiting for authentication completion
|
||||
message WaitJWTTokenRequest {
|
||||
// device code from RequestJWTAuthResponse
|
||||
string deviceCode = 1;
|
||||
// user code for verification
|
||||
string userCode = 2;
|
||||
}
|
||||
|
||||
// WaitJWTTokenResponse contains the JWT token after authentication
|
||||
message WaitJWTTokenResponse {
|
||||
// JWT token (access token or ID token)
|
||||
string token = 1;
|
||||
// token type (e.g., "Bearer")
|
||||
string tokenType = 2;
|
||||
// expiration time in seconds
|
||||
int64 expiresIn = 3;
|
||||
}
|
||||
|
||||
@@ -64,6 +64,12 @@ type DaemonServiceClient interface {
|
||||
// Logout disconnects from the network and deletes the peer from the management server
|
||||
Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error)
|
||||
GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error)
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error)
|
||||
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
|
||||
}
|
||||
|
||||
type daemonServiceClient struct {
|
||||
@@ -349,6 +355,33 @@ func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRe
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) {
|
||||
out := new(GetPeerSSHHostKeyResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetPeerSSHHostKey", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) {
|
||||
out := new(RequestJWTAuthResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/RequestJWTAuth", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) {
|
||||
out := new(WaitJWTTokenResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitJWTToken", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// DaemonServiceServer is the server API for DaemonService service.
|
||||
// All implementations must embed UnimplementedDaemonServiceServer
|
||||
// for forward compatibility
|
||||
@@ -399,6 +432,12 @@ type DaemonServiceServer interface {
|
||||
// Logout disconnects from the network and deletes the peer from the management server
|
||||
Logout(context.Context, *LogoutRequest) (*LogoutResponse, error)
|
||||
GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error)
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error)
|
||||
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
|
||||
mustEmbedUnimplementedDaemonServiceServer()
|
||||
}
|
||||
|
||||
@@ -490,6 +529,15 @@ func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest)
|
||||
func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetFeatures not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method RequestJWTAuth not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||
|
||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
@@ -1010,6 +1058,60 @@ func _DaemonService_GetFeatures_Handler(srv interface{}, ctx context.Context, de
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_GetPeerSSHHostKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(GetPeerSSHHostKeyRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/GetPeerSSHHostKey",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, req.(*GetPeerSSHHostKeyRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_RequestJWTAuth_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RequestJWTAuthRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).RequestJWTAuth(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/RequestJWTAuth",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).RequestJWTAuth(ctx, req.(*RequestJWTAuthRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(WaitJWTTokenRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).WaitJWTToken(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/WaitJWTToken",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).WaitJWTToken(ctx, req.(*WaitJWTTokenRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@@ -1125,6 +1227,18 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "GetFeatures",
|
||||
Handler: _DaemonService_GetFeatures_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "GetPeerSSHHostKey",
|
||||
Handler: _DaemonService_GetPeerSSHHostKey_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "RequestJWTAuth",
|
||||
Handler: _DaemonService_RequestJWTAuth_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "WaitJWTToken",
|
||||
Handler: _DaemonService_WaitJWTToken_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
|
||||
@@ -4,24 +4,16 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
const maxBundleUploadSize = 50 * 1024 * 1024
|
||||
|
||||
// DebugBundle creates a debug bundle and returns the location.
|
||||
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
|
||||
s.mutex.Lock()
|
||||
@@ -37,11 +29,10 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
InternalConfig: s.config,
|
||||
StatusRecorder: s.statusRecorder,
|
||||
SyncResponse: syncResponse,
|
||||
LogFile: s.logFile,
|
||||
LogPath: s.logFile,
|
||||
},
|
||||
debug.BundleConfig{
|
||||
Anonymize: req.GetAnonymize(),
|
||||
ClientStatus: req.GetStatus(),
|
||||
IncludeSystemInfo: req.GetSystemInfo(),
|
||||
LogFileCount: req.GetLogFileCount(),
|
||||
},
|
||||
@@ -55,7 +46,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
if req.GetUploadURL() == "" {
|
||||
return &proto.DebugBundleResponse{Path: path}, nil
|
||||
}
|
||||
key, err := uploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
|
||||
key, err := debug.UploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
|
||||
if err != nil {
|
||||
log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err)
|
||||
return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil
|
||||
@@ -66,92 +57,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil
|
||||
}
|
||||
|
||||
func uploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
|
||||
response, err := getUploadURL(ctx, url, managementURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = upload(ctx, filePath, response)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return response.Key, nil
|
||||
}
|
||||
|
||||
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
|
||||
fileData, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open file: %w", err)
|
||||
}
|
||||
|
||||
defer fileData.Close()
|
||||
|
||||
stat, err := fileData.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat file: %w", err)
|
||||
}
|
||||
|
||||
if stat.Size() > maxBundleUploadSize {
|
||||
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create PUT request: %w", err)
|
||||
}
|
||||
|
||||
req.ContentLength = stat.Size()
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
|
||||
putResp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upload failed: %v", err)
|
||||
}
|
||||
defer putResp.Body.Close()
|
||||
|
||||
if putResp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(putResp.Body)
|
||||
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
|
||||
id := getURLHash(managementURL)
|
||||
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create GET request: %w", err)
|
||||
}
|
||||
|
||||
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
|
||||
|
||||
resp, err := http.DefaultClient.Do(getReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get presigned URL: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
urlBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response body: %w", err)
|
||||
}
|
||||
var response types.GetURLResponse
|
||||
if err := json.Unmarshal(urlBytes, &response); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func getURLHash(url string) string {
|
||||
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
|
||||
}
|
||||
|
||||
// GetLogLevel gets the current logging level for the server.
|
||||
func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) {
|
||||
s.mutex.Lock()
|
||||
|
||||
79
client/server/jwt_cache.go
Normal file
79
client/server/jwt_cache.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type jwtCache struct {
|
||||
mu sync.RWMutex
|
||||
enclave *memguard.Enclave
|
||||
expiresAt time.Time
|
||||
timer *time.Timer
|
||||
maxTokenSize int
|
||||
}
|
||||
|
||||
func newJWTCache() *jwtCache {
|
||||
return &jwtCache{
|
||||
maxTokenSize: 8192,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *jwtCache) store(token string, maxAge time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.cleanup()
|
||||
|
||||
if c.timer != nil {
|
||||
c.timer.Stop()
|
||||
}
|
||||
|
||||
tokenBytes := []byte(token)
|
||||
c.enclave = memguard.NewEnclave(tokenBytes)
|
||||
|
||||
c.expiresAt = time.Now().Add(maxAge)
|
||||
|
||||
var timer *time.Timer
|
||||
timer = time.AfterFunc(maxAge, func() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.timer != timer {
|
||||
return
|
||||
}
|
||||
c.cleanup()
|
||||
c.timer = nil
|
||||
log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge)
|
||||
})
|
||||
c.timer = timer
|
||||
}
|
||||
|
||||
func (c *jwtCache) get() (string, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if c.enclave == nil || time.Now().After(c.expiresAt) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
buffer, err := c.enclave.Open()
|
||||
if err != nil {
|
||||
log.Debugf("Failed to open JWT token enclave: %v", err)
|
||||
return "", false
|
||||
}
|
||||
defer buffer.Destroy()
|
||||
|
||||
token := string(buffer.Bytes())
|
||||
return token, true
|
||||
}
|
||||
|
||||
// cleanup destroys the secure enclave, must be called with lock held
|
||||
func (c *jwtCache) cleanup() {
|
||||
if c.enclave != nil {
|
||||
c.enclave = nil
|
||||
}
|
||||
c.expiresAt = time.Time{}
|
||||
}
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type selectRoute struct {
|
||||
|
||||
@@ -13,15 +13,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
@@ -32,6 +28,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -46,6 +43,9 @@ const (
|
||||
defaultMaxRetryTime = 14 * 24 * time.Hour
|
||||
defaultRetryMultiplier = 1.7
|
||||
|
||||
// JWT token cache TTL for the client daemon (disabled by default)
|
||||
defaultJWTCacheTTL = 0
|
||||
|
||||
errRestoreResidualState = "failed to restore residual state: %v"
|
||||
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
|
||||
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
|
||||
@@ -81,6 +81,8 @@ type Server struct {
|
||||
profileManager *profilemanager.ServiceManager
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
|
||||
jwtCache *jwtCache
|
||||
}
|
||||
|
||||
type oauthAuthFlow struct {
|
||||
@@ -100,6 +102,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
profileManager: profilemanager.NewServiceManager(configFile),
|
||||
profilesDisabled: profilesDisabled,
|
||||
updateSettingsDisabled: updateSettingsDisabled,
|
||||
jwtCache: newJWTCache(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -373,6 +376,17 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
config.DisableNotifications = msg.DisableNotifications
|
||||
config.LazyConnectionEnabled = msg.LazyConnectionEnabled
|
||||
config.BlockInbound = msg.BlockInbound
|
||||
config.EnableSSHRoot = msg.EnableSSHRoot
|
||||
config.EnableSSHSFTP = msg.EnableSSHSFTP
|
||||
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForwarding
|
||||
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForwarding
|
||||
if msg.DisableSSHAuth != nil {
|
||||
config.DisableSSHAuth = msg.DisableSSHAuth
|
||||
}
|
||||
if msg.SshJWTCacheTTL != nil {
|
||||
ttl := int(*msg.SshJWTCacheTTL)
|
||||
config.SSHJWTCacheTTL = &ttl
|
||||
}
|
||||
|
||||
if msg.Mtu != nil {
|
||||
mtu := uint16(*msg.Mtu)
|
||||
@@ -483,13 +497,17 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
state.Set(internal.StatusConnecting)
|
||||
|
||||
if msg.SetupKey == "" {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient)
|
||||
hint := ""
|
||||
if msg.Hint != nil {
|
||||
hint = *msg.Hint
|
||||
}
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint)
|
||||
if err != nil {
|
||||
state.Set(internal.StatusLoginFailed)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) {
|
||||
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(ctx) {
|
||||
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
|
||||
log.Debugf("using previous oauth flow info")
|
||||
return &proto.LoginResponse{
|
||||
@@ -506,7 +524,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
}
|
||||
}
|
||||
|
||||
authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("getting a request OAuth flow failed: %v", err)
|
||||
return nil, err
|
||||
@@ -1059,14 +1077,237 @@ func (s *Server) Status(
|
||||
if msg.GetFullPeerStatus {
|
||||
s.runProbes(msg.ShouldRunProbes)
|
||||
fullStatus := s.statusRecorder.GetFullStatus()
|
||||
pbFullStatus := toProtoFullStatus(fullStatus)
|
||||
pbFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||
|
||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
||||
|
||||
statusResponse.FullStatus = pbFullStatus
|
||||
}
|
||||
|
||||
return &statusResponse, nil
|
||||
}
|
||||
|
||||
// getSSHServerState retrieves the current SSH server state including enabled status and active sessions
|
||||
func (s *Server) getSSHServerState() *proto.SSHServerState {
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
enabled, sessions := engine.GetSSHServerStatus()
|
||||
sshServerState := &proto.SSHServerState{
|
||||
Enabled: enabled,
|
||||
}
|
||||
|
||||
for _, session := range sessions {
|
||||
sshServerState.Sessions = append(sshServerState.Sessions, &proto.SSHSessionInfo{
|
||||
Username: session.Username,
|
||||
RemoteAddress: session.RemoteAddress,
|
||||
Command: session.Command,
|
||||
JwtUsername: session.JWTUsername,
|
||||
})
|
||||
}
|
||||
|
||||
return sshServerState
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
func (s *Server) GetPeerSSHHostKey(
|
||||
ctx context.Context,
|
||||
req *proto.GetPeerSSHHostKeyRequest,
|
||||
) (*proto.GetPeerSSHHostKeyResponse, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
statusRecorder := s.statusRecorder
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil, errors.New("client not initialized")
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine not started")
|
||||
}
|
||||
|
||||
peerAddress := req.GetPeerAddress()
|
||||
hostKey, found := engine.GetPeerSSHKey(peerAddress)
|
||||
|
||||
response := &proto.GetPeerSSHHostKeyResponse{
|
||||
Found: found,
|
||||
}
|
||||
|
||||
if !found {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
response.SshHostKey = hostKey
|
||||
|
||||
if statusRecorder == nil {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
fullStatus := statusRecorder.GetFullStatus()
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
|
||||
response.PeerIP = peerState.IP
|
||||
response.PeerFQDN = peerState.FQDN
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// getJWTCacheTTL returns the JWT cache TTL from config or default (disabled)
|
||||
func (s *Server) getJWTCacheTTL() time.Duration {
|
||||
s.mutex.Lock()
|
||||
config := s.config
|
||||
s.mutex.Unlock()
|
||||
|
||||
if config == nil || config.SSHJWTCacheTTL == nil {
|
||||
return defaultJWTCacheTTL
|
||||
}
|
||||
|
||||
seconds := *config.SSHJWTCacheTTL
|
||||
if seconds == 0 {
|
||||
log.Debug("SSH JWT cache disabled (configured to 0)")
|
||||
return 0
|
||||
}
|
||||
|
||||
ttl := time.Duration(seconds) * time.Second
|
||||
log.Debugf("SSH JWT cache TTL set to %v from config", ttl)
|
||||
return ttl
|
||||
}
|
||||
|
||||
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||
func (s *Server) RequestJWTAuth(
|
||||
ctx context.Context,
|
||||
msg *proto.RequestJWTAuthRequest,
|
||||
) (*proto.RequestJWTAuthResponse, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
config := s.config
|
||||
s.mutex.Unlock()
|
||||
|
||||
if config == nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
|
||||
}
|
||||
|
||||
jwtCacheTTL := s.getJWTCacheTTL()
|
||||
if jwtCacheTTL > 0 {
|
||||
if cachedToken, found := s.jwtCache.get(); found {
|
||||
log.Debugf("JWT token found in cache, returning cached token for SSH authentication")
|
||||
|
||||
return &proto.RequestJWTAuthResponse{
|
||||
CachedToken: cachedToken,
|
||||
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
hint := ""
|
||||
if msg.Hint != nil {
|
||||
hint = *msg.Hint
|
||||
}
|
||||
|
||||
if hint == "" {
|
||||
hint = profilemanager.GetLoginHint()
|
||||
}
|
||||
|
||||
isDesktop := isUnixRunningDesktop()
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, hint)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
|
||||
}
|
||||
|
||||
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to request auth info: %v", err)
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
s.oauthAuthFlow.flow = oAuthFlow
|
||||
s.oauthAuthFlow.info = authInfo
|
||||
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second)
|
||||
s.mutex.Unlock()
|
||||
|
||||
return &proto.RequestJWTAuthResponse{
|
||||
VerificationURI: authInfo.VerificationURI,
|
||||
VerificationURIComplete: authInfo.VerificationURIComplete,
|
||||
UserCode: authInfo.UserCode,
|
||||
DeviceCode: authInfo.DeviceCode,
|
||||
ExpiresIn: int64(authInfo.ExpiresIn),
|
||||
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
func (s *Server) WaitJWTToken(
|
||||
ctx context.Context,
|
||||
req *proto.WaitJWTTokenRequest,
|
||||
) (*proto.WaitJWTTokenResponse, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
oAuthFlow := s.oauthAuthFlow.flow
|
||||
authInfo := s.oauthAuthFlow.info
|
||||
s.mutex.Unlock()
|
||||
|
||||
if oAuthFlow == nil || authInfo.DeviceCode != req.DeviceCode {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "invalid device code or no active auth flow")
|
||||
}
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(ctx, authInfo)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to get token: %v", err)
|
||||
}
|
||||
|
||||
token := tokenInfo.GetTokenToUse()
|
||||
|
||||
jwtCacheTTL := s.getJWTCacheTTL()
|
||||
if jwtCacheTTL > 0 {
|
||||
s.jwtCache.store(token, jwtCacheTTL)
|
||||
log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL)
|
||||
} else {
|
||||
log.Debug("JWT caching disabled, not storing token")
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
s.oauthAuthFlow = oauthAuthFlow{}
|
||||
s.mutex.Unlock()
|
||||
return &proto.WaitJWTTokenResponse{
|
||||
Token: tokenInfo.GetTokenToUse(),
|
||||
TokenType: tokenInfo.TokenType,
|
||||
ExpiresIn: int64(tokenInfo.ExpiresIn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
return false
|
||||
}
|
||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||
}
|
||||
|
||||
func (s *Server) runProbes(waitForProbeResult bool) {
|
||||
if s.connectClient == nil {
|
||||
return
|
||||
@@ -1132,25 +1373,61 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
disableServerRoutes := cfg.DisableServerRoutes
|
||||
blockLANAccess := cfg.BlockLANAccess
|
||||
|
||||
enableSSHRoot := false
|
||||
if cfg.EnableSSHRoot != nil {
|
||||
enableSSHRoot = *cfg.EnableSSHRoot
|
||||
}
|
||||
|
||||
enableSSHSFTP := false
|
||||
if cfg.EnableSSHSFTP != nil {
|
||||
enableSSHSFTP = *cfg.EnableSSHSFTP
|
||||
}
|
||||
|
||||
enableSSHLocalPortForwarding := false
|
||||
if cfg.EnableSSHLocalPortForwarding != nil {
|
||||
enableSSHLocalPortForwarding = *cfg.EnableSSHLocalPortForwarding
|
||||
}
|
||||
|
||||
enableSSHRemotePortForwarding := false
|
||||
if cfg.EnableSSHRemotePortForwarding != nil {
|
||||
enableSSHRemotePortForwarding = *cfg.EnableSSHRemotePortForwarding
|
||||
}
|
||||
|
||||
disableSSHAuth := false
|
||||
if cfg.DisableSSHAuth != nil {
|
||||
disableSSHAuth = *cfg.DisableSSHAuth
|
||||
}
|
||||
|
||||
sshJWTCacheTTL := int32(0)
|
||||
if cfg.SSHJWTCacheTTL != nil {
|
||||
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
|
||||
}
|
||||
|
||||
return &proto.GetConfigResponse{
|
||||
ManagementUrl: managementURL.String(),
|
||||
PreSharedKey: preSharedKey,
|
||||
AdminURL: adminURL.String(),
|
||||
InterfaceName: cfg.WgIface,
|
||||
WireguardPort: int64(cfg.WgPort),
|
||||
Mtu: int64(cfg.MTU),
|
||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
BlockInbound: cfg.BlockInbound,
|
||||
DisableNotifications: disableNotifications,
|
||||
NetworkMonitor: networkMonitor,
|
||||
DisableDns: disableDNS,
|
||||
DisableClientRoutes: disableClientRoutes,
|
||||
DisableServerRoutes: disableServerRoutes,
|
||||
BlockLanAccess: blockLANAccess,
|
||||
ManagementUrl: managementURL.String(),
|
||||
PreSharedKey: preSharedKey,
|
||||
AdminURL: adminURL.String(),
|
||||
InterfaceName: cfg.WgIface,
|
||||
WireguardPort: int64(cfg.WgPort),
|
||||
Mtu: int64(cfg.MTU),
|
||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
BlockInbound: cfg.BlockInbound,
|
||||
DisableNotifications: disableNotifications,
|
||||
NetworkMonitor: networkMonitor,
|
||||
DisableDns: disableDNS,
|
||||
DisableClientRoutes: disableClientRoutes,
|
||||
DisableServerRoutes: disableServerRoutes,
|
||||
BlockLanAccess: blockLANAccess,
|
||||
EnableSSHRoot: enableSSHRoot,
|
||||
EnableSSHSFTP: enableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
|
||||
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: disableSSHAuth,
|
||||
SshJWTCacheTTL: sshJWTCacheTTL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1258,7 +1535,7 @@ func (s *Server) connect(ctx context.Context, config *profilemanager.Config, sta
|
||||
log.Tracef("running client connection")
|
||||
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
|
||||
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
|
||||
if err := s.connectClient.Run(runningChan); err != nil {
|
||||
if err := s.connectClient.Run(runningChan, s.logFile); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -1332,93 +1609,6 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio
|
||||
return defaultDuration
|
||||
}
|
||||
|
||||
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||
pbFullStatus := proto.FullStatus{
|
||||
ManagementState: &proto.ManagementState{},
|
||||
SignalState: &proto.SignalState{},
|
||||
LocalPeerState: &proto.LocalPeerState{},
|
||||
Peers: []*proto.PeerState{},
|
||||
}
|
||||
|
||||
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
|
||||
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
|
||||
if err := fullStatus.ManagementState.Error; err != nil {
|
||||
pbFullStatus.ManagementState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
|
||||
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
|
||||
if err := fullStatus.SignalState.Error; err != nil {
|
||||
pbFullStatus.SignalState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
|
||||
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
|
||||
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
|
||||
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
|
||||
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
|
||||
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
|
||||
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
|
||||
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
|
||||
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
|
||||
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
pbPeerState := &proto.PeerState{
|
||||
IP: peerState.IP,
|
||||
PubKey: peerState.PubKey,
|
||||
ConnStatus: peerState.ConnStatus.String(),
|
||||
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
||||
Relayed: peerState.Relayed,
|
||||
LocalIceCandidateType: peerState.LocalIceCandidateType,
|
||||
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
|
||||
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
|
||||
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
|
||||
RelayAddress: peerState.RelayServerAddress,
|
||||
Fqdn: peerState.FQDN,
|
||||
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
|
||||
BytesRx: peerState.BytesRx,
|
||||
BytesTx: peerState.BytesTx,
|
||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||
Networks: maps.Keys(peerState.GetRoutes()),
|
||||
Latency: durationpb.New(peerState.Latency),
|
||||
}
|
||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||
}
|
||||
|
||||
for _, relayState := range fullStatus.Relays {
|
||||
pbRelayState := &proto.RelayState{
|
||||
URI: relayState.URI,
|
||||
Available: relayState.Err == nil,
|
||||
}
|
||||
if err := relayState.Err; err != nil {
|
||||
pbRelayState.Error = err.Error()
|
||||
}
|
||||
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
|
||||
}
|
||||
|
||||
for _, dnsState := range fullStatus.NSGroupStates {
|
||||
var err string
|
||||
if dnsState.Error != nil {
|
||||
err = dnsState.Error.Error()
|
||||
}
|
||||
|
||||
var servers []string
|
||||
for _, server := range dnsState.Servers {
|
||||
servers = append(servers, server.String())
|
||||
}
|
||||
|
||||
pbDnsState := &proto.NSGroupState{
|
||||
Servers: servers,
|
||||
Domains: dnsState.Domains,
|
||||
Enabled: dnsState.Enabled,
|
||||
Error: err,
|
||||
}
|
||||
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
|
||||
}
|
||||
|
||||
return &pbFullStatus
|
||||
}
|
||||
|
||||
// sendTerminalNotification sends a terminal notification message
|
||||
// to inform the user that the NetBird connection session has expired.
|
||||
func sendTerminalNotification() error {
|
||||
|
||||
@@ -15,6 +15,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
@@ -290,7 +295,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
jobManager := job.NewJobManager(nil, store)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
@@ -311,13 +316,16 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, jobManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -72,6 +72,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
lazyConnectionEnabled := true
|
||||
blockInbound := true
|
||||
mtu := int64(1280)
|
||||
sshJWTCacheTTL := int32(300)
|
||||
|
||||
req := &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
@@ -102,6 +103,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
CleanDNSLabels: false,
|
||||
DnsRouteInterval: durationpb.New(2 * time.Minute),
|
||||
Mtu: &mtu,
|
||||
SshJWTCacheTTL: &sshJWTCacheTTL,
|
||||
}
|
||||
|
||||
_, err = s.SetConfig(ctx, req)
|
||||
@@ -146,6 +148,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList())
|
||||
require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval)
|
||||
require.Equal(t, uint16(mtu), cfg.MTU)
|
||||
require.NotNil(t, cfg.SSHJWTCacheTTL)
|
||||
require.Equal(t, int(sshJWTCacheTTL), *cfg.SSHJWTCacheTTL)
|
||||
|
||||
verifyAllFieldsCovered(t, req)
|
||||
}
|
||||
@@ -167,30 +171,36 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
||||
}
|
||||
|
||||
expectedFields := map[string]bool{
|
||||
"ManagementUrl": true,
|
||||
"AdminURL": true,
|
||||
"RosenpassEnabled": true,
|
||||
"RosenpassPermissive": true,
|
||||
"ServerSSHAllowed": true,
|
||||
"InterfaceName": true,
|
||||
"WireguardPort": true,
|
||||
"OptionalPreSharedKey": true,
|
||||
"DisableAutoConnect": true,
|
||||
"NetworkMonitor": true,
|
||||
"DisableClientRoutes": true,
|
||||
"DisableServerRoutes": true,
|
||||
"DisableDns": true,
|
||||
"DisableFirewall": true,
|
||||
"BlockLanAccess": true,
|
||||
"DisableNotifications": true,
|
||||
"LazyConnectionEnabled": true,
|
||||
"BlockInbound": true,
|
||||
"NatExternalIPs": true,
|
||||
"CustomDNSAddress": true,
|
||||
"ExtraIFaceBlacklist": true,
|
||||
"DnsLabels": true,
|
||||
"DnsRouteInterval": true,
|
||||
"Mtu": true,
|
||||
"ManagementUrl": true,
|
||||
"AdminURL": true,
|
||||
"RosenpassEnabled": true,
|
||||
"RosenpassPermissive": true,
|
||||
"ServerSSHAllowed": true,
|
||||
"InterfaceName": true,
|
||||
"WireguardPort": true,
|
||||
"OptionalPreSharedKey": true,
|
||||
"DisableAutoConnect": true,
|
||||
"NetworkMonitor": true,
|
||||
"DisableClientRoutes": true,
|
||||
"DisableServerRoutes": true,
|
||||
"DisableDns": true,
|
||||
"DisableFirewall": true,
|
||||
"BlockLanAccess": true,
|
||||
"DisableNotifications": true,
|
||||
"LazyConnectionEnabled": true,
|
||||
"BlockInbound": true,
|
||||
"NatExternalIPs": true,
|
||||
"CustomDNSAddress": true,
|
||||
"ExtraIFaceBlacklist": true,
|
||||
"DnsLabels": true,
|
||||
"DnsRouteInterval": true,
|
||||
"Mtu": true,
|
||||
"EnableSSHRoot": true,
|
||||
"EnableSSHSFTP": true,
|
||||
"EnableSSHLocalPortForwarding": true,
|
||||
"EnableSSHRemotePortForwarding": true,
|
||||
"DisableSSHAuth": true,
|
||||
"SshJWTCacheTTL": true,
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(req).Elem()
|
||||
@@ -221,29 +231,35 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
||||
// Map of CLI flag names to their corresponding SetConfigRequest field names.
|
||||
// This map must be updated when adding new config-related CLI flags.
|
||||
flagToField := map[string]string{
|
||||
"management-url": "ManagementUrl",
|
||||
"admin-url": "AdminURL",
|
||||
"enable-rosenpass": "RosenpassEnabled",
|
||||
"rosenpass-permissive": "RosenpassPermissive",
|
||||
"allow-server-ssh": "ServerSSHAllowed",
|
||||
"interface-name": "InterfaceName",
|
||||
"wireguard-port": "WireguardPort",
|
||||
"preshared-key": "OptionalPreSharedKey",
|
||||
"disable-auto-connect": "DisableAutoConnect",
|
||||
"network-monitor": "NetworkMonitor",
|
||||
"disable-client-routes": "DisableClientRoutes",
|
||||
"disable-server-routes": "DisableServerRoutes",
|
||||
"disable-dns": "DisableDns",
|
||||
"disable-firewall": "DisableFirewall",
|
||||
"block-lan-access": "BlockLanAccess",
|
||||
"block-inbound": "BlockInbound",
|
||||
"enable-lazy-connection": "LazyConnectionEnabled",
|
||||
"external-ip-map": "NatExternalIPs",
|
||||
"dns-resolver-address": "CustomDNSAddress",
|
||||
"extra-iface-blacklist": "ExtraIFaceBlacklist",
|
||||
"extra-dns-labels": "DnsLabels",
|
||||
"dns-router-interval": "DnsRouteInterval",
|
||||
"mtu": "Mtu",
|
||||
"management-url": "ManagementUrl",
|
||||
"admin-url": "AdminURL",
|
||||
"enable-rosenpass": "RosenpassEnabled",
|
||||
"rosenpass-permissive": "RosenpassPermissive",
|
||||
"allow-server-ssh": "ServerSSHAllowed",
|
||||
"interface-name": "InterfaceName",
|
||||
"wireguard-port": "WireguardPort",
|
||||
"preshared-key": "OptionalPreSharedKey",
|
||||
"disable-auto-connect": "DisableAutoConnect",
|
||||
"network-monitor": "NetworkMonitor",
|
||||
"disable-client-routes": "DisableClientRoutes",
|
||||
"disable-server-routes": "DisableServerRoutes",
|
||||
"disable-dns": "DisableDns",
|
||||
"disable-firewall": "DisableFirewall",
|
||||
"block-lan-access": "BlockLanAccess",
|
||||
"block-inbound": "BlockInbound",
|
||||
"enable-lazy-connection": "LazyConnectionEnabled",
|
||||
"external-ip-map": "NatExternalIPs",
|
||||
"dns-resolver-address": "CustomDNSAddress",
|
||||
"extra-iface-blacklist": "ExtraIFaceBlacklist",
|
||||
"extra-dns-labels": "DnsLabels",
|
||||
"dns-router-interval": "DnsRouteInterval",
|
||||
"mtu": "Mtu",
|
||||
"enable-ssh-root": "EnableSSHRoot",
|
||||
"enable-ssh-sftp": "EnableSSHSFTP",
|
||||
"enable-ssh-local-port-forwarding": "EnableSSHLocalPortForwarding",
|
||||
"enable-ssh-remote-port-forwarding": "EnableSSHRemotePortForwarding",
|
||||
"disable-ssh-auth": "DisableSSHAuth",
|
||||
"ssh-jwt-cache-ttl": "SshJWTCacheTTL",
|
||||
}
|
||||
|
||||
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).
|
||||
|
||||
@@ -6,9 +6,11 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh/config"
|
||||
)
|
||||
|
||||
func registerStates(mgr *statemanager.Manager) {
|
||||
mgr.RegisterState(&dns.ShutdownState{})
|
||||
mgr.RegisterState(&systemops.ShutdownState{})
|
||||
mgr.RegisterState(&config.ShutdownState{})
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh/config"
|
||||
)
|
||||
|
||||
func registerStates(mgr *statemanager.Manager) {
|
||||
@@ -15,4 +16,5 @@ func registerStates(mgr *statemanager.Manager) {
|
||||
mgr.RegisterState(&systemops.ShutdownState{})
|
||||
mgr.RegisterState(&nftables.ShutdownState{})
|
||||
mgr.RegisterState(&iptables.ShutdownState{})
|
||||
mgr.RegisterState(&config.ShutdownState{})
|
||||
}
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// Client wraps crypto/ssh Client to simplify usage
|
||||
type Client struct {
|
||||
client *ssh.Client
|
||||
}
|
||||
|
||||
// Close closes the wrapped SSH Client
|
||||
func (c *Client) Close() error {
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
// OpenTerminal starts an interactive terminal session with the remote SSH server
|
||||
func (c *Client) OpenTerminal() error {
|
||||
session, err := c.client.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open new session: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
err := session.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
fd := int(os.Stdout.Fd())
|
||||
state, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to run raw terminal: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err := term.Restore(fd, state)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
w, h, err := term.GetSize(fd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("terminal get size: %s", err)
|
||||
}
|
||||
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
}
|
||||
|
||||
terminal := os.Getenv("TERM")
|
||||
if terminal == "" {
|
||||
terminal = "xterm-256color"
|
||||
}
|
||||
if err := session.RequestPty(terminal, h, w, modes); err != nil {
|
||||
return fmt.Errorf("failed requesting pty session with xterm: %s", err)
|
||||
}
|
||||
|
||||
session.Stdout = os.Stdout
|
||||
session.Stderr = os.Stderr
|
||||
session.Stdin = os.Stdin
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
return fmt.Errorf("failed to start login shell on the remote host: %s", err)
|
||||
}
|
||||
|
||||
if err := session.Wait(); err != nil {
|
||||
if e, ok := err.(*ssh.ExitError); ok {
|
||||
if e.ExitStatus() == 130 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed running SSH session: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DialWithKey connects to the remote SSH server with a provided private key file (PEM).
|
||||
func DialWithKey(addr, user string, privateKey []byte) (*Client, error) {
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: user,
|
||||
Timeout: 5 * time.Second,
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.PublicKeys(signer),
|
||||
},
|
||||
HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }),
|
||||
}
|
||||
|
||||
return Dial("tcp", addr, config)
|
||||
}
|
||||
|
||||
// Dial connects to the remote SSH server.
|
||||
func Dial(network, addr string, config *ssh.ClientConfig) (*Client, error) {
|
||||
client, err := ssh.Dial(network, addr, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Client{
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
699
client/ssh/client/client.go
Normal file
699
client/ssh/client/client.go
Normal file
@@ -0,0 +1,699 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/knownhosts"
|
||||
"golang.org/x/term"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultDaemonAddr is the default address for the NetBird daemon
|
||||
DefaultDaemonAddr = "unix:///var/run/netbird.sock"
|
||||
// DefaultDaemonAddrWindows is the default address for the NetBird daemon on Windows
|
||||
DefaultDaemonAddrWindows = "tcp://127.0.0.1:41731"
|
||||
)
|
||||
|
||||
// Client wraps crypto/ssh Client for simplified SSH operations
|
||||
type Client struct {
|
||||
client *ssh.Client
|
||||
terminalState *term.State
|
||||
terminalFd int
|
||||
|
||||
windowsStdoutMode uint32 // nolint:unused
|
||||
windowsStdinMode uint32 // nolint:unused
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
func (c *Client) OpenTerminal(ctx context.Context) error {
|
||||
session, err := c.client.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("new session: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := session.Close(); err != nil {
|
||||
log.Debugf("session close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := c.setupTerminalMode(ctx, session); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.setupSessionIO(session)
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
return fmt.Errorf("start shell: %w", err)
|
||||
}
|
||||
|
||||
return c.waitForSession(ctx, session)
|
||||
}
|
||||
|
||||
// setupSessionIO connects session streams to local terminal
|
||||
func (c *Client) setupSessionIO(session *ssh.Session) {
|
||||
session.Stdout = os.Stdout
|
||||
session.Stderr = os.Stderr
|
||||
session.Stdin = os.Stdin
|
||||
}
|
||||
|
||||
// waitForSession waits for the session to complete with context cancellation
|
||||
func (c *Client) waitForSession(ctx context.Context, session *ssh.Session) error {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- session.Wait()
|
||||
}()
|
||||
|
||||
defer c.restoreTerminal()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
return c.handleSessionError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleSessionError processes session termination errors
|
||||
func (c *Client) handleSessionError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var e *ssh.ExitError
|
||||
var em *ssh.ExitMissingError
|
||||
if !errors.As(err, &e) && !errors.As(err, &em) {
|
||||
return fmt.Errorf("session wait: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreTerminal restores the terminal to its original state
|
||||
func (c *Client) restoreTerminal() {
|
||||
if c.terminalState != nil {
|
||||
_ = term.Restore(c.terminalFd, c.terminalState)
|
||||
c.terminalState = nil
|
||||
c.terminalFd = 0
|
||||
}
|
||||
|
||||
if err := c.restoreWindowsConsoleState(); err != nil {
|
||||
log.Debugf("restore Windows console state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteCommand executes a command on the remote host and returns the output
|
||||
func (c *Client) ExecuteCommand(ctx context.Context, command string) ([]byte, error) {
|
||||
session, cleanup, err := c.createSession(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
output, err := session.CombinedOutput(command)
|
||||
if err != nil {
|
||||
var e *ssh.ExitError
|
||||
var em *ssh.ExitMissingError
|
||||
if !errors.As(err, &e) && !errors.As(err, &em) {
|
||||
return output, fmt.Errorf("execute command: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// ExecuteCommandWithIO executes a command with interactive I/O connected to local terminal
|
||||
func (c *Client) ExecuteCommandWithIO(ctx context.Context, command string) error {
|
||||
session, cleanup, err := c.createSession(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
c.setupSessionIO(session)
|
||||
|
||||
if err := session.Start(command); err != nil {
|
||||
return fmt.Errorf("start command: %w", err)
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- session.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = session.Signal(ssh.SIGTERM)
|
||||
select {
|
||||
case <-done:
|
||||
return ctx.Err()
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return ctx.Err()
|
||||
}
|
||||
case err := <-done:
|
||||
return c.handleCommandError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteCommandWithPTY executes a command with a pseudo-terminal for interactive sessions
|
||||
func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) error {
|
||||
session, cleanup, err := c.createSession(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
if err := c.setupTerminalMode(ctx, session); err != nil {
|
||||
return fmt.Errorf("setup terminal mode: %w", err)
|
||||
}
|
||||
|
||||
c.setupSessionIO(session)
|
||||
|
||||
if err := session.Start(command); err != nil {
|
||||
return fmt.Errorf("start command: %w", err)
|
||||
}
|
||||
|
||||
defer c.restoreTerminal()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- session.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = session.Signal(ssh.SIGTERM)
|
||||
select {
|
||||
case <-done:
|
||||
return ctx.Err()
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return ctx.Err()
|
||||
}
|
||||
case err := <-done:
|
||||
return c.handleCommandError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleCommandError processes command execution errors
|
||||
func (c *Client) handleCommandError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var e *ssh.ExitError
|
||||
var em *ssh.ExitMissingError
|
||||
if errors.As(err, &e) || errors.As(err, &em) {
|
||||
return err
|
||||
}
|
||||
|
||||
return fmt.Errorf("execute command: %w", err)
|
||||
}
|
||||
|
||||
// setupContextCancellation sets up context cancellation for a session
|
||||
func (c *Client) setupContextCancellation(ctx context.Context, session *ssh.Session) func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = session.Signal(ssh.SIGTERM)
|
||||
_ = session.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
return func() { close(done) }
|
||||
}
|
||||
|
||||
// createSession creates a new SSH session with context cancellation setup
|
||||
func (c *Client) createSession(ctx context.Context) (*ssh.Session, func(), error) {
|
||||
session, err := c.client.NewSession()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("new session: %w", err)
|
||||
}
|
||||
|
||||
cancel := c.setupContextCancellation(ctx, session)
|
||||
cleanup := func() {
|
||||
cancel()
|
||||
_ = session.Close()
|
||||
}
|
||||
|
||||
return session, cleanup, nil
|
||||
}
|
||||
|
||||
// getDefaultDaemonAddr returns the daemon address from environment or default for the OS
|
||||
func getDefaultDaemonAddr() string {
|
||||
if addr := os.Getenv("NB_DAEMON_ADDR"); addr != "" {
|
||||
return addr
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
return DefaultDaemonAddrWindows
|
||||
}
|
||||
return DefaultDaemonAddr
|
||||
}
|
||||
|
||||
// DialOptions contains options for SSH connections
|
||||
type DialOptions struct {
|
||||
KnownHostsFile string
|
||||
IdentityFile string
|
||||
DaemonAddr string
|
||||
SkipCachedToken bool
|
||||
InsecureSkipVerify bool
|
||||
}
|
||||
|
||||
// Dial connects to the given ssh server with specified options
|
||||
func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
|
||||
daemonAddr := opts.DaemonAddr
|
||||
if daemonAddr == "" {
|
||||
daemonAddr = getDefaultDaemonAddr()
|
||||
}
|
||||
opts.DaemonAddr = daemonAddr
|
||||
|
||||
hostKeyCallback, err := createHostKeyCallback(opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create host key callback: %w", err)
|
||||
}
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: user,
|
||||
Timeout: 30 * time.Second,
|
||||
HostKeyCallback: hostKeyCallback,
|
||||
}
|
||||
|
||||
if opts.IdentityFile != "" {
|
||||
authMethod, err := createSSHKeyAuth(opts.IdentityFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create SSH key auth: %w", err)
|
||||
}
|
||||
config.Auth = append(config.Auth, authMethod)
|
||||
}
|
||||
|
||||
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken)
|
||||
}
|
||||
|
||||
// dialSSH establishes an SSH connection without JWT authentication
|
||||
func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) {
|
||||
dialer := &net.Dialer{}
|
||||
conn, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial %s: %w", addr, err)
|
||||
}
|
||||
|
||||
clientConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
|
||||
if err != nil {
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
log.Debugf("connection close after handshake failure: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("ssh handshake: %w", err)
|
||||
}
|
||||
|
||||
client := ssh.NewClient(clientConn, chans, reqs)
|
||||
return &Client{
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection
|
||||
func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) {
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse address %s: %w", addr, err)
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse port %s: %w", portStr, err)
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SSH server detection failed: %w", err)
|
||||
}
|
||||
|
||||
if !serverType.RequiresJWT() {
|
||||
return dialSSH(ctx, network, addr, config)
|
||||
}
|
||||
|
||||
jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request JWT token: %w", err)
|
||||
}
|
||||
|
||||
configWithJWT := nbssh.AddJWTAuth(config, jwtToken)
|
||||
return dialSSH(ctx, network, addr, configWithJWT)
|
||||
}
|
||||
|
||||
// requestJWTToken requests a JWT token from the NetBird daemon
|
||||
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) {
|
||||
hint := profilemanager.GetLoginHint()
|
||||
|
||||
conn, err := connectToDaemon(daemonAddr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint)
|
||||
}
|
||||
|
||||
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
|
||||
func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
|
||||
conn, err := connectToDaemon(daemonAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("daemon connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
verifier := nbssh.NewDaemonHostKeyVerifier(client)
|
||||
callback := nbssh.CreateHostKeyCallback(verifier)
|
||||
return callback(hostname, remote, key)
|
||||
}
|
||||
|
||||
func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) {
|
||||
addr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||
|
||||
conn, err := grpc.NewClient(
|
||||
addr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
if err != nil {
|
||||
log.Debugf("failed to create gRPC client for NetBird daemon at %s: %v", daemonAddr, err)
|
||||
return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// getKnownHostsFiles returns paths to known_hosts files in order of preference
|
||||
func getKnownHostsFiles() []string {
|
||||
var files []string
|
||||
|
||||
// User's known_hosts file (highest priority)
|
||||
if homeDir, err := os.UserHomeDir(); err == nil {
|
||||
userKnownHosts := filepath.Join(homeDir, ".ssh", "known_hosts")
|
||||
files = append(files, userKnownHosts)
|
||||
}
|
||||
|
||||
// NetBird managed known_hosts files
|
||||
if runtime.GOOS == "windows" {
|
||||
programData := os.Getenv("PROGRAMDATA")
|
||||
if programData == "" {
|
||||
programData = `C:\ProgramData`
|
||||
}
|
||||
netbirdKnownHosts := filepath.Join(programData, "ssh", "ssh_known_hosts.d", "99-netbird")
|
||||
files = append(files, netbirdKnownHosts)
|
||||
} else {
|
||||
files = append(files, "/etc/ssh/ssh_known_hosts.d/99-netbird")
|
||||
files = append(files, "/etc/ssh/ssh_known_hosts")
|
||||
}
|
||||
|
||||
return files
|
||||
}
|
||||
|
||||
// createHostKeyCallback creates a host key verification callback
|
||||
func createHostKeyCallback(opts DialOptions) (ssh.HostKeyCallback, error) {
|
||||
if opts.InsecureSkipVerify {
|
||||
return ssh.InsecureIgnoreHostKey(), nil // #nosec G106 - User explicitly requested insecure mode
|
||||
}
|
||||
|
||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil {
|
||||
return nil
|
||||
}
|
||||
return tryKnownHostsVerification(hostname, remote, key, opts.KnownHostsFile)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func tryDaemonVerification(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
|
||||
if daemonAddr == "" {
|
||||
return fmt.Errorf("no daemon address")
|
||||
}
|
||||
return verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr)
|
||||
}
|
||||
|
||||
func tryKnownHostsVerification(hostname string, remote net.Addr, key ssh.PublicKey, knownHostsFile string) error {
|
||||
knownHostsFiles := getKnownHostsFilesList(knownHostsFile)
|
||||
hostKeyCallbacks := buildHostKeyCallbacks(knownHostsFiles)
|
||||
|
||||
for _, callback := range hostKeyCallbacks {
|
||||
if err := callback(hostname, remote, key); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("host key verification failed: key for %s not found in any known_hosts file", hostname)
|
||||
}
|
||||
|
||||
func getKnownHostsFilesList(knownHostsFile string) []string {
|
||||
if knownHostsFile != "" {
|
||||
return []string{knownHostsFile}
|
||||
}
|
||||
return getKnownHostsFiles()
|
||||
}
|
||||
|
||||
func buildHostKeyCallbacks(knownHostsFiles []string) []ssh.HostKeyCallback {
|
||||
var hostKeyCallbacks []ssh.HostKeyCallback
|
||||
for _, file := range knownHostsFiles {
|
||||
if callback, err := knownhosts.New(file); err == nil {
|
||||
hostKeyCallbacks = append(hostKeyCallbacks, callback)
|
||||
}
|
||||
}
|
||||
return hostKeyCallbacks
|
||||
}
|
||||
|
||||
// createSSHKeyAuth creates SSH key authentication from a private key file
|
||||
func createSSHKeyAuth(keyFile string) (ssh.AuthMethod, error) {
|
||||
keyData, err := os.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read SSH key file %s: %w", keyFile, err)
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(keyData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse SSH private key: %w", err)
|
||||
}
|
||||
|
||||
return ssh.PublicKeys(signer), nil
|
||||
}
|
||||
|
||||
// LocalPortForward sets up local port forwarding, binding to localAddr and forwarding to remoteAddr
|
||||
func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr string) error {
|
||||
localListener, err := net.Listen("tcp", localAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", localAddr, err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
log.Debugf("local listener close error: %v", err)
|
||||
}
|
||||
}()
|
||||
for {
|
||||
localConn, err := localListener.Accept()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
go c.handleLocalForward(localConn, remoteAddr)
|
||||
}
|
||||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
log.Debugf("local listener close error: %v", err)
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// handleLocalForward handles a single local port forwarding connection
|
||||
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
|
||||
defer func() {
|
||||
if err := localConn.Close(); err != nil {
|
||||
log.Debugf("local connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
channel, err := c.client.Dial("tcp", remoteAddr)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "administratively prohibited") {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n")
|
||||
} else {
|
||||
log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := channel.Close(); err != nil {
|
||||
log.Debugf("remote channel close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(channel, localConn); err != nil {
|
||||
log.Debugf("local forward copy error (local->remote): %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := io.Copy(localConn, channel); err != nil {
|
||||
log.Debugf("local forward copy error (remote->local): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
|
||||
func (c *Client) RemotePortForward(ctx context.Context, remoteAddr, localAddr string) error {
|
||||
host, port, err := c.parseRemoteAddress(remoteAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse remote address: %w", err)
|
||||
}
|
||||
|
||||
req := c.buildTCPIPForwardRequest(host, port)
|
||||
if err := c.sendTCPIPForwardRequest(req); err != nil {
|
||||
return fmt.Errorf("setup remote forward: %w", err)
|
||||
}
|
||||
|
||||
go c.handleRemoteForwardChannels(ctx, localAddr)
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
if err := c.cancelTCPIPForwardRequest(req); err != nil {
|
||||
return fmt.Errorf("cancel tcpip-forward: %w", err)
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// parseRemoteAddress parses host and port from remote address string
|
||||
func (c *Client) parseRemoteAddress(remoteAddr string) (string, uint32, error) {
|
||||
host, portStr, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("parse remote address %s: %w", remoteAddr, err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("parse remote port %s: %w", portStr, err)
|
||||
}
|
||||
|
||||
return host, uint32(port), nil
|
||||
}
|
||||
|
||||
// buildTCPIPForwardRequest creates a tcpip-forward request message
|
||||
func (c *Client) buildTCPIPForwardRequest(host string, port uint32) tcpipForwardMsg {
|
||||
return tcpipForwardMsg{
|
||||
Host: host,
|
||||
Port: port,
|
||||
}
|
||||
}
|
||||
|
||||
// sendTCPIPForwardRequest sends the tcpip-forward request to establish remote port forwarding
|
||||
func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error {
|
||||
ok, _, err := c.client.SendRequest("tcpip-forward", true, ssh.Marshal(&req))
|
||||
if err != nil {
|
||||
return fmt.Errorf("send tcpip-forward request: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cancelTCPIPForwardRequest cancels the tcpip-forward request
|
||||
func (c *Client) cancelTCPIPForwardRequest(req tcpipForwardMsg) error {
|
||||
_, _, err := c.client.SendRequest("cancel-tcpip-forward", true, ssh.Marshal(&req))
|
||||
if err != nil {
|
||||
return fmt.Errorf("send cancel-tcpip-forward request: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleRemoteForwardChannels handles incoming forwarded-tcpip channels
|
||||
func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr string) {
|
||||
// Get the channel once - subsequent calls return nil!
|
||||
channelRequests := c.client.HandleChannelOpen("forwarded-tcpip")
|
||||
if channelRequests == nil {
|
||||
log.Debugf("forwarded-tcpip channel type already being handled")
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case newChan := <-channelRequests:
|
||||
if newChan != nil {
|
||||
go c.handleRemoteForwardChannel(newChan, localAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleRemoteForwardChannel handles a single forwarded-tcpip channel
|
||||
func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) {
|
||||
channel, reqs, err := newChan.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := channel.Close(); err != nil {
|
||||
log.Debugf("remote channel close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
localConn, err := net.Dial("tcp", localAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := localConn.Close(); err != nil {
|
||||
log.Debugf("local connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(localConn, channel); err != nil {
|
||||
log.Debugf("remote forward copy error (remote->local): %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := io.Copy(channel, localConn); err != nil {
|
||||
log.Debugf("remote forward copy error (local->remote): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// tcpipForwardMsg represents the structure for tcpip-forward requests
|
||||
type tcpipForwardMsg struct {
|
||||
Host string
|
||||
Port uint32
|
||||
}
|
||||
512
client/ssh/client/client_test.go
Normal file
512
client/ssh/client/client_test.go
Normal file
@@ -0,0 +1,512 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
)
|
||||
|
||||
// TestMain handles package-level setup and cleanup
|
||||
func TestMain(m *testing.M) {
|
||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
||||
// This happens when running tests as non-privileged user with fallback
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||
// Just exit with error to break the recursion
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
code := m.Run()
|
||||
|
||||
// Cleanup any created test users
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestSSHClient_DialWithKey(t *testing.T) {
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create and start server
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Test Dial
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Verify client is connected
|
||||
assert.NotNil(t, client.client)
|
||||
}
|
||||
|
||||
func TestSSHClient_CommandExecution(t *testing.T) {
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
server, _, client := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Run("ExecuteCommand captures output", func(t *testing.T) {
|
||||
output, err := client.ExecuteCommand(ctx, "echo hello")
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(output), "hello")
|
||||
})
|
||||
|
||||
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
|
||||
err := client.ExecuteCommandWithIO(ctx, "echo world")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("commands with flags work", func(t *testing.T) {
|
||||
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
|
||||
})
|
||||
|
||||
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
|
||||
var testCmd string
|
||||
if runtime.GOOS == "windows" {
|
||||
testCmd = "echo hello | Select-String notfound"
|
||||
} else {
|
||||
testCmd = "echo 'hello' | grep 'notfound'"
|
||||
}
|
||||
_, err := client.ExecuteCommand(ctx, testCmd)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_ConnectionHandling(t *testing.T) {
|
||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Generate client key for multiple connections
|
||||
|
||||
const numClients = 3
|
||||
clients := make([]*Client, numClients)
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
for i := 0; i < numClients; i++ {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
cancel()
|
||||
require.NoError(t, err, "Client %d should connect successfully", i)
|
||||
clients[i] = client
|
||||
}
|
||||
|
||||
for i, client := range clients {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "Client %d should close without error", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHClient_ContextCancellation(t *testing.T) {
|
||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
t.Run("connection with short timeout", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
// Check for actual timeout-related errors rather than string matching
|
||||
assert.True(t,
|
||||
errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
strings.Contains(err.Error(), "timeout"),
|
||||
"Expected timeout-related error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("command execution cancellation", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Logf("client close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cmdCancel()
|
||||
|
||||
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
|
||||
if err != nil {
|
||||
var exitMissingErr *cryptossh.ExitMissingError
|
||||
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
errors.As(err, &exitMissingErr)
|
||||
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_NoAuthMode(t *testing.T) {
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
|
||||
t.Run("any key succeeds in no-auth mode", func(t *testing.T) {
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
if client != nil {
|
||||
require.NoError(t, client.Close(), "Client should close without error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_TerminalState(t *testing.T) {
|
||||
server, _, client := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
assert.Nil(t, client.terminalState)
|
||||
assert.Equal(t, 0, client.terminalFd)
|
||||
|
||||
client.restoreTerminal()
|
||||
assert.Nil(t, client.terminalState)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := client.OpenTerminal(ctx)
|
||||
// In test environment without a real terminal, this may complete quickly or timeout
|
||||
// Both behaviors are acceptable for testing terminal state management
|
||||
if err != nil {
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.True(t,
|
||||
strings.Contains(err.Error(), "context deadline exceeded") ||
|
||||
strings.Contains(err.Error(), "console"),
|
||||
"Should timeout or have console error on Windows")
|
||||
} else {
|
||||
// On Unix systems in test environment, we may get various errors
|
||||
// including timeouts or terminal-related errors
|
||||
assert.True(t,
|
||||
strings.Contains(err.Error(), "context deadline exceeded") ||
|
||||
strings.Contains(err.Error(), "terminal") ||
|
||||
strings.Contains(err.Error(), "pty"),
|
||||
"Expected timeout or terminal-related error, got: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestSSHServerAndClient(t *testing.T) (*sshserver.Server, string, *Client) {
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return server, serverAddr, client
|
||||
}
|
||||
|
||||
func TestSSHClient_PortForwarding(t *testing.T) {
|
||||
server, _, client := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
t.Run("local forwarding times out gracefully", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := client.LocalPortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080")
|
||||
assert.Error(t, err)
|
||||
assert.True(t,
|
||||
errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
strings.Contains(err.Error(), "connection"),
|
||||
"Expected context or connection error")
|
||||
})
|
||||
|
||||
t.Run("remote forwarding denied", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := client.RemotePortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080")
|
||||
assert.Error(t, err)
|
||||
assert.True(t,
|
||||
strings.Contains(err.Error(), "denied") ||
|
||||
strings.Contains(err.Error(), "disabled"),
|
||||
"Should be denied by default")
|
||||
})
|
||||
|
||||
t.Run("invalid addresses fail", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := client.LocalPortForward(ctx, "invalid:address", "127.0.0.1:8080")
|
||||
assert.Error(t, err)
|
||||
|
||||
err = client.LocalPortForward(ctx, "127.0.0.1:0", "invalid:address")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping data transfer test in short mode")
|
||||
}
|
||||
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowLocalPortForwarding(true)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Port forwarding requires the actual current user, not test user
|
||||
realUser, err := getRealCurrentUser()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Skip if running as system account that can't do port forwarding
|
||||
if testutil.IsSystemAccount(realUser) {
|
||||
t.Skipf("Skipping port forwarding test - running as system account: %s", realUser)
|
||||
}
|
||||
|
||||
client, err := Dial(ctx, serverAddr, realUser, DialOptions{
|
||||
InsecureSkipVerify: true, // Skip host key verification for test
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Logf("client close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
testServer, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := testServer.Close(); err != nil {
|
||||
t.Logf("test server close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
testServerAddr := testServer.Addr().String()
|
||||
expectedResponse := "Hello, World!"
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := testServer.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
t.Logf("connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
buf := make([]byte, 1024)
|
||||
if _, err := c.Read(buf); err != nil {
|
||||
t.Logf("connection read error: %v", err)
|
||||
return
|
||||
}
|
||||
if _, err := c.Write([]byte(expectedResponse)); err != nil {
|
||||
t.Logf("connection write error: %v", err)
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
localListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
localAddr := localListener.Addr().String()
|
||||
if err := localListener.Close(); err != nil {
|
||||
t.Logf("local listener close error: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
err := client.LocalPortForward(ctx, localAddr, testServerAddr)
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
if isWindowsPrivilegeError(err) {
|
||||
t.Logf("Port forward failed due to Windows privilege restrictions: %v", err)
|
||||
} else {
|
||||
t.Logf("Port forward error: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
conn, err := net.DialTimeout("tcp", localAddr, 2*time.Second)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Logf("set read deadline error: %v", err)
|
||||
}
|
||||
response := make([]byte, len(expectedResponse))
|
||||
n, err := io.ReadFull(conn, response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(expectedResponse), n)
|
||||
assert.Equal(t, expectedResponse, string(response))
|
||||
}
|
||||
|
||||
// getRealCurrentUser returns the actual current user (not test user) for features like port forwarding
|
||||
func getRealCurrentUser() (string, error) {
|
||||
if runtime.GOOS == "windows" {
|
||||
if currentUser, err := user.Current(); err == nil {
|
||||
return currentUser.Username, nil
|
||||
}
|
||||
}
|
||||
|
||||
if username := os.Getenv("USER"); username != "" {
|
||||
return username, nil
|
||||
}
|
||||
|
||||
if currentUser, err := user.Current(); err == nil {
|
||||
return currentUser.Username, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("unable to determine current user")
|
||||
}
|
||||
|
||||
// isWindowsPrivilegeError checks if an error is related to Windows privilege restrictions
|
||||
func isWindowsPrivilegeError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
errStr := strings.ToLower(err.Error())
|
||||
return strings.Contains(errStr, "ntstatus=0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD
|
||||
strings.Contains(errStr, "0xc0000041") || // STATUS_PRIVILEGE_NOT_HELD (LsaRegisterLogonProcess)
|
||||
strings.Contains(errStr, "0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD (LsaLogonUser)
|
||||
strings.Contains(errStr, "privilege") ||
|
||||
strings.Contains(errStr, "access denied") ||
|
||||
strings.Contains(errStr, "user authentication failed")
|
||||
}
|
||||
127
client/ssh/client/terminal_unix.go
Normal file
127
client/ssh/client/terminal_unix.go
Normal file
@@ -0,0 +1,127 @@
|
||||
//go:build !windows
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) error {
|
||||
stdinFd := int(os.Stdin.Fd())
|
||||
|
||||
if !term.IsTerminal(stdinFd) {
|
||||
return c.setupNonTerminalMode(ctx, session)
|
||||
}
|
||||
|
||||
fd := int(os.Stdin.Fd())
|
||||
|
||||
state, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return c.setupNonTerminalMode(ctx, session)
|
||||
}
|
||||
|
||||
if err := c.setupTerminal(session, fd); err != nil {
|
||||
if restoreErr := term.Restore(fd, state); restoreErr != nil {
|
||||
log.Debugf("restore terminal state: %v", restoreErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
c.terminalState = state
|
||||
c.terminalFd = fd
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
|
||||
|
||||
go func() {
|
||||
defer signal.Stop(sigChan)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if err := term.Restore(fd, state); err != nil {
|
||||
log.Debugf("restore terminal state: %v", err)
|
||||
}
|
||||
case sig := <-sigChan:
|
||||
if err := term.Restore(fd, state); err != nil {
|
||||
log.Debugf("restore terminal state: %v", err)
|
||||
}
|
||||
signal.Reset(sig)
|
||||
s, ok := sig.(syscall.Signal)
|
||||
if !ok {
|
||||
log.Debugf("signal %v is not a syscall.Signal: %T", sig, sig)
|
||||
return
|
||||
}
|
||||
if err := syscall.Kill(syscall.Getpid(), s); err != nil {
|
||||
log.Debugf("kill process with signal %v: %v", s, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) setupNonTerminalMode(_ context.Context, session *ssh.Session) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreWindowsConsoleState is a no-op on Unix systems
|
||||
func (c *Client) restoreWindowsConsoleState() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) setupTerminal(session *ssh.Session, fd int) error {
|
||||
w, h, err := term.GetSize(fd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get terminal size: %w", err)
|
||||
}
|
||||
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
// Ctrl+C
|
||||
ssh.VINTR: 3,
|
||||
// Ctrl+\
|
||||
ssh.VQUIT: 28,
|
||||
// Backspace
|
||||
ssh.VERASE: 127,
|
||||
// Ctrl+U
|
||||
ssh.VKILL: 21,
|
||||
// Ctrl+D
|
||||
ssh.VEOF: 4,
|
||||
ssh.VEOL: 0,
|
||||
ssh.VEOL2: 0,
|
||||
// Ctrl+Q
|
||||
ssh.VSTART: 17,
|
||||
// Ctrl+S
|
||||
ssh.VSTOP: 19,
|
||||
// Ctrl+Z
|
||||
ssh.VSUSP: 26,
|
||||
// Ctrl+O
|
||||
ssh.VDISCARD: 15,
|
||||
// Ctrl+R
|
||||
ssh.VREPRINT: 18,
|
||||
// Ctrl+W
|
||||
ssh.VWERASE: 23,
|
||||
// Ctrl+V
|
||||
ssh.VLNEXT: 22,
|
||||
}
|
||||
|
||||
terminal := os.Getenv("TERM")
|
||||
if terminal == "" {
|
||||
terminal = "xterm-256color"
|
||||
}
|
||||
|
||||
if err := session.RequestPty(terminal, h, w, modes); err != nil {
|
||||
return fmt.Errorf("request pty: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
265
client/ssh/client/terminal_windows.go
Normal file
265
client/ssh/client/terminal_windows.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
enableProcessedInput = 0x0001
|
||||
enableLineInput = 0x0002
|
||||
enableEchoInput = 0x0004 // Input mode: ENABLE_ECHO_INPUT
|
||||
enableVirtualTerminalProcessing = 0x0004 // Output mode: ENABLE_VIRTUAL_TERMINAL_PROCESSING (same value, different mode)
|
||||
enableVirtualTerminalInput = 0x0200
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
|
||||
procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
|
||||
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
|
||||
)
|
||||
|
||||
// ConsoleUnavailableError indicates that Windows console handles are not available
|
||||
// (e.g., in CI environments where stdout/stdin are redirected)
|
||||
type ConsoleUnavailableError struct {
|
||||
Operation string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *ConsoleUnavailableError) Error() string {
|
||||
return fmt.Sprintf("console unavailable for %s: %v", e.Operation, e.Err)
|
||||
}
|
||||
|
||||
func (e *ConsoleUnavailableError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type coord struct {
|
||||
x, y int16
|
||||
}
|
||||
|
||||
type smallRect struct {
|
||||
left, top, right, bottom int16
|
||||
}
|
||||
|
||||
type consoleScreenBufferInfo struct {
|
||||
size coord
|
||||
cursorPosition coord
|
||||
attributes uint16
|
||||
window smallRect
|
||||
maximumWindowSize coord
|
||||
}
|
||||
|
||||
func (c *Client) setupTerminalMode(_ context.Context, session *ssh.Session) error {
|
||||
if err := c.saveWindowsConsoleState(); err != nil {
|
||||
var consoleErr *ConsoleUnavailableError
|
||||
if errors.As(err, &consoleErr) {
|
||||
log.Debugf("console unavailable, not requesting PTY: %v", err)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("save console state: %w", err)
|
||||
}
|
||||
|
||||
if err := c.enableWindowsVirtualTerminal(); err != nil {
|
||||
var consoleErr *ConsoleUnavailableError
|
||||
if errors.As(err, &consoleErr) {
|
||||
log.Debugf("virtual terminal unavailable: %v", err)
|
||||
} else {
|
||||
return fmt.Errorf("failed to enable virtual terminal: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
w, h := c.getWindowsConsoleSize()
|
||||
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
ssh.ICRNL: 1,
|
||||
ssh.OPOST: 1,
|
||||
ssh.ONLCR: 1,
|
||||
ssh.ISIG: 1,
|
||||
ssh.ICANON: 1,
|
||||
ssh.VINTR: 3, // Ctrl+C
|
||||
ssh.VQUIT: 28, // Ctrl+\
|
||||
ssh.VERASE: 127, // Backspace
|
||||
ssh.VKILL: 21, // Ctrl+U
|
||||
ssh.VEOF: 4, // Ctrl+D
|
||||
ssh.VEOL: 0,
|
||||
ssh.VEOL2: 0,
|
||||
ssh.VSTART: 17, // Ctrl+Q
|
||||
ssh.VSTOP: 19, // Ctrl+S
|
||||
ssh.VSUSP: 26, // Ctrl+Z
|
||||
ssh.VDISCARD: 15, // Ctrl+O
|
||||
ssh.VWERASE: 23, // Ctrl+W
|
||||
ssh.VLNEXT: 22, // Ctrl+V
|
||||
ssh.VREPRINT: 18, // Ctrl+R
|
||||
}
|
||||
|
||||
if err := session.RequestPty("xterm-256color", h, w, modes); err != nil {
|
||||
if restoreErr := c.restoreWindowsConsoleState(); restoreErr != nil {
|
||||
log.Debugf("restore Windows console state: %v", restoreErr)
|
||||
}
|
||||
return fmt.Errorf("request pty: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) saveWindowsConsoleState() error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Debugf("panic in saveWindowsConsoleState: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
stdout := syscall.Handle(os.Stdout.Fd())
|
||||
stdin := syscall.Handle(os.Stdin.Fd())
|
||||
|
||||
var stdoutMode, stdinMode uint32
|
||||
|
||||
ret, _, err := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&stdoutMode)))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to get stdout console mode: %v", err)
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "get stdout console mode",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
ret, _, err = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&stdinMode)))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to get stdin console mode: %v", err)
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "get stdin console mode",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
c.terminalFd = 1
|
||||
c.windowsStdoutMode = stdoutMode
|
||||
c.windowsStdinMode = stdinMode
|
||||
|
||||
log.Debugf("saved Windows console state - stdout: 0x%04x, stdin: 0x%04x", stdoutMode, stdinMode)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) enableWindowsVirtualTerminal() (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in enableWindowsVirtualTerminal: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
stdout := syscall.Handle(os.Stdout.Fd())
|
||||
stdin := syscall.Handle(os.Stdin.Fd())
|
||||
var mode uint32
|
||||
|
||||
ret, _, winErr := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&mode)))
|
||||
if ret == 0 {
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "get stdout console mode for VT",
|
||||
Err: winErr,
|
||||
}
|
||||
}
|
||||
|
||||
mode |= enableVirtualTerminalProcessing
|
||||
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdout), uintptr(mode))
|
||||
if ret == 0 {
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "enable virtual terminal processing",
|
||||
Err: winErr,
|
||||
}
|
||||
}
|
||||
|
||||
ret, _, winErr = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&mode)))
|
||||
if ret == 0 {
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "get stdin console mode for VT",
|
||||
Err: winErr,
|
||||
}
|
||||
}
|
||||
|
||||
mode &= ^uint32(enableLineInput | enableEchoInput | enableProcessedInput)
|
||||
mode |= enableVirtualTerminalInput
|
||||
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(mode))
|
||||
if ret == 0 {
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "set stdin raw mode",
|
||||
Err: winErr,
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("enabled Windows virtual terminal processing")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) getWindowsConsoleSize() (int, int) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Debugf("panic in getWindowsConsoleSize: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
stdout := syscall.Handle(os.Stdout.Fd())
|
||||
var csbi consoleScreenBufferInfo
|
||||
|
||||
ret, _, err := procGetConsoleScreenBufferInfo.Call(uintptr(stdout), uintptr(unsafe.Pointer(&csbi)))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to get console buffer info, using defaults: %v", err)
|
||||
return 80, 24
|
||||
}
|
||||
|
||||
width := int(csbi.window.right - csbi.window.left + 1)
|
||||
height := int(csbi.window.bottom - csbi.window.top + 1)
|
||||
|
||||
log.Debugf("Windows console size: %dx%d", width, height)
|
||||
return width, height
|
||||
}
|
||||
|
||||
func (c *Client) restoreWindowsConsoleState() error {
|
||||
var err error
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in restoreWindowsConsoleState: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if c.terminalFd != 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
stdout := syscall.Handle(os.Stdout.Fd())
|
||||
stdin := syscall.Handle(os.Stdin.Fd())
|
||||
|
||||
ret, _, winErr := procSetConsoleMode.Call(uintptr(stdout), uintptr(c.windowsStdoutMode))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to restore stdout console mode: %v", winErr)
|
||||
if err == nil {
|
||||
err = fmt.Errorf("restore stdout console mode: %w", winErr)
|
||||
}
|
||||
}
|
||||
|
||||
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(c.windowsStdinMode))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to restore stdin console mode: %v", winErr)
|
||||
if err == nil {
|
||||
err = fmt.Errorf("restore stdin console mode: %w", winErr)
|
||||
}
|
||||
}
|
||||
|
||||
c.terminalFd = 0
|
||||
c.windowsStdoutMode = 0
|
||||
c.windowsStdinMode = 0
|
||||
|
||||
log.Debugf("restored Windows console state")
|
||||
return err
|
||||
}
|
||||
171
client/ssh/common.go
Normal file
171
client/ssh/common.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
NetBirdSSHConfigFile = "99-netbird.conf"
|
||||
|
||||
UnixSSHConfigDir = "/etc/ssh/ssh_config.d"
|
||||
WindowsSSHConfigDir = "ssh/ssh_config.d"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrPeerNotFound indicates the peer was not found in the network
|
||||
ErrPeerNotFound = errors.New("peer not found in network")
|
||||
// ErrNoStoredKey indicates the peer has no stored SSH host key
|
||||
ErrNoStoredKey = errors.New("peer has no stored SSH host key")
|
||||
)
|
||||
|
||||
// HostKeyVerifier provides SSH host key verification
|
||||
type HostKeyVerifier interface {
|
||||
VerifySSHHostKey(peerAddress string, key []byte) error
|
||||
}
|
||||
|
||||
// DaemonHostKeyVerifier implements HostKeyVerifier using the NetBird daemon
|
||||
type DaemonHostKeyVerifier struct {
|
||||
client proto.DaemonServiceClient
|
||||
}
|
||||
|
||||
// NewDaemonHostKeyVerifier creates a new daemon-based host key verifier
|
||||
func NewDaemonHostKeyVerifier(client proto.DaemonServiceClient) *DaemonHostKeyVerifier {
|
||||
return &DaemonHostKeyVerifier{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// VerifySSHHostKey verifies an SSH host key by querying the NetBird daemon
|
||||
func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKey []byte) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
response, err := d.client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
|
||||
PeerAddress: peerAddress,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !response.GetFound() {
|
||||
return ErrPeerNotFound
|
||||
}
|
||||
|
||||
storedKeyData := response.GetSshHostKey()
|
||||
|
||||
return VerifyHostKey(storedKeyData, presentedKey, peerAddress)
|
||||
}
|
||||
|
||||
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
|
||||
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string) (string, error) {
|
||||
req := &proto.RequestJWTAuthRequest{}
|
||||
if hint != "" {
|
||||
req.Hint = &hint
|
||||
}
|
||||
authResponse, err := client.RequestJWTAuth(ctx, req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request JWT auth: %w", err)
|
||||
}
|
||||
|
||||
if useCache && authResponse.CachedToken != "" {
|
||||
log.Debug("Using cached authentication token")
|
||||
return authResponse.CachedToken, nil
|
||||
}
|
||||
|
||||
if stderr != nil {
|
||||
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
|
||||
_, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete)
|
||||
if authResponse.UserCode != "" {
|
||||
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
|
||||
}
|
||||
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
|
||||
}
|
||||
|
||||
tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{
|
||||
DeviceCode: authResponse.DeviceCode,
|
||||
UserCode: authResponse.UserCode,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("wait for JWT token: %w", err)
|
||||
}
|
||||
|
||||
if stdout != nil {
|
||||
_, _ = fmt.Fprintln(stdout, "Authentication successful!")
|
||||
}
|
||||
return tokenResponse.Token, nil
|
||||
}
|
||||
|
||||
// VerifyHostKey verifies an SSH host key against stored peer key data.
|
||||
// Returns nil only if the presented key matches the stored key.
|
||||
// Returns ErrNoStoredKey if storedKeyData is empty.
|
||||
// Returns an error if the keys don't match or if parsing fails.
|
||||
func VerifyHostKey(storedKeyData []byte, presentedKey []byte, peerAddress string) error {
|
||||
if len(storedKeyData) == 0 {
|
||||
return ErrNoStoredKey
|
||||
}
|
||||
|
||||
storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse stored SSH key for %s: %w", peerAddress, err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(presentedKey, storedPubKey.Marshal()) {
|
||||
return fmt.Errorf("SSH host key mismatch for %s", peerAddress)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddJWTAuth prepends JWT password authentication to existing auth methods.
|
||||
// This ensures JWT auth is tried first while preserving any existing auth methods.
|
||||
func AddJWTAuth(config *ssh.ClientConfig, jwtToken string) *ssh.ClientConfig {
|
||||
configWithJWT := *config
|
||||
configWithJWT.Auth = append([]ssh.AuthMethod{ssh.Password(jwtToken)}, config.Auth...)
|
||||
return &configWithJWT
|
||||
}
|
||||
|
||||
// CreateHostKeyCallback creates an SSH host key verification callback using the provided verifier.
|
||||
// It tries multiple addresses (hostname, IP) for the peer before failing.
|
||||
func CreateHostKeyCallback(verifier HostKeyVerifier) ssh.HostKeyCallback {
|
||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
addresses := buildAddressList(hostname, remote)
|
||||
presentedKey := key.Marshal()
|
||||
|
||||
for _, addr := range addresses {
|
||||
if err := verifier.VerifySSHHostKey(addr, presentedKey); err != nil {
|
||||
if errors.Is(err, ErrPeerNotFound) {
|
||||
// Try other addresses for this peer
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
// Verified
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("SSH host key verification failed: peer %s not found in network", hostname)
|
||||
}
|
||||
}
|
||||
|
||||
// buildAddressList creates a list of addresses to check for host key verification.
|
||||
// It includes the original hostname and extracts the host part from the remote address if different.
|
||||
func buildAddressList(hostname string, remote net.Addr) []string {
|
||||
addresses := []string{hostname}
|
||||
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
|
||||
if host != hostname {
|
||||
addresses = append(addresses, host)
|
||||
}
|
||||
}
|
||||
return addresses
|
||||
}
|
||||
282
client/ssh/config/manager.go
Normal file
282
client/ssh/config/manager.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
EnvDisableSSHConfig = "NB_DISABLE_SSH_CONFIG"
|
||||
|
||||
EnvForceSSHConfig = "NB_FORCE_SSH_CONFIG"
|
||||
|
||||
MaxPeersForSSHConfig = 200
|
||||
|
||||
fileWriteTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
func isSSHConfigDisabled() bool {
|
||||
value := os.Getenv(EnvDisableSSHConfig)
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
disabled, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
|
||||
func isSSHConfigForced() bool {
|
||||
value := os.Getenv(EnvForceSSHConfig)
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
forced, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return forced
|
||||
}
|
||||
|
||||
// shouldGenerateSSHConfig checks if SSH config should be generated based on peer count
|
||||
func shouldGenerateSSHConfig(peerCount int) bool {
|
||||
if isSSHConfigDisabled() {
|
||||
return false
|
||||
}
|
||||
|
||||
if isSSHConfigForced() {
|
||||
return true
|
||||
}
|
||||
|
||||
return peerCount <= MaxPeersForSSHConfig
|
||||
}
|
||||
|
||||
// writeFileWithTimeout writes data to a file with a timeout
|
||||
func writeFileWithTimeout(filename string, data []byte, perm os.FileMode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), fileWriteTimeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- os.WriteFile(filename, data, perm)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("file write timeout after %v: %s", fileWriteTimeout, filename)
|
||||
}
|
||||
}
|
||||
|
||||
// Manager handles SSH client configuration for NetBird peers
|
||||
type Manager struct {
|
||||
sshConfigDir string
|
||||
sshConfigFile string
|
||||
}
|
||||
|
||||
// PeerSSHInfo represents a peer's SSH configuration information
|
||||
type PeerSSHInfo struct {
|
||||
Hostname string
|
||||
IP string
|
||||
FQDN string
|
||||
}
|
||||
|
||||
// New creates a new SSH config manager
|
||||
func New() *Manager {
|
||||
sshConfigDir := getSystemSSHConfigDir()
|
||||
return &Manager{
|
||||
sshConfigDir: sshConfigDir,
|
||||
sshConfigFile: nbssh.NetBirdSSHConfigFile,
|
||||
}
|
||||
}
|
||||
|
||||
// getSystemSSHConfigDir returns platform-specific SSH configuration directory
|
||||
func getSystemSSHConfigDir() string {
|
||||
if runtime.GOOS == "windows" {
|
||||
return getWindowsSSHConfigDir()
|
||||
}
|
||||
return nbssh.UnixSSHConfigDir
|
||||
}
|
||||
|
||||
func getWindowsSSHConfigDir() string {
|
||||
programData := os.Getenv("PROGRAMDATA")
|
||||
if programData == "" {
|
||||
programData = `C:\ProgramData`
|
||||
}
|
||||
return filepath.Join(programData, nbssh.WindowsSSHConfigDir)
|
||||
}
|
||||
|
||||
// SetupSSHClientConfig creates SSH client configuration for NetBird peers
|
||||
func (m *Manager) SetupSSHClientConfig(peers []PeerSSHInfo) error {
|
||||
if !shouldGenerateSSHConfig(len(peers)) {
|
||||
m.logSkipReason(len(peers))
|
||||
return nil
|
||||
}
|
||||
|
||||
sshConfig, err := m.buildSSHConfig(peers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build SSH config: %w", err)
|
||||
}
|
||||
return m.writeSSHConfig(sshConfig)
|
||||
}
|
||||
|
||||
func (m *Manager) logSkipReason(peerCount int) {
|
||||
if isSSHConfigDisabled() {
|
||||
log.Debugf("SSH config management disabled via %s", EnvDisableSSHConfig)
|
||||
} else {
|
||||
log.Infof("SSH config generation skipped: too many peers (%d > %d). Use %s=true to force.",
|
||||
peerCount, MaxPeersForSSHConfig, EnvForceSSHConfig)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) buildSSHConfig(peers []PeerSSHInfo) (string, error) {
|
||||
sshConfig := m.buildConfigHeader()
|
||||
|
||||
var allHostPatterns []string
|
||||
for _, peer := range peers {
|
||||
hostPatterns := m.buildHostPatterns(peer)
|
||||
allHostPatterns = append(allHostPatterns, hostPatterns...)
|
||||
}
|
||||
|
||||
if len(allHostPatterns) > 0 {
|
||||
peerConfig, err := m.buildPeerConfig(allHostPatterns)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sshConfig += peerConfig
|
||||
}
|
||||
|
||||
return sshConfig, nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildConfigHeader() string {
|
||||
return "# NetBird SSH client configuration\n" +
|
||||
"# Generated automatically - do not edit manually\n" +
|
||||
"#\n" +
|
||||
"# To disable SSH config management, use:\n" +
|
||||
"# netbird service reconfigure --service-env NB_DISABLE_SSH_CONFIG=true\n" +
|
||||
"#\n\n"
|
||||
}
|
||||
|
||||
func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
|
||||
uniquePatterns := make(map[string]bool)
|
||||
var deduplicatedPatterns []string
|
||||
for _, pattern := range allHostPatterns {
|
||||
if !uniquePatterns[pattern] {
|
||||
uniquePatterns[pattern] = true
|
||||
deduplicatedPatterns = append(deduplicatedPatterns, pattern)
|
||||
}
|
||||
}
|
||||
|
||||
execPath, err := m.getNetBirdExecutablePath()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get NetBird executable path: %w", err)
|
||||
}
|
||||
|
||||
hostLine := strings.Join(deduplicatedPatterns, " ")
|
||||
config := fmt.Sprintf("Host %s\n", hostLine)
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
|
||||
} else {
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath)
|
||||
}
|
||||
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
||||
config += " PasswordAuthentication yes\n"
|
||||
config += " PubkeyAuthentication yes\n"
|
||||
config += " BatchMode no\n"
|
||||
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
|
||||
config += " StrictHostKeyChecking no\n"
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config += " UserKnownHostsFile NUL\n"
|
||||
} else {
|
||||
config += " UserKnownHostsFile /dev/null\n"
|
||||
}
|
||||
|
||||
config += " CheckHostIP no\n"
|
||||
config += " LogLevel ERROR\n\n"
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
|
||||
var hostPatterns []string
|
||||
if peer.IP != "" {
|
||||
hostPatterns = append(hostPatterns, peer.IP)
|
||||
}
|
||||
if peer.FQDN != "" {
|
||||
hostPatterns = append(hostPatterns, peer.FQDN)
|
||||
}
|
||||
if peer.Hostname != "" && peer.Hostname != peer.FQDN {
|
||||
hostPatterns = append(hostPatterns, peer.Hostname)
|
||||
}
|
||||
return hostPatterns
|
||||
}
|
||||
|
||||
func (m *Manager) writeSSHConfig(sshConfig string) error {
|
||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||
|
||||
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
|
||||
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
|
||||
}
|
||||
|
||||
if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
|
||||
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
|
||||
}
|
||||
|
||||
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveSSHClientConfig removes NetBird SSH configuration
|
||||
func (m *Manager) RemoveSSHClientConfig() error {
|
||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||
err := os.Remove(sshConfigPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("remove SSH config %s: %w", sshConfigPath, err)
|
||||
}
|
||||
if err == nil {
|
||||
log.Infof("Removed NetBird SSH config: %s", sshConfigPath)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) getNetBirdExecutablePath() (string, error) {
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("retrieve executable path: %w", err)
|
||||
}
|
||||
|
||||
realPath, err := filepath.EvalSymlinks(execPath)
|
||||
if err != nil {
|
||||
log.Debugf("symlink resolution failed: %v", err)
|
||||
return execPath, nil
|
||||
}
|
||||
|
||||
return realPath, nil
|
||||
}
|
||||
|
||||
// GetSSHConfigDir returns the SSH config directory path
|
||||
func (m *Manager) GetSSHConfigDir() string {
|
||||
return m.sshConfigDir
|
||||
}
|
||||
|
||||
// GetSSHConfigFile returns the SSH config file name
|
||||
func (m *Manager) GetSSHConfigFile() string {
|
||||
return m.sshConfigFile
|
||||
}
|
||||
159
client/ssh/config/manager_test.go
Normal file
159
client/ssh/config/manager_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestManager_SetupSSHClientConfig(t *testing.T) {
|
||||
// Create temporary directory for test
|
||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||
require.NoError(t, err)
|
||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
||||
|
||||
// Override manager paths to use temp directory
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
}
|
||||
|
||||
// Test SSH config generation with peers
|
||||
peers := []PeerSSHInfo{
|
||||
{
|
||||
Hostname: "peer1",
|
||||
IP: "100.125.1.1",
|
||||
FQDN: "peer1.nb.internal",
|
||||
},
|
||||
{
|
||||
Hostname: "peer2",
|
||||
IP: "100.125.1.2",
|
||||
FQDN: "peer2.nb.internal",
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read generated config
|
||||
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
|
||||
content, err := os.ReadFile(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
configStr := string(content)
|
||||
|
||||
// Verify the basic SSH config structure exists
|
||||
assert.Contains(t, configStr, "# NetBird SSH client configuration")
|
||||
assert.Contains(t, configStr, "Generated automatically - do not edit manually")
|
||||
|
||||
// Check that peer hostnames are included
|
||||
assert.Contains(t, configStr, "100.125.1.1")
|
||||
assert.Contains(t, configStr, "100.125.1.2")
|
||||
assert.Contains(t, configStr, "peer1.nb.internal")
|
||||
assert.Contains(t, configStr, "peer2.nb.internal")
|
||||
|
||||
// Check platform-specific UserKnownHostsFile
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, configStr, "UserKnownHostsFile NUL")
|
||||
} else {
|
||||
assert.Contains(t, configStr, "UserKnownHostsFile /dev/null")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSystemSSHConfigDir(t *testing.T) {
|
||||
configDir := getSystemSSHConfigDir()
|
||||
|
||||
// Path should not be empty
|
||||
assert.NotEmpty(t, configDir)
|
||||
|
||||
// Should be an absolute path
|
||||
assert.True(t, filepath.IsAbs(configDir))
|
||||
|
||||
// On Unix systems, should start with /etc
|
||||
// On Windows, should contain ProgramData
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, strings.ToLower(configDir), "programdata")
|
||||
} else {
|
||||
assert.Contains(t, configDir, "/etc/ssh")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_PeerLimit(t *testing.T) {
|
||||
// Create temporary directory for test
|
||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||
require.NoError(t, err)
|
||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
||||
|
||||
// Override manager paths to use temp directory
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
}
|
||||
|
||||
// Generate many peers (more than limit)
|
||||
var peers []PeerSSHInfo
|
||||
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
|
||||
peers = append(peers, PeerSSHInfo{
|
||||
Hostname: fmt.Sprintf("peer%d", i),
|
||||
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
|
||||
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
|
||||
})
|
||||
}
|
||||
|
||||
// Test that SSH config generation is skipped when too many peers
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Config should not be created due to peer limit
|
||||
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
|
||||
_, err = os.Stat(configPath)
|
||||
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
|
||||
}
|
||||
|
||||
func TestManager_ForcedSSHConfig(t *testing.T) {
|
||||
// Set force environment variable
|
||||
t.Setenv(EnvForceSSHConfig, "true")
|
||||
|
||||
// Create temporary directory for test
|
||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||
require.NoError(t, err)
|
||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
||||
|
||||
// Override manager paths to use temp directory
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
}
|
||||
|
||||
// Generate many peers (more than limit)
|
||||
var peers []PeerSSHInfo
|
||||
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
|
||||
peers = append(peers, PeerSSHInfo{
|
||||
Hostname: fmt.Sprintf("peer%d", i),
|
||||
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
|
||||
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
|
||||
})
|
||||
}
|
||||
|
||||
// Test that SSH config generation is forced despite many peers
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Config should be created despite peer limit due to force flag
|
||||
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
|
||||
_, err = os.Stat(configPath)
|
||||
require.NoError(t, err, "SSH config should be created when forced")
|
||||
|
||||
// Verify config contains peer hostnames
|
||||
content, err := os.ReadFile(configPath)
|
||||
require.NoError(t, err)
|
||||
configStr := string(content)
|
||||
assert.Contains(t, configStr, "peer0.nb.internal")
|
||||
assert.Contains(t, configStr, "peer1.nb.internal")
|
||||
}
|
||||
22
client/ssh/config/shutdown_state.go
Normal file
22
client/ssh/config/shutdown_state.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package config
|
||||
|
||||
// ShutdownState represents SSH configuration state that needs to be cleaned up.
|
||||
type ShutdownState struct {
|
||||
SSHConfigDir string
|
||||
SSHConfigFile string
|
||||
}
|
||||
|
||||
// Name returns the state name for the state manager.
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "ssh_config_state"
|
||||
}
|
||||
|
||||
// Cleanup removes SSH client configuration files.
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
manager := &Manager{
|
||||
sshConfigDir: s.SSHConfigDir,
|
||||
sshConfigFile: s.SSHConfigFile,
|
||||
}
|
||||
|
||||
return manager.RemoveSSHClientConfig()
|
||||
}
|
||||
99
client/ssh/detection/detection.go
Normal file
99
client/ssh/detection/detection.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package detection
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// ServerIdentifier is the base response for NetBird SSH servers
|
||||
ServerIdentifier = "NetBird-SSH-Server"
|
||||
// ProxyIdentifier is the base response for NetBird SSH proxy
|
||||
ProxyIdentifier = "NetBird-SSH-Proxy"
|
||||
// JWTRequiredMarker is appended to responses when JWT is required
|
||||
JWTRequiredMarker = "NetBird-JWT-Required"
|
||||
|
||||
// Timeout is the timeout for SSH server detection
|
||||
Timeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type ServerType string
|
||||
|
||||
const (
|
||||
ServerTypeNetBirdJWT ServerType = "netbird-jwt"
|
||||
ServerTypeNetBirdNoJWT ServerType = "netbird-no-jwt"
|
||||
ServerTypeRegular ServerType = "regular"
|
||||
)
|
||||
|
||||
// Dialer provides network connection capabilities
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// RequiresJWT checks if the server type requires JWT authentication
|
||||
func (s ServerType) RequiresJWT() bool {
|
||||
return s == ServerTypeNetBirdJWT
|
||||
}
|
||||
|
||||
// ExitCode returns the exit code for the detect command
|
||||
func (s ServerType) ExitCode() int {
|
||||
switch s {
|
||||
case ServerTypeNetBirdJWT:
|
||||
return 0
|
||||
case ServerTypeNetBirdNoJWT:
|
||||
return 1
|
||||
case ServerTypeRegular:
|
||||
return 2
|
||||
default:
|
||||
return 2
|
||||
}
|
||||
}
|
||||
|
||||
// DetectSSHServerType detects SSH server type using the provided dialer
|
||||
func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port int) (ServerType, error) {
|
||||
targetAddr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
|
||||
conn, err := dialer.DialContext(ctx, "tcp", targetAddr)
|
||||
if err != nil {
|
||||
log.Debugf("SSH connection failed for detection: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil {
|
||||
log.Debugf("set read deadline: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
serverBanner, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Debugf("read SSH banner: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
serverBanner = strings.TrimSpace(serverBanner)
|
||||
log.Debugf("SSH server banner: %s", serverBanner)
|
||||
|
||||
if !strings.HasPrefix(serverBanner, "SSH-") {
|
||||
log.Debugf("Invalid SSH banner")
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
if !strings.Contains(serverBanner, ServerIdentifier) {
|
||||
log.Debugf("Server banner does not contain identifier '%s'", ServerIdentifier)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
if strings.Contains(serverBanner, JWTRequiredMarker) {
|
||||
return ServerTypeNetBirdJWT, nil
|
||||
}
|
||||
|
||||
return ServerTypeNetBirdNoJWT, nil
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func isRoot() bool {
|
||||
return os.Geteuid() == 0
|
||||
}
|
||||
|
||||
func getLoginCmd(user string, remoteAddr net.Addr) (loginPath string, args []string, err error) {
|
||||
if !isRoot() {
|
||||
shell := getUserShell(user)
|
||||
if shell == "" {
|
||||
shell = "/bin/sh"
|
||||
}
|
||||
|
||||
return shell, []string{"-l"}, nil
|
||||
}
|
||||
|
||||
loginPath, err = exec.LookPath("login")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
addrPort, err := netip.ParseAddrPort(remoteAddr.String())
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
if util.FileExists("/etc/arch-release") && !util.FileExists("/etc/pam.d/remote") {
|
||||
return loginPath, []string{"-f", user, "-p"}, nil
|
||||
}
|
||||
return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil
|
||||
case "darwin":
|
||||
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), user}, nil
|
||||
case "freebsd":
|
||||
return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build !darwin
|
||||
// +build !darwin
|
||||
|
||||
package ssh
|
||||
|
||||
import "os/user"
|
||||
|
||||
func userNameLookup(username string) (*user.User, error) {
|
||||
if username == "" || (username == "root" && !isRoot()) {
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
return user.Lookup(username)
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
//go:build darwin
|
||||
// +build darwin
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func userNameLookup(username string) (*user.User, error) {
|
||||
if username == "" || (username == "root" && !isRoot()) {
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
var userObject *user.User
|
||||
userObject, err := user.Lookup(username)
|
||||
if err != nil && err.Error() == user.UnknownUserError(username).Error() {
|
||||
return idUserNameLookup(username)
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userObject, nil
|
||||
}
|
||||
|
||||
func idUserNameLookup(username string) (*user.User, error) {
|
||||
cmd := exec.Command("id", "-P", username)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while retrieving user with id -P command, error: %v", err)
|
||||
}
|
||||
colon := ":"
|
||||
|
||||
if !bytes.Contains(out, []byte(username+colon)) {
|
||||
return nil, fmt.Errorf("unable to find user in returned string")
|
||||
}
|
||||
// netbird:********:501:20::0:0:netbird:/Users/netbird:/bin/zsh
|
||||
parts := strings.SplitN(string(out), colon, 10)
|
||||
userObject := &user.User{
|
||||
Username: parts[0],
|
||||
Uid: parts[2],
|
||||
Gid: parts[3],
|
||||
Name: parts[7],
|
||||
HomeDir: parts[8],
|
||||
}
|
||||
return userObject, nil
|
||||
}
|
||||
392
client/ssh/proxy/proxy.go
Normal file
392
client/ssh/proxy/proxy.go
Normal file
@@ -0,0 +1,392 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
// sshConnectionTimeout is the timeout for SSH TCP connection establishment
|
||||
sshConnectionTimeout = 120 * time.Second
|
||||
// sshHandshakeTimeout is the timeout for SSH handshake completion
|
||||
sshHandshakeTimeout = 30 * time.Second
|
||||
|
||||
jwtAuthErrorMsg = "JWT authentication: %w"
|
||||
)
|
||||
|
||||
type SSHProxy struct {
|
||||
daemonAddr string
|
||||
targetHost string
|
||||
targetPort int
|
||||
stderr io.Writer
|
||||
conn *grpc.ClientConn
|
||||
daemonClient proto.DaemonServiceClient
|
||||
}
|
||||
|
||||
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) {
|
||||
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
|
||||
return &SSHProxy{
|
||||
daemonAddr: daemonAddr,
|
||||
targetHost: targetHost,
|
||||
targetPort: targetPort,
|
||||
stderr: stderr,
|
||||
conn: grpcConn,
|
||||
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) Close() error {
|
||||
if p.conn != nil {
|
||||
return p.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) Connect(ctx context.Context) error {
|
||||
hint := profilemanager.GetLoginHint()
|
||||
|
||||
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint)
|
||||
if err != nil {
|
||||
return fmt.Errorf(jwtAuthErrorMsg, err)
|
||||
}
|
||||
|
||||
return p.runProxySSHServer(ctx, jwtToken)
|
||||
}
|
||||
|
||||
func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error {
|
||||
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
|
||||
|
||||
sshServer := &ssh.Server{
|
||||
Handler: func(s ssh.Session) {
|
||||
p.handleSSHSession(ctx, s, jwtToken)
|
||||
},
|
||||
ChannelHandlers: map[string]ssh.ChannelHandler{
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
"direct-tcpip": p.directTCPIPHandler,
|
||||
},
|
||||
SubsystemHandlers: map[string]ssh.SubsystemHandler{
|
||||
"sftp": func(s ssh.Session) {
|
||||
p.sftpSubsystemHandler(s, jwtToken)
|
||||
},
|
||||
},
|
||||
RequestHandlers: map[string]ssh.RequestHandler{
|
||||
"tcpip-forward": p.tcpipForwardHandler,
|
||||
"cancel-tcpip-forward": p.cancelTcpipForwardHandler,
|
||||
},
|
||||
Version: serverVersion,
|
||||
}
|
||||
|
||||
hostKey, err := generateHostKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate host key: %w", err)
|
||||
}
|
||||
sshServer.HostSigners = []ssh.Signer{hostKey}
|
||||
|
||||
conn := &stdioConn{
|
||||
stdin: os.Stdin,
|
||||
stdout: os.Stdout,
|
||||
}
|
||||
|
||||
sshServer.HandleConn(conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) {
|
||||
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
|
||||
|
||||
sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = sshClient.Close() }()
|
||||
|
||||
serverSession, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = serverSession.Close() }()
|
||||
|
||||
serverSession.Stdin = session
|
||||
serverSession.Stdout = session
|
||||
serverSession.Stderr = session.Stderr()
|
||||
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
if isPty {
|
||||
if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil {
|
||||
log.Debugf("PTY request to backend: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
for win := range winCh {
|
||||
if err := serverSession.WindowChange(win.Height, win.Width); err != nil {
|
||||
log.Debugf("window change: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if len(session.Command()) > 0 {
|
||||
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
|
||||
log.Debugf("run command: %v", err)
|
||||
p.handleProxyExitCode(session, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err = serverSession.Shell(); err != nil {
|
||||
log.Debugf("start shell: %v", err)
|
||||
return
|
||||
}
|
||||
if err := serverSession.Wait(); err != nil {
|
||||
log.Debugf("session wait: %v", err)
|
||||
p.handleProxyExitCode(session, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
|
||||
var exitErr *cryptossh.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil {
|
||||
log.Debugf("set exit status: %v", exitErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func generateHostKey() (ssh.Signer, error) {
|
||||
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate ED25519 key: %w", err)
|
||||
}
|
||||
|
||||
signer, err := cryptossh.ParsePrivateKey(keyPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse private key: %w", err)
|
||||
}
|
||||
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
type stdioConn struct {
|
||||
stdin io.Reader
|
||||
stdout io.Writer
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (c *stdioConn) Read(b []byte) (n int, err error) {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return 0, io.EOF
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return c.stdin.Read(b)
|
||||
}
|
||||
|
||||
func (c *stdioConn) Write(b []byte) (n int, err error) {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return c.stdout.Write(b)
|
||||
}
|
||||
|
||||
func (c *stdioConn) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stdioConn) LocalAddr() net.Addr {
|
||||
return &net.UnixAddr{Name: "stdio", Net: "unix"}
|
||||
}
|
||||
|
||||
func (c *stdioConn) RemoteAddr() net.Addr {
|
||||
return &net.UnixAddr{Name: "stdio", Net: "unix"}
|
||||
}
|
||||
|
||||
func (c *stdioConn) SetDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stdioConn) SetReadDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) {
|
||||
_ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy")
|
||||
}
|
||||
|
||||
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
|
||||
ctx, cancel := context.WithCancel(s.Context())
|
||||
defer cancel()
|
||||
|
||||
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
|
||||
|
||||
sshClient, err := p.dialBackend(ctx, targetAddr, s.User(), jwtToken)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(s, "SSH connection failed: %v\n", err)
|
||||
_ = s.Exit(1)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := sshClient.Close(); err != nil {
|
||||
log.Debugf("close SSH client: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverSession, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(s, "create server session: %v\n", err)
|
||||
_ = s.Exit(1)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := serverSession.Close(); err != nil {
|
||||
log.Debugf("close server session: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
stdin, stdout, err := p.setupSFTPPipes(serverSession)
|
||||
if err != nil {
|
||||
log.Debugf("setup SFTP pipes: %v", err)
|
||||
_ = s.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
if err := serverSession.RequestSubsystem("sftp"); err != nil {
|
||||
_, _ = fmt.Fprintf(s, "SFTP subsystem request failed: %v\n", err)
|
||||
_ = s.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
p.runSFTPBridge(ctx, s, stdin, stdout, serverSession)
|
||||
}
|
||||
|
||||
func (p *SSHProxy) setupSFTPPipes(serverSession *cryptossh.Session) (io.WriteCloser, io.Reader, error) {
|
||||
stdin, err := serverSession.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get stdin pipe: %w", err)
|
||||
}
|
||||
|
||||
stdout, err := serverSession.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
return stdin, stdout, nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.WriteCloser, stdout io.Reader, serverSession *cryptossh.Session) {
|
||||
copyErrCh := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(stdin, s)
|
||||
if err != nil {
|
||||
log.Debugf("SFTP client to server copy: %v", err)
|
||||
}
|
||||
if err := stdin.Close(); err != nil {
|
||||
log.Debugf("close stdin: %v", err)
|
||||
}
|
||||
copyErrCh <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(s, stdout)
|
||||
if err != nil {
|
||||
log.Debugf("SFTP server to client copy: %v", err)
|
||||
}
|
||||
copyErrCh <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
if err := serverSession.Close(); err != nil {
|
||||
log.Debugf("force close server session on context cancellation: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
if err := <-copyErrCh; err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Debugf("SFTP copy error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := serverSession.Wait(); err != nil {
|
||||
log.Debugf("SFTP session ended: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
|
||||
return false, []byte("port forwarding not supported in proxy")
|
||||
}
|
||||
|
||||
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: user,
|
||||
Auth: []cryptossh.AuthMethod{cryptossh.Password(jwtToken)},
|
||||
Timeout: sshHandshakeTimeout,
|
||||
HostKeyCallback: p.verifyHostKey,
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: sshConnectionTimeout,
|
||||
}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to server: %w", err)
|
||||
}
|
||||
|
||||
clientConn, chans, reqs, err := cryptossh.NewClientConn(conn, addr, config)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("SSH handshake: %w", err)
|
||||
}
|
||||
|
||||
return cryptossh.NewClient(clientConn, chans, reqs), nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) verifyHostKey(hostname string, remote net.Addr, key cryptossh.PublicKey) error {
|
||||
verifier := nbssh.NewDaemonHostKeyVerifier(p.daemonClient)
|
||||
callback := nbssh.CreateHostKeyCallback(verifier)
|
||||
return callback(hostname, remote, key)
|
||||
}
|
||||
367
client/ssh/proxy/proxy_test.go
Normal file
367
client/ssh/proxy/proxy_test.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" {
|
||||
if os.Args[2] == "exec" {
|
||||
if len(os.Args) > 3 {
|
||||
cmd := os.Args[3]
|
||||
if cmd == "echo" && len(os.Args) > 4 {
|
||||
fmt.Fprintln(os.Stdout, os.Args[4])
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' with args: %v - preventing infinite recursion\n", os.Args)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestSSHProxy_verifyHostKey(t *testing.T) {
|
||||
t.Run("calls daemon to verify host key", func(t *testing.T) {
|
||||
mockDaemon := startMockDaemon(t)
|
||||
defer mockDaemon.stop()
|
||||
|
||||
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = grpcConn.Close() }()
|
||||
|
||||
proxy := &SSHProxy{
|
||||
daemonAddr: mockDaemon.addr,
|
||||
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||
}
|
||||
|
||||
testKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
testPubKey, err := nbssh.GeneratePublicKey(testKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey("test-host", testPubKey)
|
||||
|
||||
err = proxy.verifyHostKey("test-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, testPubKey))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("rejects unknown host key", func(t *testing.T) {
|
||||
mockDaemon := startMockDaemon(t)
|
||||
defer mockDaemon.stop()
|
||||
|
||||
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = grpcConn.Close() }()
|
||||
|
||||
proxy := &SSHProxy{
|
||||
daemonAddr: mockDaemon.addr,
|
||||
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||
}
|
||||
|
||||
unknownKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
unknownPubKey, err := nbssh.GeneratePublicKey(unknownKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = proxy.verifyHostKey("unknown-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, unknownPubKey))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "peer unknown-host not found in network")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHProxy_Connect(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// TODO: Windows test times out - user switching and command execution tested on Linux
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping on Windows - covered by Linux tests")
|
||||
}
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
defer func() { _ = sshServer.Stop() }()
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
defer mockDaemon.stop()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
defer func() { _ = clientConn.Close() }()
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
defer func() {
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
}()
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(stdinWriter, proxyConn)
|
||||
}()
|
||||
go func() {
|
||||
_, _ = io.Copy(proxyConn, stdoutReader)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
connectErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
connectErrCh <- proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err, "Should connect to proxy server")
|
||||
defer func() { _ = sshClientConn.Close() }()
|
||||
|
||||
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err, "Should create session through full proxy to backend")
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output("echo hello-from-proxy")
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
require.NoError(t, err, "Command should execute successfully through proxy")
|
||||
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("Command execution timed out")
|
||||
}
|
||||
|
||||
_ = session.Close()
|
||||
_ = sshClient.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
}
|
||||
|
||||
type mockDaemonServer struct {
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
hostKeys map[string][]byte
|
||||
jwtToken string
|
||||
}
|
||||
|
||||
func (m *mockDaemonServer) GetPeerSSHHostKey(ctx context.Context, req *proto.GetPeerSSHHostKeyRequest) (*proto.GetPeerSSHHostKeyResponse, error) {
|
||||
key, found := m.hostKeys[req.PeerAddress]
|
||||
return &proto.GetPeerSSHHostKeyResponse{
|
||||
Found: found,
|
||||
SshHostKey: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockDaemonServer) RequestJWTAuth(ctx context.Context, req *proto.RequestJWTAuthRequest) (*proto.RequestJWTAuthResponse, error) {
|
||||
return &proto.RequestJWTAuthResponse{
|
||||
CachedToken: m.jwtToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockDaemonServer) WaitJWTToken(ctx context.Context, req *proto.WaitJWTTokenRequest) (*proto.WaitJWTTokenResponse, error) {
|
||||
return &proto.WaitJWTTokenResponse{
|
||||
Token: m.jwtToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type mockDaemon struct {
|
||||
addr string
|
||||
server *grpc.Server
|
||||
impl *mockDaemonServer
|
||||
}
|
||||
|
||||
func startMockDaemon(t *testing.T) *mockDaemon {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
impl := &mockDaemonServer{
|
||||
hostKeys: make(map[string][]byte),
|
||||
jwtToken: "test-jwt-token",
|
||||
}
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
proto.RegisterDaemonServiceServer(grpcServer, impl)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
|
||||
return &mockDaemon{
|
||||
addr: listener.Addr().String(),
|
||||
server: grpcServer,
|
||||
impl: impl,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
|
||||
m.impl.hostKeys[addr] = pubKey
|
||||
}
|
||||
|
||||
func (m *mockDaemon) setJWTToken(token string) {
|
||||
m.impl.jwtToken = token
|
||||
}
|
||||
|
||||
func (m *mockDaemon) stop() {
|
||||
if m.server != nil {
|
||||
m.server.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
|
||||
t.Helper()
|
||||
pubKey, _, _, _, err := cryptossh.ParseAuthorizedKey(pubKeyBytes)
|
||||
require.NoError(t, err)
|
||||
return pubKey
|
||||
}
|
||||
|
||||
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
||||
t.Helper()
|
||||
privateKey, jwksJSON := generateTestJWKS(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if _, err := w.Write(jwksJSON); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
return server, privateKey, server.URL
|
||||
}
|
||||
|
||||
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||
t.Helper()
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
n := publicKey.N.Bytes()
|
||||
e := publicKey.E
|
||||
|
||||
jwk := nbjwt.JSONWebKey{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Use: "sig",
|
||||
N: base64.RawURLEncoding.EncodeToString(n),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
|
||||
}
|
||||
|
||||
jwks := nbjwt.Jwks{
|
||||
Keys: []nbjwt.JSONWebKey{jwk},
|
||||
}
|
||||
|
||||
jwksJSON, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
return privateKey, jwksJSON
|
||||
}
|
||||
|
||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
|
||||
t.Helper()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
return tokenString
|
||||
}
|
||||
@@ -1,280 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
||||
const DefaultSSHPort = 44338
|
||||
|
||||
// TerminalTimeout is the timeout for terminal session to be ready
|
||||
const TerminalTimeout = 10 * time.Second
|
||||
|
||||
// TerminalBackoffDelay is the delay between terminal session readiness checks
|
||||
const TerminalBackoffDelay = 500 * time.Millisecond
|
||||
|
||||
// DefaultSSHServer is a function that creates DefaultServer
|
||||
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
|
||||
return newDefaultServer(hostKeyPEM, addr)
|
||||
}
|
||||
|
||||
// Server is an interface of SSH server
|
||||
type Server interface {
|
||||
// Stop stops SSH server.
|
||||
Stop() error
|
||||
// Start starts SSH server. Blocking
|
||||
Start() error
|
||||
// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys
|
||||
RemoveAuthorizedKey(peer string)
|
||||
// AddAuthorizedKey add a given peer key to server authorized keys
|
||||
AddAuthorizedKey(peer, newKey string) error
|
||||
}
|
||||
|
||||
// DefaultServer is the embedded NetBird SSH server
|
||||
type DefaultServer struct {
|
||||
listener net.Listener
|
||||
// authorizedKeys is ssh pub key indexed by peer WireGuard public key
|
||||
authorizedKeys map[string]ssh.PublicKey
|
||||
mu sync.Mutex
|
||||
hostKeyPEM []byte
|
||||
sessions []ssh.Session
|
||||
}
|
||||
|
||||
// newDefaultServer creates new server with provided host key
|
||||
func newDefaultServer(hostKeyPEM []byte, addr string) (*DefaultServer, error) {
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allowedKeys := make(map[string]ssh.PublicKey)
|
||||
return &DefaultServer{listener: ln, mu: sync.Mutex{}, hostKeyPEM: hostKeyPEM, authorizedKeys: allowedKeys, sessions: make([]ssh.Session, 0)}, nil
|
||||
}
|
||||
|
||||
// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys
|
||||
func (srv *DefaultServer) RemoveAuthorizedKey(peer string) {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
delete(srv.authorizedKeys, peer)
|
||||
}
|
||||
|
||||
// AddAuthorizedKey add a given peer key to server authorized keys
|
||||
func (srv *DefaultServer) AddAuthorizedKey(peer, newKey string) error {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(newKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
srv.authorizedKeys[peer] = parsedKey
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops SSH server.
|
||||
func (srv *DefaultServer) Stop() error {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
err := srv.listener.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, session := range srv.sessions {
|
||||
err := session.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed closing SSH session from %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (srv *DefaultServer) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
for _, allowed := range srv.authorizedKeys {
|
||||
if ssh.KeysEqual(allowed, key) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func prepareUserEnv(user *user.User, shell string) []string {
|
||||
return []string{
|
||||
fmt.Sprint("SHELL=" + shell),
|
||||
fmt.Sprint("USER=" + user.Username),
|
||||
fmt.Sprint("HOME=" + user.HomeDir),
|
||||
}
|
||||
}
|
||||
|
||||
func acceptEnv(s string) bool {
|
||||
split := strings.Split(s, "=")
|
||||
if len(split) != 2 {
|
||||
return false
|
||||
}
|
||||
return split[0] == "TERM" || split[0] == "LANG" || strings.HasPrefix(split[0], "LC_")
|
||||
}
|
||||
|
||||
// sessionHandler handles SSH session post auth
|
||||
func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
||||
srv.mu.Lock()
|
||||
srv.sessions = append(srv.sessions, session)
|
||||
srv.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
err := session.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String())
|
||||
|
||||
localUser, err := userNameLookup(session.User())
|
||||
if err != nil {
|
||||
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
|
||||
err = session.Exit(1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Warnf("failed SSH session from %v, user %s", session.RemoteAddr(), session.User())
|
||||
return
|
||||
}
|
||||
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
if isPty {
|
||||
loginCmd, loginArgs, err := getLoginCmd(localUser.Username, session.RemoteAddr())
|
||||
if err != nil {
|
||||
log.Warnf("failed logging-in user %s from remote IP %s", localUser.Username, session.RemoteAddr().String())
|
||||
return
|
||||
}
|
||||
cmd := exec.Command(loginCmd, loginArgs...)
|
||||
go func() {
|
||||
<-session.Context().Done()
|
||||
if cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
err := cmd.Process.Kill()
|
||||
if err != nil {
|
||||
log.Debugf("failed killing SSH process %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
cmd.Dir = localUser.HomeDir
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||
cmd.Env = append(cmd.Env, prepareUserEnv(localUser, getUserShell(localUser.Uid))...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
cmd.Env = append(cmd.Env, v)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("Login command: %s", cmd.String())
|
||||
file, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
log.Errorf("failed starting SSH server: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
for win := range winCh {
|
||||
setWinSize(file, win.Width, win.Height)
|
||||
}
|
||||
}()
|
||||
|
||||
srv.stdInOut(file, session)
|
||||
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
_, err := io.WriteString(session, "only PTY is supported.\n")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = session.Exit(1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
log.Debugf("SSH session ended")
|
||||
}
|
||||
|
||||
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
|
||||
go func() {
|
||||
// stdin
|
||||
_, err := io.Copy(file, session)
|
||||
if err != nil {
|
||||
_ = session.Exit(1)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// AWS Linux 2 machines need some time to open the terminal so we need to wait for it
|
||||
timer := time.NewTimer(TerminalTimeout)
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
_, _ = session.Write([]byte("Reached timeout while opening connection\n"))
|
||||
_ = session.Exit(1)
|
||||
return
|
||||
default:
|
||||
// stdout
|
||||
writtenBytes, err := io.Copy(session, file)
|
||||
if err != nil && writtenBytes != 0 {
|
||||
_ = session.Exit(0)
|
||||
return
|
||||
}
|
||||
time.Sleep(TerminalBackoffDelay)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts SSH server. Blocking
|
||||
func (srv *DefaultServer) Start() error {
|
||||
log.Infof("starting SSH server on addr: %s", srv.listener.Addr().String())
|
||||
|
||||
publicKeyOption := ssh.PublicKeyAuth(srv.publicKeyHandler)
|
||||
hostKeyPEM := ssh.HostKeyPEM(srv.hostKeyPEM)
|
||||
err := ssh.Serve(srv.listener, srv.sessionHandler, publicKeyOption, hostKeyPEM)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getUserShell(userID string) string {
|
||||
if runtime.GOOS == "linux" {
|
||||
output, _ := exec.Command("getent", "passwd", userID).Output()
|
||||
line := strings.SplitN(string(output), ":", 10)
|
||||
if len(line) > 6 {
|
||||
return strings.TrimSpace(line[6])
|
||||
}
|
||||
}
|
||||
|
||||
shell := os.Getenv("SHELL")
|
||||
if shell == "" {
|
||||
shell = "/bin/sh"
|
||||
}
|
||||
return shell
|
||||
}
|
||||
206
client/ssh/server/command_execution.go
Normal file
206
client/ssh/server/command_execution.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// handleCommand executes an SSH command with privilege validation
|
||||
func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) {
|
||||
hasPty := winCh != nil
|
||||
|
||||
commandType := "command"
|
||||
if hasPty {
|
||||
commandType = "Pty command"
|
||||
}
|
||||
|
||||
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
|
||||
|
||||
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
|
||||
if err != nil {
|
||||
logger.Errorf("%s creation failed: %v", commandType, err)
|
||||
|
||||
errorMsg := fmt.Sprintf("Cannot create %s - platform may not support user switching", commandType)
|
||||
if hasPty {
|
||||
errorMsg += " with Pty"
|
||||
}
|
||||
errorMsg += "\n"
|
||||
|
||||
if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
|
||||
logger.Debugf(errWriteSession, writeErr)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !hasPty {
|
||||
if s.executeCommand(logger, session, execCmd, cleanup) {
|
||||
logger.Debugf("%s execution completed", commandType)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
defer cleanup()
|
||||
|
||||
ptyReq, _, _ := session.Pty()
|
||||
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
|
||||
logger.Debugf("%s execution completed", commandType)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
localUser := privilegeResult.User
|
||||
if localUser == nil {
|
||||
return nil, nil, errors.New("no user in privilege result")
|
||||
}
|
||||
|
||||
// If PTY requested but su doesn't support --pty, skip su and use executor
|
||||
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
|
||||
if hasPty && !s.suSupportsPty {
|
||||
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
|
||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||
}
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
return cmd, cleanup, nil
|
||||
}
|
||||
|
||||
// Try su first for system integration (PAM/audit) when privileged
|
||||
cmd, err := s.createSuCommand(session, localUser, hasPty)
|
||||
if err != nil || privilegeResult.UsedFallback {
|
||||
log.Debugf("su command failed, falling back to executor: %v", err)
|
||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||
}
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
return cmd, cleanup, nil
|
||||
}
|
||||
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
return cmd, func() {}, nil
|
||||
}
|
||||
|
||||
// executeCommand executes the command and handles I/O and exit codes
|
||||
func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, cleanup func()) bool {
|
||||
defer cleanup()
|
||||
|
||||
s.setupProcessGroup(execCmd)
|
||||
|
||||
stdinPipe, err := execCmd.StdinPipe()
|
||||
if err != nil {
|
||||
logger.Errorf("create stdin pipe: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
execCmd.Stdout = session
|
||||
execCmd.Stderr = session.Stderr()
|
||||
|
||||
if execCmd.Dir != "" {
|
||||
if _, err := os.Stat(execCmd.Dir); err != nil {
|
||||
logger.Warnf("working directory does not exist: %s (%v)", execCmd.Dir, err)
|
||||
execCmd.Dir = "/"
|
||||
}
|
||||
}
|
||||
|
||||
if err := execCmd.Start(); err != nil {
|
||||
logger.Errorf("command start failed: %v", err)
|
||||
// no user message for exec failure, just exit
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
go s.handleCommandIO(logger, stdinPipe, session)
|
||||
return s.waitForCommandCleanup(logger, session, execCmd)
|
||||
}
|
||||
|
||||
// handleCommandIO manages stdin/stdout copying in a goroutine
|
||||
func (s *Server) handleCommandIO(logger *log.Entry, stdinPipe io.WriteCloser, session ssh.Session) {
|
||||
defer func() {
|
||||
if err := stdinPipe.Close(); err != nil {
|
||||
logger.Debugf("stdin pipe close error: %v", err)
|
||||
}
|
||||
}()
|
||||
if _, err := io.Copy(stdinPipe, session); err != nil {
|
||||
logger.Debugf("stdin copy error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// waitForCommandCleanup waits for command completion with session disconnect handling
|
||||
func (s *Server) waitForCommandCleanup(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd) bool {
|
||||
ctx := session.Context()
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- execCmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Debugf("session cancelled, terminating command")
|
||||
s.killProcessGroup(execCmd)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
logger.Tracef("command terminated after session cancellation: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
logger.Warnf("command did not terminate within 5 seconds after session cancellation")
|
||||
}
|
||||
|
||||
if err := session.Exit(130); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
|
||||
case err := <-done:
|
||||
return s.handleCommandCompletion(logger, session, err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleCommandCompletion handles command completion
|
||||
func (s *Server) handleCommandCompletion(logger *log.Entry, session ssh.Session, err error) bool {
|
||||
if err != nil {
|
||||
logger.Debugf("command execution failed: %v", err)
|
||||
s.handleSessionExit(session, err, logger)
|
||||
return false
|
||||
}
|
||||
|
||||
s.handleSessionExit(session, nil, logger)
|
||||
return true
|
||||
}
|
||||
|
||||
// handleSessionExit handles command errors and sets appropriate exit codes
|
||||
func (s *Server) handleSessionExit(session ssh.Session, err error, logger *log.Entry) {
|
||||
if err == nil {
|
||||
if err := session.Exit(0); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
if err := session.Exit(exitError.ExitCode()); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
} else {
|
||||
logger.Debugf("non-exit error in command execution: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
52
client/ssh/server/command_execution_js.go
Normal file
52
client/ssh/server/command_execution_js.go
Normal file
@@ -0,0 +1,52 @@
|
||||
//go:build js
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
|
||||
|
||||
// createSuCommand is not supported on JS/WASM
|
||||
func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
||||
return nil, errNotSupported
|
||||
}
|
||||
|
||||
// createExecutorCommand is not supported on JS/WASM
|
||||
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
|
||||
return nil, nil, errNotSupported
|
||||
}
|
||||
|
||||
// prepareCommandEnv is not supported on JS/WASM
|
||||
func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupProcessGroup is not supported on JS/WASM
|
||||
func (s *Server) setupProcessGroup(_ *exec.Cmd) {
|
||||
}
|
||||
|
||||
// killProcessGroup is not supported on JS/WASM
|
||||
func (s *Server) killProcessGroup(*exec.Cmd) {
|
||||
}
|
||||
|
||||
// detectSuPtySupport always returns false on JS/WASM
|
||||
func (s *Server) detectSuPtySupport(context.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// executeCommandWithPty is not supported on JS/WASM
|
||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
logger.Errorf("PTY command execution not supported on JS/WASM")
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
329
client/ssh/server/command_execution_unix.go
Normal file
329
client/ssh/server/command_execution_unix.go
Normal file
@@ -0,0 +1,329 @@
|
||||
//go:build unix
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ptyManager manages Pty file operations with thread safety
|
||||
type ptyManager struct {
|
||||
file *os.File
|
||||
mu sync.RWMutex
|
||||
closed bool
|
||||
closeErr error
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newPtyManager(file *os.File) *ptyManager {
|
||||
return &ptyManager{file: file}
|
||||
}
|
||||
|
||||
func (pm *ptyManager) Close() error {
|
||||
pm.once.Do(func() {
|
||||
pm.mu.Lock()
|
||||
pm.closed = true
|
||||
pm.closeErr = pm.file.Close()
|
||||
pm.mu.Unlock()
|
||||
})
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
return pm.closeErr
|
||||
}
|
||||
|
||||
func (pm *ptyManager) Setsize(ws *pty.Winsize) error {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
if pm.closed {
|
||||
return errors.New("pty is closed")
|
||||
}
|
||||
return pty.Setsize(pm.file, ws)
|
||||
}
|
||||
|
||||
func (pm *ptyManager) File() *os.File {
|
||||
return pm.file
|
||||
}
|
||||
|
||||
// detectSuPtySupport checks if su supports the --pty flag
|
||||
func (s *Server) detectSuPtySupport(ctx context.Context) bool {
|
||||
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "su", "--help")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
log.Debugf("su --help failed (may not support --help): %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
supported := strings.Contains(string(output), "--pty")
|
||||
log.Debugf("su --pty support detected: %v", supported)
|
||||
return supported
|
||||
}
|
||||
|
||||
// createSuCommand creates a command using su -l -c for privilege switching
|
||||
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||
suPath, err := exec.LookPath("su")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("su command not available: %w", err)
|
||||
}
|
||||
|
||||
command := session.RawCommand()
|
||||
if command == "" {
|
||||
return nil, fmt.Errorf("no command specified for su execution")
|
||||
}
|
||||
|
||||
args := []string{"-l"}
|
||||
if hasPty && s.suSupportsPty {
|
||||
args = append(args, "--pty")
|
||||
}
|
||||
args = append(args, localUser.Username, "-c", command)
|
||||
|
||||
cmd := exec.CommandContext(session.Context(), suPath, args...)
|
||||
cmd.Dir = localUser.HomeDir
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// getShellCommandArgs returns the shell command and arguments for executing a command string
|
||||
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||
if cmdString == "" {
|
||||
return []string{shell, "-l"}
|
||||
}
|
||||
return []string{shell, "-l", "-c", cmdString}
|
||||
}
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Unix
|
||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// executeCommandWithPty executes a command with PTY allocation
|
||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
termType := ptyReq.Term
|
||||
if termType == "" {
|
||||
termType = "xterm-256color"
|
||||
}
|
||||
execCmd.Env = append(execCmd.Env, fmt.Sprintf("TERM=%s", termType))
|
||||
|
||||
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
||||
}
|
||||
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
|
||||
if err != nil {
|
||||
logger.Errorf("Pty command creation failed: %v", err)
|
||||
errorMsg := "User switching failed - login command not available\r\n"
|
||||
if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
|
||||
logger.Debugf(errWriteSession, writeErr)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
logger.Infof("starting interactive shell: %s", execCmd.Path)
|
||||
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
||||
}
|
||||
|
||||
// runPtyCommand runs a command with PTY management (common code for interactive and command execution)
|
||||
func (s *Server) runPtyCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
ptmx, err := s.startPtyCommandWithSize(execCmd, ptyReq)
|
||||
if err != nil {
|
||||
logger.Errorf("Pty start failed: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
ptyMgr := newPtyManager(ptmx)
|
||||
defer func() {
|
||||
if err := ptyMgr.Close(); err != nil {
|
||||
logger.Debugf("Pty close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go s.handlePtyWindowResize(logger, session, ptyMgr, winCh)
|
||||
s.handlePtyIO(logger, session, ptyMgr)
|
||||
s.waitForPtyCompletion(logger, session, execCmd, ptyMgr)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) startPtyCommandWithSize(execCmd *exec.Cmd, ptyReq ssh.Pty) (*os.File, error) {
|
||||
winSize := &pty.Winsize{
|
||||
Cols: uint16(ptyReq.Window.Width),
|
||||
Rows: uint16(ptyReq.Window.Height),
|
||||
}
|
||||
if winSize.Cols == 0 {
|
||||
winSize.Cols = 80
|
||||
}
|
||||
if winSize.Rows == 0 {
|
||||
winSize.Rows = 24
|
||||
}
|
||||
|
||||
ptmx, err := pty.StartWithSize(execCmd, winSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start Pty: %w", err)
|
||||
}
|
||||
|
||||
return ptmx, nil
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyWindowResize(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, winCh <-chan ssh.Window) {
|
||||
for {
|
||||
select {
|
||||
case <-session.Context().Done():
|
||||
return
|
||||
case win, ok := <-winCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := ptyMgr.Setsize(&pty.Winsize{Rows: uint16(win.Height), Cols: uint16(win.Width)}); err != nil {
|
||||
logger.Debugf("Pty resize to %dx%d: %v", win.Width, win.Height, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager) {
|
||||
ptmx := ptyMgr.File()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(ptmx, session); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
|
||||
logger.Warnf("Pty input copy error: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
logger.Debugf("session close error: %v", err)
|
||||
}
|
||||
}()
|
||||
if _, err := io.Copy(session, ptmx); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
|
||||
logger.Warnf("Pty output copy error: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager) {
|
||||
ctx := session.Context()
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- execCmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
|
||||
case err := <-done:
|
||||
s.handlePtyCommandCompletion(logger, session, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager, done <-chan error) {
|
||||
logger.Debugf("Pty session cancelled, terminating command")
|
||||
if err := ptyMgr.Close(); err != nil {
|
||||
logger.Debugf("Pty close during session cancellation: %v", err)
|
||||
}
|
||||
|
||||
s.killProcessGroup(execCmd)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
logger.Debugf("Pty command terminated after session cancellation with error: %v", err)
|
||||
} else {
|
||||
logger.Debugf("Pty command terminated after session cancellation")
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
logger.Warnf("Pty command did not terminate within 5 seconds after session cancellation")
|
||||
}
|
||||
|
||||
if err := session.Exit(130); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) {
|
||||
if err != nil {
|
||||
logger.Debugf("Pty command execution failed: %v", err)
|
||||
s.handleSessionExit(session, err, logger)
|
||||
return
|
||||
}
|
||||
|
||||
// Normal completion
|
||||
logger.Debugf("Pty command completed successfully")
|
||||
if err := session.Exit(0); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) setupProcessGroup(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Setpgid: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) killProcessGroup(cmd *exec.Cmd) {
|
||||
if cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
|
||||
logger := log.WithField("pid", cmd.Process.Pid)
|
||||
pgid := cmd.Process.Pid
|
||||
|
||||
if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
|
||||
logger.Debugf("kill process group SIGTERM: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
const gracePeriod = 500 * time.Millisecond
|
||||
const checkInterval = 50 * time.Millisecond
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.After(gracePeriod)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil {
|
||||
logger.Debugf("kill process group SIGKILL: %v", err)
|
||||
}
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := syscall.Kill(-pgid, 0); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user