diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index d4db84aa3..6ca941626 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -184,21 +184,22 @@ func parseCustomSSHFlags(args []string) ([]string, []string, []string) { for i := 0; i < len(args); i++ { arg := args[i] - if strings.HasPrefix(arg, "-L") { + switch { + case strings.HasPrefix(arg, "-L"): if arg == "-L" && i+1 < len(args) { localForwardFlags = append(localForwardFlags, args[i+1]) i++ } else if len(arg) > 2 { localForwardFlags = append(localForwardFlags, arg[2:]) } - } else if strings.HasPrefix(arg, "-R") { + case strings.HasPrefix(arg, "-R"): if arg == "-R" && i+1 < len(args) { remoteForwardFlags = append(remoteForwardFlags, args[i+1]) i++ } else if len(arg) > 2 { remoteForwardFlags = append(remoteForwardFlags, arg[2:]) } - } else { + default: filteredArgs = append(filteredArgs, arg) } } diff --git a/client/cmd/ssh_sftp_unix.go b/client/cmd/ssh_sftp_unix.go index 470af9491..7723165cf 100644 --- a/client/cmd/ssh_sftp_unix.go +++ b/client/cmd/ssh_sftp_unix.go @@ -86,9 +86,15 @@ func sftpMain(cmd *cobra.Command, _ []string) error { log.Tracef("starting SFTP server with dropped privileges") if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) { cmd.PrintErrf("SFTP server error: %v\n", err) + if closeErr := sftpServer.Close(); closeErr != nil { + cmd.PrintErrf("SFTP server close error: %v\n", closeErr) + } os.Exit(sshserver.ExitCodeShellExecFail) } + if closeErr := sftpServer.Close(); closeErr != nil { + cmd.PrintErrf("SFTP server close error: %v\n", closeErr) + } os.Exit(sshserver.ExitCodeSuccess) return nil } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index f9e213597..d2f7a63be 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -1267,15 +1267,6 @@ func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint m.logger.Debug("Unregistered netstack service on protocol %s port %d", protocol, port) } -// isNetstackService checks if a service is registered as listening on netstack for the given protocol and port -func (m *Manager) isNetstackService(layerType gopacket.LayerType, port uint16) bool { - m.netstackServiceMutex.RLock() - defer m.netstackServiceMutex.RUnlock() - key := serviceKey{protocol: layerType, port: port} - _, exists := m.netstackServices[key] - return exists -} - // protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType { switch protocol { diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 50cac01d9..9a7fa4d3d 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -16,8 +16,11 @@ import ( var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT") +var ( + errInvalidIPHeaderLength = errors.New("invalid IP header length") +) + const ( - invalidIPHeaderLengthMsg = "invalid IP header length" errRewriteTCPDestinationPort = "rewrite TCP destination port: %v" ) @@ -175,21 +178,6 @@ func (t *portNATTracker) getConnectionNAT(srcIP, dstIP netip.Addr, srcPort, dstP return conn, exists } -// removeConnection removes a tracked connection from the NAT tracking table. -func (t *portNATTracker) removeConnection(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) { - t.mutex.Lock() - defer t.mutex.Unlock() - - key := ConnKey{ - SrcIP: srcIP, - DstIP: dstIP, - SrcPort: srcPort, - DstPort: dstPort, - } - - delete(t.connections, key) -} - // shouldApplyNAT checks if NAT should be applied to a new connection to prevent bidirectional conflicts. func (t *portNATTracker) shouldApplyNAT(srcIP, dstIP netip.Addr, dstPort uint16) bool { t.mutex.RLock() @@ -390,7 +378,7 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP ipHeaderLen := int(d.ip4.IHL) * 4 if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return fmt.Errorf(invalidIPHeaderLengthMsg) + return errInvalidIPHeaderLength } binary.BigEndian.PutUint16(packetData[10:12], 0) @@ -425,7 +413,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip ipHeaderLen := int(d.ip4.IHL) * 4 if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return fmt.Errorf(invalidIPHeaderLengthMsg) + return errInvalidIPHeaderLength } binary.BigEndian.PutUint16(packetData[10:12], 0) @@ -560,11 +548,12 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services on specific ports. func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { var layerType gopacket.LayerType - if protocol == firewall.ProtocolTCP { + switch protocol { + case firewall.ProtocolTCP: layerType = layers.LayerTypeTCP - } else if protocol == firewall.ProtocolUDP { + case firewall.ProtocolUDP: layerType = layers.LayerTypeUDP - } else { + default: return fmt.Errorf("unsupported protocol: %s", protocol) } @@ -594,11 +583,12 @@ func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.L // RemoveInboundDNAT removes inbound DNAT rule for specified local address and ports. func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { var layerType gopacket.LayerType - if protocol == firewall.ProtocolTCP { + switch protocol { + case firewall.ProtocolTCP: layerType = layers.LayerTypeTCP - } else if protocol == firewall.ProtocolUDP { + case firewall.ProtocolUDP: layerType = layers.LayerTypeUDP - } else { + default: return fmt.Errorf("unsupported protocol: %s", protocol) } @@ -747,7 +737,7 @@ func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPo ipHeaderLen := int(d.ip4.IHL) * 4 if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return fmt.Errorf(invalidIPHeaderLengthMsg) + return errInvalidIPHeaderLength } tcpStart := ipHeaderLen @@ -786,7 +776,7 @@ func (m *Manager) rewriteTCPSourcePort(packetData []byte, d *decoder, newPort ui ipHeaderLen := int(d.ip4.IHL) * 4 if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return fmt.Errorf(invalidIPHeaderLengthMsg) + return errInvalidIPHeaderLength } tcpStart := ipHeaderLen diff --git a/client/firewall/uspfilter/nat_test.go b/client/firewall/uspfilter/nat_test.go index f3cd1a5d0..4c43077bc 100644 --- a/client/firewall/uspfilter/nat_test.go +++ b/client/firewall/uspfilter/nat_test.go @@ -538,7 +538,9 @@ func TestSSHPortRedirectionEndToEnd(t *testing.T) { // Read server response buf := make([]byte, 1024) - conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Logf("failed to set read deadline: %v", err) + } n, err := conn.Read(buf) require.NoError(t, err, "Should read server response") diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go index 30957baec..1dc5c72e1 100644 --- a/client/ssh/client/client.go +++ b/client/ssh/client/client.go @@ -273,7 +273,7 @@ func DialInsecure(ctx context.Context, addr, user string) (*Client, error) { config := &ssh.ClientConfig{ User: user, Timeout: 30 * time.Second, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), + HostKeyCallback: ssh.InsecureIgnoreHostKey(), // #nosec G106 - Only used for tests } return dial(ctx, "tcp", addr, config) diff --git a/client/ssh/client/terminal_unix.go b/client/ssh/client/terminal_unix.go index cc8846d58..b47262143 100644 --- a/client/ssh/client/terminal_unix.go +++ b/client/ssh/client/terminal_unix.go @@ -95,31 +95,31 @@ func (c *Client) setupTerminal(session *ssh.Session, fd int) error { ssh.TTY_OP_ISPEED: 14400, ssh.TTY_OP_OSPEED: 14400, // Ctrl+C - ssh.VINTR: 3, + ssh.VINTR: 3, // Ctrl+\ - ssh.VQUIT: 28, + ssh.VQUIT: 28, // Backspace - ssh.VERASE: 127, + ssh.VERASE: 127, // Ctrl+U - ssh.VKILL: 21, + ssh.VKILL: 21, // Ctrl+D - ssh.VEOF: 4, - ssh.VEOL: 0, - ssh.VEOL2: 0, + ssh.VEOF: 4, + ssh.VEOL: 0, + ssh.VEOL2: 0, // Ctrl+Q - ssh.VSTART: 17, + ssh.VSTART: 17, // Ctrl+S - ssh.VSTOP: 19, + ssh.VSTOP: 19, // Ctrl+Z - ssh.VSUSP: 26, + ssh.VSUSP: 26, // Ctrl+O - ssh.VDISCARD: 15, + ssh.VDISCARD: 15, // Ctrl+R - ssh.VREPRINT: 18, + ssh.VREPRINT: 18, // Ctrl+W - ssh.VWERASE: 23, + ssh.VWERASE: 23, // Ctrl+V - ssh.VLNEXT: 22, + ssh.VLNEXT: 22, } terminal := os.Getenv("TERM") diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go index 0e61b4e65..ab59c3d15 100644 --- a/client/ssh/config/manager.go +++ b/client/ssh/config/manager.go @@ -233,7 +233,6 @@ func (m *Manager) SetupSSHClientConfigWithPeers(domains []string, peerKeys []Pee } } - // Try to create system-wide SSH config if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil { log.Warnf("Failed to create SSH config directory %s: %v", m.sshConfigDir, err) @@ -334,7 +333,8 @@ func (m *Manager) RemoveSSHClientConfig() error { // Also try to clean up user config homeDir, err := os.UserHomeDir() if err != nil { - return nil // Not critical + log.Debugf("failed to get user home directory: %v", err) + return nil } userConfigPath := filepath.Join(homeDir, ".ssh", "config") @@ -541,7 +541,8 @@ func (m *Manager) RemoveKnownHostsFile() error { // Also try to clean up user known_hosts homeDir, err := os.UserHomeDir() if err != nil { - return nil // Not critical + log.Debugf("failed to get user home directory: %v", err) + return nil } userKnownHostsPath := filepath.Join(homeDir, ".ssh", m.userKnownHosts) @@ -553,4 +554,3 @@ func (m *Manager) RemoveKnownHostsFile() error { return nil } - diff --git a/client/ssh/config/manager_test.go b/client/ssh/config/manager_test.go index 3b356189a..f8b0373dc 100644 --- a/client/ssh/config/manager_test.go +++ b/client/ssh/config/manager_test.go @@ -207,9 +207,7 @@ func TestManager_DirectoryFallback(t *testing.T) { defer os.RemoveAll(tempDir) // Set HOME to temp directory to control user fallback - originalHome := os.Getenv("HOME") - os.Setenv("HOME", tempDir) - defer os.Setenv("HOME", originalHome) + t.Setenv("HOME", tempDir) // Create manager with non-writable system directories manager := &Manager{ @@ -306,15 +304,7 @@ func TestManager_PeerLimit(t *testing.T) { func TestManager_ForcedSSHConfig(t *testing.T) { // Set force environment variable - originalForce := os.Getenv(EnvForceSSHConfig) - os.Setenv(EnvForceSSHConfig, "true") - defer func() { - if originalForce == "" { - os.Unsetenv(EnvForceSSHConfig) - } else { - os.Setenv(EnvForceSSHConfig, originalForce) - } - }() + t.Setenv(EnvForceSSHConfig, "true") // Create temporary directory for test tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test") diff --git a/client/ssh/server/command_execution.go b/client/ssh/server/command_execution.go index bf7e36dd4..29afa518f 100644 --- a/client/ssh/server/command_execution.go +++ b/client/ssh/server/command_execution.go @@ -36,7 +36,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege } errorMsg += "\n" - if _, writeErr := fmt.Fprintf(session.Stderr(), errorMsg); writeErr != nil { + if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil { logger.Debugf(errWriteSession, writeErr) } if err := session.Exit(1); err != nil { @@ -150,34 +150,6 @@ func (s *Server) handleCommandIO(logger *log.Entry, stdinPipe io.WriteCloser, se } } -// waitForCommandCompletion waits for command completion and handles exit codes -func (s *Server) waitForCommandCompletion(sessionKey SessionKey, session ssh.Session, execCmd *exec.Cmd) bool { - logger := log.WithField("session", sessionKey) - - if err := execCmd.Wait(); err != nil { - logger.Debugf("command execution failed: %v", err) - var exitError *exec.ExitError - if errors.As(err, &exitError) { - if err := session.Exit(exitError.ExitCode()); err != nil { - logger.Debugf(errExitSession, err) - } - } else { - if _, writeErr := fmt.Fprintf(session.Stderr(), "failed to execute command: %v\n", err); writeErr != nil { - logger.Debugf(errWriteSession, writeErr) - } - if err := session.Exit(1); err != nil { - logger.Debugf(errExitSession, err) - } - } - return false - } - - if err := session.Exit(0); err != nil { - logger.Debugf(errExitSession, err) - } - return true -} - // createPtyCommandWithPrivileges creates the exec.Cmd for Pty execution respecting privilege check results func (s *Server) createPtyCommandWithPrivileges(cmd []string, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) { localUser := privilegeResult.User @@ -246,23 +218,6 @@ func (s *Server) waitForCommandCleanup(logger *log.Entry, session ssh.Session, e } } -// handleCommandSessionCancellation handles command session cancellation -func (s *Server) handleCommandSessionCancellation(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, done <-chan error) { - logger.Debugf("session cancelled, terminating command") - s.killProcessGroup(execCmd) - - select { - case err := <-done: - logger.Debugf("command terminated after session cancellation: %v", err) - case <-time.After(5 * time.Second): - logger.Warnf("command did not terminate within 5 seconds after session cancellation") - } - - if err := session.Exit(130); err != nil { - logger.Debugf(errExitSession, err) - } -} - // handleCommandCompletion handles command completion func (s *Server) handleCommandCompletion(logger *log.Entry, session ssh.Session, err error) bool { if err != nil { diff --git a/client/ssh/server/compatibility_test.go b/client/ssh/server/compatibility_test.go index 772b4d4a6..552545adc 100644 --- a/client/ssh/server/compatibility_test.go +++ b/client/ssh/server/compatibility_test.go @@ -15,6 +15,7 @@ import ( "testing" "time" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" @@ -88,28 +89,7 @@ func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username "echo", "hello_world") output, err := cmd.CombinedOutput() - - if err != nil { - t.Logf("SSH command failed: %v", err) - t.Logf("Output: %s", string(output)) - return - } - assert.Contains(t, string(output), "hello_world", "SSH command should execute successfully") -} - -// testSSHCommandExecution tests basic command execution with system SSH client. -func testSSHCommandExecution(t *testing.T, host, port, keyFile string) { - cmd := exec.Command("ssh", - "-i", keyFile, - "-p", port, - "-o", "StrictHostKeyChecking=no", - "-o", "UserKnownHostsFile=/dev/null", - "-o", "ConnectTimeout=5", - fmt.Sprintf("test-user@%s", host), - "echo", "hello_world") - - output, err := cmd.CombinedOutput() if err != nil { t.Logf("SSH command failed: %v", err) t.Logf("Output: %s", string(output)) @@ -269,7 +249,9 @@ func testSSHPortForwarding(t *testing.T, host, port, keyFile string) { _, err = conn.Write([]byte(request)) require.NoError(t, err) - conn.SetReadDeadline(time.Now().Add(3 * time.Second)) + if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil { + log.Debugf("failed to set read deadline: %v", err) + } response := make([]byte, len(expectedResponse)) n, err := io.ReadFull(conn, response) if err != nil { @@ -305,16 +287,16 @@ func generateOpenSSHKey() ([]byte, []byte, error) { // Remove the temp file so ssh-keygen can create it if err := os.Remove(keyPath); err != nil { - // Ignore if file doesn't exist, we just need it gone + t.Logf("failed to remove key file: %v", err) } // Clean up temp files defer func() { if err := os.Remove(keyPath); err != nil { - // Ignore cleanup errors but could log them in debug mode + t.Logf("failed to cleanup key file: %v", err) } if err := os.Remove(keyPath + ".pub"); err != nil { - // Ignore cleanup errors but could log them in debug mode + t.Logf("failed to cleanup public key file: %v", err) } }() diff --git a/client/ssh/server/executor_test.go b/client/ssh/server/executor_test.go deleted file mode 100644 index c7791c185..000000000 --- a/client/ssh/server/executor_test.go +++ /dev/null @@ -1,226 +0,0 @@ -//go:build unix - -package server - -import ( - "context" - "os" - "os/exec" - "os/user" - "runtime" - "strconv" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) { - pd := NewPrivilegeDropper() - - tests := []struct { - name string - uid uint32 - gid uint32 - wantErr bool - }{ - { - name: "valid non-root user", - uid: 1000, - gid: 1000, - wantErr: false, - }, - { - name: "root UID should be rejected", - uid: 0, - gid: 1000, - wantErr: true, - }, - { - name: "root GID should be rejected", - uid: 1000, - gid: 0, - wantErr: true, - }, - { - name: "both root should be rejected", - uid: 0, - gid: 0, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := pd.validatePrivileges(tt.uid, tt.gid) - if tt.wantErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) { - pd := NewPrivilegeDropper() - - config := ExecutorConfig{ - UID: 1000, - GID: 1000, - Groups: []uint32{1000, 1001}, - WorkingDir: "/home/testuser", - Shell: "/bin/bash", - Command: "ls -la", - } - - cmd, err := pd.CreateExecutorCommand(context.Background(), config) - require.NoError(t, err) - require.NotNil(t, cmd) - - // Verify the command is calling netbird ssh exec - assert.Contains(t, cmd.Args, "ssh") - assert.Contains(t, cmd.Args, "exec") - assert.Contains(t, cmd.Args, "--uid") - assert.Contains(t, cmd.Args, "1000") - assert.Contains(t, cmd.Args, "--gid") - assert.Contains(t, cmd.Args, "1000") - assert.Contains(t, cmd.Args, "--groups") - assert.Contains(t, cmd.Args, "1000") - assert.Contains(t, cmd.Args, "1001") - assert.Contains(t, cmd.Args, "--working-dir") - assert.Contains(t, cmd.Args, "/home/testuser") - assert.Contains(t, cmd.Args, "--shell") - assert.Contains(t, cmd.Args, "/bin/bash") - assert.Contains(t, cmd.Args, "--cmd") - assert.Contains(t, cmd.Args, "ls -la") -} - -func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) { - pd := NewPrivilegeDropper() - - config := ExecutorConfig{ - UID: 1000, - GID: 1000, - Groups: []uint32{1000}, - WorkingDir: "/home/testuser", - Shell: "/bin/bash", - Command: "", - } - - cmd, err := pd.CreateExecutorCommand(context.Background(), config) - require.NoError(t, err) - require.NotNil(t, cmd) - - // Verify no command mode (command is empty so no --cmd flag) - assert.NotContains(t, cmd.Args, "--cmd") - assert.NotContains(t, cmd.Args, "--interactive") -} - -// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping -// This test requires root privileges and will be skipped if not running as root -func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("Privilege dropping not supported on Windows") - } - - if os.Geteuid() != 0 { - t.Skip("This test requires root privileges") - } - - // Find a non-root user to test with - testUser, err := user.Lookup("nobody") - if err != nil { - // Try to find any non-root user - testUser, err = findNonRootUser() - if err != nil { - t.Skip("No suitable non-root user found for testing") - } - } - - uid64, err := strconv.ParseUint(testUser.Uid, 10, 32) - require.NoError(t, err) - targetUID := uint32(uid64) - - gid64, err := strconv.ParseUint(testUser.Gid, 10, 32) - require.NoError(t, err) - targetGID := uint32(gid64) - - // Test in a child process to avoid affecting the test runner - if os.Getenv("TEST_PRIVILEGE_DROP") == "1" { - pd := NewPrivilegeDropper() - - // This should succeed - err := pd.DropPrivileges(targetUID, targetGID, []uint32{targetGID}) - require.NoError(t, err) - - // Verify we are now running as the target user - currentUID := uint32(os.Geteuid()) - currentGID := uint32(os.Getegid()) - - assert.Equal(t, targetUID, currentUID, "UID should match target") - assert.Equal(t, targetGID, currentGID, "GID should match target") - assert.NotEqual(t, uint32(0), currentUID, "Should not be running as root") - assert.NotEqual(t, uint32(0), currentGID, "Should not be running as root group") - - return - } - - // Fork a child process to test privilege dropping - cmd := os.Args[0] - args := []string{"-test.run=TestPrivilegeDropper_ActualPrivilegeDrop"} - - env := append(os.Environ(), "TEST_PRIVILEGE_DROP=1") - - execCmd := exec.Command(cmd, args...) - execCmd.Env = env - - err = execCmd.Run() - require.NoError(t, err, "Child process should succeed") -} - -// findNonRootUser finds any non-root user on the system for testing -func findNonRootUser() (*user.User, error) { - // Try common non-root users - commonUsers := []string{"nobody", "daemon", "bin", "sys", "sync", "games", "man", "lp", "mail", "news", "uucp", "proxy", "www-data", "backup", "list", "irc"} - - for _, username := range commonUsers { - if u, err := user.Lookup(username); err == nil { - uid64, err := strconv.ParseUint(u.Uid, 10, 32) - if err != nil { - continue - } - if uid64 != 0 { // Not root - return u, nil - } - } - } - - // If no common users found, create a minimal user info for testing - // This won't actually work for privilege dropping but allows the test structure - return &user.User{ - Uid: "65534", // Standard nobody UID - Gid: "65534", // Standard nobody GID - Username: "nobody", - Name: "nobody", - HomeDir: "/nonexistent", - }, nil -} - -func TestPrivilegeDropper_ExecuteWithPrivilegeDrop_Validation(t *testing.T) { - pd := NewPrivilegeDropper() - - // Test validation of root privileges - this should be caught in CreateExecutorCommand - config := ExecutorConfig{ - UID: 0, // Root UID should be rejected - GID: 1000, - Groups: []uint32{1000}, - WorkingDir: "/tmp", - Shell: "/bin/sh", - Command: "echo test", - } - - _, err := pd.CreateExecutorCommand(context.Background(), config) - assert.Error(t, err) - assert.Contains(t, err.Error(), "root user") -} diff --git a/client/ssh/server/port_forwarding.go b/client/ssh/server/port_forwarding.go index 37e232f17..4cdac9c4a 100644 --- a/client/ssh/server/port_forwarding.go +++ b/client/ssh/server/port_forwarding.go @@ -376,23 +376,6 @@ func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn } } -// registerConnectionCancel stores a cancel function for a connection -func (s *Server) registerConnectionCancel(key ConnectionKey, cancel context.CancelFunc) { - s.mu.Lock() - defer s.mu.Unlock() - if s.sessionCancels == nil { - s.sessionCancels = make(map[ConnectionKey]context.CancelFunc) - } - s.sessionCancels[key] = cancel -} - -// unregisterConnectionCancel removes a connection's cancel function -func (s *Server) unregisterConnectionCancel(key ConnectionKey) { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.sessionCancels, key) -} - // monitorSessionContext watches for session cancellation and closes connections func (s *Server) monitorSessionContext(ctx context.Context, channel cryptossh.Channel, conn net.Conn, closed chan struct{}, closeOnce *bool, logger *log.Entry) { <-ctx.Done() diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index d0ba2e30e..d76a70def 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -375,16 +375,6 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey { return "unknown" } -// cleanupConnectionPortForward removes port forward state from a connection -func (s *Server) cleanupConnectionPortForward(sshConn *cryptossh.ServerConn) { - s.mu.Lock() - defer s.mu.Unlock() - - if state, exists := s.sshConnections[sshConn]; exists { - state.hasActivePortForward = false - } -} - // connectionValidator validates incoming connections based on source IP func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn { s.mu.RLock() diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go index 76174fe07..f1132e7ad 100644 --- a/client/ssh/server/session_handlers.go +++ b/client/ssh/server/session_handlers.go @@ -129,17 +129,17 @@ func (s *Server) buildUserLookupErrorMessage(err error) string { switch { case errors.As(err, &privilegedErr): if privilegedErr.Username == "root" { - return fmt.Sprintf("root login is disabled on this SSH server\n") + return "root login is disabled on this SSH server\n" } - return fmt.Sprintf("privileged user access is disabled on this SSH server\n") + return "privileged user access is disabled on this SSH server\n" case errors.Is(err, ErrPrivilegeRequired): - return fmt.Sprintf("Windows user switching failed - NetBird must run with elevated privileges for user switching\n") + return "Windows user switching failed - NetBird must run with elevated privileges for user switching\n" case errors.Is(err, ErrPrivilegedUserSwitch): - return fmt.Sprintf("Cannot switch to privileged user - current user lacks required privileges\n") + return "Cannot switch to privileged user - current user lacks required privileges\n" default: - return fmt.Sprintf("User authentication failed\n") + return "User authentication failed\n" } } diff --git a/client/ssh/server/shell.go b/client/ssh/server/shell.go index 7de658909..beff8fce7 100644 --- a/client/ssh/server/shell.go +++ b/client/ssh/server/shell.go @@ -18,7 +18,7 @@ import ( const ( defaultUnixShell = "/bin/sh" - pwshExe = "pwsh.exe" + pwshExe = "pwsh.exe" // #nosec G101 - This is not a credential, just executable name powershellExe = "powershell.exe" ) @@ -104,7 +104,7 @@ func prepareUserEnv(user *user.User, shell string) []string { fmt.Sprint("USER=" + user.Username), fmt.Sprint("LOGNAME=" + user.Username), fmt.Sprint("HOME=" + user.HomeDir), - fmt.Sprint("PATH=/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games"), + "PATH=/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games", } } diff --git a/client/ssh/server/socket_filter_linux.go b/client/ssh/server/socket_filter_linux.go index 8b17b99e9..730317192 100644 --- a/client/ssh/server/socket_filter_linux.go +++ b/client/ssh/server/socket_filter_linux.go @@ -85,7 +85,7 @@ func attachSocketFilter(listener net.Listener, wgIfIndex int) error { fd := int(file.Fd()) _, _, errno := syscall.Syscall6( - syscall.SYS_SETSOCKOPT, + unix.SYS_SETSOCKOPT, uintptr(fd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_ATTACH_FILTER), diff --git a/client/ssh/server/user_utils.go b/client/ssh/server/user_utils.go index 24bfd9335..b82aa6b8a 100644 --- a/client/ssh/server/user_utils.go +++ b/client/ssh/server/user_utils.go @@ -183,17 +183,6 @@ func isSameResolvedUser(user1, user2 *user.User) bool { return user1.Uid == user2.Uid } -// logPrivilegeCheckResult logs the final result of privilege checking -func (s *Server) logPrivilegeCheckResult(req PrivilegeCheckRequest, result PrivilegeCheckResult) { - if !result.Allowed { - log.Debugf("Privilege check denied for %s (user: %s, feature: %s): %v", - req.FeatureName, req.RequestedUsername, req.FeatureName, result.Error) - } else { - log.Debugf("Privilege check allowed for %s (user: %s, requires_switching: %v)", - req.FeatureName, req.RequestedUsername, result.RequiresUserSwitching) - } -} - // privilegeCheckContext holds all context needed for privilege checking type privilegeCheckContext struct { currentUser *user.User @@ -389,7 +378,7 @@ func isWindowsPrivilegedSID(sid string) bool { return false } -// buildShellArgs builds shell arguments for executing commands. +// buildShellArgs builds shell arguments for executing commands func buildShellArgs(shell, command string) []string { if command != "" { return []string{shell, "-Command", command} diff --git a/client/ssh/server/user_utils_test.go b/client/ssh/server/user_utils_test.go index 5d3bede15..d0369379c 100644 --- a/client/ssh/server/user_utils_test.go +++ b/client/ssh/server/user_utils_test.go @@ -674,19 +674,20 @@ func TestCheckPrivileges_ActualPlatform(t *testing.T) { FeatureName: "SSH login", }) - if actualOS == "windows" { + switch { + case actualOS == "windows": // Windows should deny user switching assert.False(t, result.Allowed, "Windows should deny user switching") assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed") assert.Contains(t, result.Error.Error(), "user switching not supported", "Should indicate user switching not supported") - } else if !actualIsPrivileged { + case !actualIsPrivileged: // Non-privileged Unix processes should fallback to current user assert.True(t, result.Allowed, "Non-privileged Unix process should fallback to current user") assert.False(t, result.RequiresUserSwitching, "Fallback means no switching actually happens") assert.True(t, result.UsedFallback, "Should indicate fallback was used") assert.NotNil(t, result.User, "Should return current user") - } else { + default: // Privileged Unix processes should attempt user lookup assert.False(t, result.Allowed, "Should fail due to nonexistent user") assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed")