diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 525bcdef1..0acf0b133 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -634,7 +634,11 @@ func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward return err } - cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr) + if err := validateDestinationPort(remoteAddr); err != nil { + return fmt.Errorf("invalid remote address: %w", err) + } + + log.Debugf("Local port forwarding: %s -> %s", localAddr, remoteAddr) go func() { if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) { @@ -652,7 +656,11 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar return err } - cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr) + if err := validateDestinationPort(localAddr); err != nil { + return fmt.Errorf("invalid local address: %w", err) + } + + log.Debugf("Remote port forwarding: %s -> %s", remoteAddr, localAddr) go func() { if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) { @@ -663,6 +671,35 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar return nil } +// validateDestinationPort checks that the destination address has a valid port. +// Port 0 is only valid for bind addresses (where the OS picks an available port), +// not for destination addresses where we need to connect. +func validateDestinationPort(addr string) error { + if strings.HasPrefix(addr, "/") || strings.HasPrefix(addr, "./") { + return nil + } + + _, portStr, err := net.SplitHostPort(addr) + if err != nil { + return fmt.Errorf("parse address %s: %w", addr, err) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("invalid port %s: %w", portStr, err) + } + + if port == 0 { + return fmt.Errorf("port 0 is not valid for destination address") + } + + if port < 0 || port > 65535 { + return fmt.Errorf("port %d out of range (1-65535)", port) + } + + return nil +} + // parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80". // Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket". func parsePortForwardSpec(spec string) (string, string, error) { diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 5ae0c1ad1..5d56befc7 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -2013,6 +2013,7 @@ type SSHSessionInfo struct { RemoteAddress string `protobuf:"bytes,2,opt,name=remoteAddress,proto3" json:"remoteAddress,omitempty"` Command string `protobuf:"bytes,3,opt,name=command,proto3" json:"command,omitempty"` JwtUsername string `protobuf:"bytes,4,opt,name=jwtUsername,proto3" json:"jwtUsername,omitempty"` + PortForwards []string `protobuf:"bytes,5,rep,name=portForwards,proto3" json:"portForwards,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -2075,6 +2076,13 @@ func (x *SSHSessionInfo) GetJwtUsername() string { return "" } +func (x *SSHSessionInfo) GetPortForwards() []string { + if x != nil { + return x.PortForwards + } + return nil +} + // SSHServerState contains the latest state of the SSH server type SSHServerState struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -5706,12 +5714,13 @@ const file_daemon_proto_rawDesc = "" + "\aservers\x18\x01 \x03(\tR\aservers\x12\x18\n" + "\adomains\x18\x02 \x03(\tR\adomains\x12\x18\n" + "\aenabled\x18\x03 \x01(\bR\aenabled\x12\x14\n" + - "\x05error\x18\x04 \x01(\tR\x05error\"\x8e\x01\n" + + "\x05error\x18\x04 \x01(\tR\x05error\"\xb2\x01\n" + "\x0eSSHSessionInfo\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12$\n" + "\rremoteAddress\x18\x02 \x01(\tR\rremoteAddress\x12\x18\n" + "\acommand\x18\x03 \x01(\tR\acommand\x12 \n" + - "\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\"^\n" + + "\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\x12\"\n" + + "\fportForwards\x18\x05 \x03(\tR\fportForwards\"^\n" + "\x0eSSHServerState\x12\x18\n" + "\aenabled\x18\x01 \x01(\bR\aenabled\x122\n" + "\bsessions\x18\x02 \x03(\v2\x16.daemon.SSHSessionInfoR\bsessions\"\xaf\x04\n" + diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 5f30bfe4b..b75ca821a 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -372,6 +372,7 @@ message SSHSessionInfo { string remoteAddress = 2; string command = 3; string jwtUsername = 4; + repeated string portForwards = 5; } // SSHServerState contains the latest state of the SSH server diff --git a/client/server/server.go b/client/server/server.go index 35ac04381..7b6c4e98c 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -1104,6 +1104,7 @@ func (s *Server) getSSHServerState() *proto.SSHServerState { RemoteAddress: session.RemoteAddress, Command: session.Command, JwtUsername: session.JWTUsername, + PortForwards: session.PortForwards, }) } diff --git a/client/ssh/auth/auth.go b/client/ssh/auth/auth.go index 488b6e12e..079282fdc 100644 --- a/client/ssh/auth/auth.go +++ b/client/ssh/auth/auth.go @@ -98,19 +98,17 @@ func (a *Authorizer) Update(config *Config) { len(config.AuthorizedUsers), len(machineUsers)) } -// Authorize validates if a user is authorized to login as the specified OS user -// Returns nil if authorized, or an error describing why authorization failed -func (a *Authorizer) Authorize(jwtUserID, osUsername string) error { +// Authorize validates if a user is authorized to login as the specified OS user. +// Returns a success message describing how authorization was granted, or an error. +func (a *Authorizer) Authorize(jwtUserID, osUsername string) (string, error) { if jwtUserID == "" { - log.Warnf("SSH auth denied: JWT user ID is empty for OS user '%s'", osUsername) - return ErrEmptyUserID + return "", fmt.Errorf("JWT user ID is empty for OS user %q: %w", osUsername, ErrEmptyUserID) } // Hash the JWT user ID for comparison hashedUserID, err := sshuserhash.HashUserID(jwtUserID) if err != nil { - log.Errorf("SSH auth denied: failed to hash user ID '%s' for OS user '%s': %v", jwtUserID, osUsername, err) - return fmt.Errorf("failed to hash user ID: %w", err) + return "", fmt.Errorf("hash user ID %q for OS user %q: %w", jwtUserID, osUsername, err) } a.mu.RLock() @@ -119,8 +117,7 @@ func (a *Authorizer) Authorize(jwtUserID, osUsername string) error { // Find the index of this user in the authorized list userIndex, found := a.findUserIndex(hashedUserID) if !found { - log.Warnf("SSH auth denied: user '%s' (hash: %s) not in authorized list for OS user '%s'", jwtUserID, hashedUserID, osUsername) - return ErrUserNotAuthorized + return "", fmt.Errorf("user %q (hash: %s) not in authorized list for OS user %q: %w", jwtUserID, hashedUserID, osUsername, ErrUserNotAuthorized) } return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex) @@ -128,12 +125,11 @@ func (a *Authorizer) Authorize(jwtUserID, osUsername string) error { // checkMachineUserMapping validates if a user's index is authorized for the specified OS user // Checks wildcard mapping first, then specific OS user mappings -func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) error { +func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) (string, error) { // If wildcard exists and user's index is in the wildcard list, allow access to any OS user if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard { if a.isIndexInList(uint32(userIndex), wildcardIndexes) { - log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' via wildcard (index: %d)", jwtUserID, osUsername, userIndex) - return nil + return fmt.Sprintf("granted via wildcard (index: %d)", userIndex), nil } } @@ -141,18 +137,15 @@ func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userI allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername] if !hasMachineUserMapping { // No mapping for this OS user - deny by default (fail closed) - log.Warnf("SSH auth denied: no machine user mapping for OS user '%s' (JWT user: %s)", osUsername, jwtUserID) - return ErrNoMachineUserMapping + return "", fmt.Errorf("no machine user mapping for OS user %q (JWT user: %s): %w", osUsername, jwtUserID, ErrNoMachineUserMapping) } // Check if user's index is in the allowed indexes for this specific OS user if !a.isIndexInList(uint32(userIndex), allowedIndexes) { - log.Warnf("SSH auth denied: user '%s' not mapped to OS user '%s' (user index: %d)", jwtUserID, osUsername, userIndex) - return ErrUserNotMappedToOSUser + return "", fmt.Errorf("user %q not mapped to OS user %q (index: %d): %w", jwtUserID, osUsername, userIndex, ErrUserNotMappedToOSUser) } - log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' (index: %d)", jwtUserID, osUsername, userIndex) - return nil + return fmt.Sprintf("granted (index: %d)", userIndex), nil } // GetUserIDClaim returns the JWT claim name used to extract user IDs diff --git a/client/ssh/auth/auth_test.go b/client/ssh/auth/auth_test.go index 2b3b5a414..fa27b72e8 100644 --- a/client/ssh/auth/auth_test.go +++ b/client/ssh/auth/auth_test.go @@ -24,7 +24,7 @@ func TestAuthorizer_Authorize_UserNotInList(t *testing.T) { authorizer.Update(config) // Try to authorize a different user - err = authorizer.Authorize("unauthorized-user", "root") + _, err = authorizer.Authorize("unauthorized-user", "root") assert.Error(t, err) assert.ErrorIs(t, err, ErrUserNotAuthorized) } @@ -45,15 +45,15 @@ func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T) authorizer.Update(config) // All attempts should fail when no machine user mappings exist (fail closed) - err = authorizer.Authorize("user1", "root") + _, err = authorizer.Authorize("user1", "root") assert.Error(t, err) assert.ErrorIs(t, err, ErrNoMachineUserMapping) - err = authorizer.Authorize("user2", "admin") + _, err = authorizer.Authorize("user2", "admin") assert.Error(t, err) assert.ErrorIs(t, err, ErrNoMachineUserMapping) - err = authorizer.Authorize("user1", "postgres") + _, err = authorizer.Authorize("user1", "postgres") assert.Error(t, err) assert.ErrorIs(t, err, ErrNoMachineUserMapping) } @@ -80,21 +80,21 @@ func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testi authorizer.Update(config) // user1 (index 0) should access root and admin - err = authorizer.Authorize("user1", "root") + _, err = authorizer.Authorize("user1", "root") assert.NoError(t, err) - err = authorizer.Authorize("user1", "admin") + _, err = authorizer.Authorize("user1", "admin") assert.NoError(t, err) // user2 (index 1) should access root and postgres - err = authorizer.Authorize("user2", "root") + _, err = authorizer.Authorize("user2", "root") assert.NoError(t, err) - err = authorizer.Authorize("user2", "postgres") + _, err = authorizer.Authorize("user2", "postgres") assert.NoError(t, err) // user3 (index 2) should access postgres - err = authorizer.Authorize("user3", "postgres") + _, err = authorizer.Authorize("user3", "postgres") assert.NoError(t, err) } @@ -121,22 +121,22 @@ func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testin authorizer.Update(config) // user1 (index 0) should NOT access postgres - err = authorizer.Authorize("user1", "postgres") + _, err = authorizer.Authorize("user1", "postgres") assert.Error(t, err) assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) // user2 (index 1) should NOT access admin - err = authorizer.Authorize("user2", "admin") + _, err = authorizer.Authorize("user2", "admin") assert.Error(t, err) assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) // user3 (index 2) should NOT access root - err = authorizer.Authorize("user3", "root") + _, err = authorizer.Authorize("user3", "root") assert.Error(t, err) assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) // user3 (index 2) should NOT access admin - err = authorizer.Authorize("user3", "admin") + _, err = authorizer.Authorize("user3", "admin") assert.Error(t, err) assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) } @@ -158,7 +158,7 @@ func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) { authorizer.Update(config) // user1 should NOT access an unmapped OS user (fail closed) - err = authorizer.Authorize("user1", "postgres") + _, err = authorizer.Authorize("user1", "postgres") assert.Error(t, err) assert.ErrorIs(t, err, ErrNoMachineUserMapping) } @@ -178,7 +178,7 @@ func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) { authorizer.Update(config) // Empty user ID should fail - err = authorizer.Authorize("", "root") + _, err = authorizer.Authorize("", "root") assert.Error(t, err) assert.ErrorIs(t, err, ErrEmptyUserID) } @@ -211,12 +211,12 @@ func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) { // All users should be authorized for root for i := 0; i < 10; i++ { - err := authorizer.Authorize("user"+string(rune('0'+i)), "root") + _, err := authorizer.Authorize("user"+string(rune('0'+i)), "root") assert.NoError(t, err, "user%d should be authorized", i) } // User not in list should fail - err := authorizer.Authorize("unknown-user", "root") + _, err := authorizer.Authorize("unknown-user", "root") assert.Error(t, err) assert.ErrorIs(t, err, ErrUserNotAuthorized) } @@ -236,14 +236,14 @@ func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) { authorizer.Update(config) // user1 should be authorized - err = authorizer.Authorize("user1", "root") + _, err = authorizer.Authorize("user1", "root") assert.NoError(t, err) // Clear configuration authorizer.Update(nil) // user1 should no longer be authorized - err = authorizer.Authorize("user1", "root") + _, err = authorizer.Authorize("user1", "root") assert.Error(t, err) assert.ErrorIs(t, err, ErrUserNotAuthorized) } @@ -267,16 +267,16 @@ func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) { authorizer.Update(config) // root should work - err = authorizer.Authorize("user1", "root") + _, err = authorizer.Authorize("user1", "root") assert.NoError(t, err) // postgres should fail (no mapping) - err = authorizer.Authorize("user1", "postgres") + _, err = authorizer.Authorize("user1", "postgres") assert.Error(t, err) assert.ErrorIs(t, err, ErrNoMachineUserMapping) // admin should fail (no mapping) - err = authorizer.Authorize("user1", "admin") + _, err = authorizer.Authorize("user1", "admin") assert.Error(t, err) assert.ErrorIs(t, err, ErrNoMachineUserMapping) } @@ -301,7 +301,7 @@ func TestAuthorizer_CustomUserIDClaim(t *testing.T) { assert.Equal(t, "email", authorizer.GetUserIDClaim()) // Authorize with email as user ID - err = authorizer.Authorize("user@example.com", "root") + _, err = authorizer.Authorize("user@example.com", "root") assert.NoError(t, err) } @@ -349,19 +349,19 @@ func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) { authorizer.Update(config) // First user should have access - err := authorizer.Authorize("user"+string(rune(0)), "root") + _, err := authorizer.Authorize("user"+string(rune(0)), "root") assert.NoError(t, err) // Middle user should have access - err = authorizer.Authorize("user"+string(rune(500)), "root") + _, err = authorizer.Authorize("user"+string(rune(500)), "root") assert.NoError(t, err) // Last user should have access - err = authorizer.Authorize("user"+string(rune(999)), "root") + _, err = authorizer.Authorize("user"+string(rune(999)), "root") assert.NoError(t, err) // User not in mapping should NOT have access - err = authorizer.Authorize("user"+string(rune(100)), "root") + _, err = authorizer.Authorize("user"+string(rune(100)), "root") assert.Error(t, err) } @@ -393,7 +393,7 @@ func TestAuthorizer_ConcurrentAuthorization(t *testing.T) { if idx%2 == 0 { user = "user2" } - err := authorizer.Authorize(user, "root") + _, err := authorizer.Authorize(user, "root") errChan <- err }(i) } @@ -426,22 +426,22 @@ func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) { authorizer.Update(config) // All authorized users should be able to access any OS user - err = authorizer.Authorize("user1", "root") + _, err = authorizer.Authorize("user1", "root") assert.NoError(t, err) - err = authorizer.Authorize("user2", "postgres") + _, err = authorizer.Authorize("user2", "postgres") assert.NoError(t, err) - err = authorizer.Authorize("user3", "admin") + _, err = authorizer.Authorize("user3", "admin") assert.NoError(t, err) - err = authorizer.Authorize("user1", "ubuntu") + _, err = authorizer.Authorize("user1", "ubuntu") assert.NoError(t, err) - err = authorizer.Authorize("user2", "nginx") + _, err = authorizer.Authorize("user2", "nginx") assert.NoError(t, err) - err = authorizer.Authorize("user3", "docker") + _, err = authorizer.Authorize("user3", "docker") assert.NoError(t, err) } @@ -462,11 +462,11 @@ func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) { authorizer.Update(config) // user1 should have access - err = authorizer.Authorize("user1", "root") + _, err = authorizer.Authorize("user1", "root") assert.NoError(t, err) // Unauthorized user should still be denied even with wildcard - err = authorizer.Authorize("unauthorized-user", "root") + _, err = authorizer.Authorize("unauthorized-user", "root") assert.Error(t, err) assert.ErrorIs(t, err, ErrUserNotAuthorized) } @@ -492,17 +492,17 @@ func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) { authorizer.Update(config) // Both users should be able to access root via wildcard (takes precedence over specific mapping) - err = authorizer.Authorize("user1", "root") + _, err = authorizer.Authorize("user1", "root") assert.NoError(t, err) - err = authorizer.Authorize("user2", "root") + _, err = authorizer.Authorize("user2", "root") assert.NoError(t, err) // Both users should be able to access any other OS user via wildcard - err = authorizer.Authorize("user1", "postgres") + _, err = authorizer.Authorize("user1", "postgres") assert.NoError(t, err) - err = authorizer.Authorize("user2", "admin") + _, err = authorizer.Authorize("user2", "admin") assert.NoError(t, err) } @@ -526,29 +526,29 @@ func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) { authorizer.Update(config) // user1 can access root - err = authorizer.Authorize("user1", "root") + _, err = authorizer.Authorize("user1", "root") assert.NoError(t, err) // user2 can access postgres - err = authorizer.Authorize("user2", "postgres") + _, err = authorizer.Authorize("user2", "postgres") assert.NoError(t, err) // user1 cannot access postgres - err = authorizer.Authorize("user1", "postgres") + _, err = authorizer.Authorize("user1", "postgres") assert.Error(t, err) assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) // user2 cannot access root - err = authorizer.Authorize("user2", "root") + _, err = authorizer.Authorize("user2", "root") assert.Error(t, err) assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) // Neither can access unmapped OS users - err = authorizer.Authorize("user1", "admin") + _, err = authorizer.Authorize("user1", "admin") assert.Error(t, err) assert.ErrorIs(t, err, ErrNoMachineUserMapping) - err = authorizer.Authorize("user2", "admin") + _, err = authorizer.Authorize("user2", "admin") assert.Error(t, err) assert.ErrorIs(t, err, ErrNoMachineUserMapping) } @@ -578,35 +578,35 @@ func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) { authorizer.Update(config) // wasm (index 0) should access any OS user via wildcard - err = authorizer.Authorize("wasm", "root") + _, err = authorizer.Authorize("wasm", "root") assert.NoError(t, err, "wasm should access root via wildcard") - err = authorizer.Authorize("wasm", "alice") + _, err = authorizer.Authorize("wasm", "alice") assert.NoError(t, err, "wasm should access alice via wildcard") - err = authorizer.Authorize("wasm", "bob") + _, err = authorizer.Authorize("wasm", "bob") assert.NoError(t, err, "wasm should access bob via wildcard") - err = authorizer.Authorize("wasm", "postgres") + _, err = authorizer.Authorize("wasm", "postgres") assert.NoError(t, err, "wasm should access postgres via wildcard") // user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres - err = authorizer.Authorize("user2", "alice") + _, err = authorizer.Authorize("user2", "alice") assert.NoError(t, err, "user2 should access alice via explicit mapping") - err = authorizer.Authorize("user2", "bob") + _, err = authorizer.Authorize("user2", "bob") assert.NoError(t, err, "user2 should access bob via explicit mapping") - err = authorizer.Authorize("user2", "root") + _, err = authorizer.Authorize("user2", "root") assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)") assert.ErrorIs(t, err, ErrNoMachineUserMapping) - err = authorizer.Authorize("user2", "postgres") + _, err = authorizer.Authorize("user2", "postgres") assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)") assert.ErrorIs(t, err, ErrNoMachineUserMapping) // Unauthorized user should still be denied - err = authorizer.Authorize("user3", "root") + _, err = authorizer.Authorize("user3", "root") assert.Error(t, err) assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied") } diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go index aab222093..342da7303 100644 --- a/client/ssh/client/client.go +++ b/client/ssh/client/client.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net" "os" "path/filepath" @@ -551,14 +550,15 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) { defer func() { if err := localConn.Close(); err != nil { - log.Debugf("local connection close error: %v", err) + log.Debugf("local port forwarding: close local connection: %v", err) } }() channel, err := c.client.Dial("tcp", remoteAddr) if err != nil { - if strings.Contains(err.Error(), "administratively prohibited") { - _, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n") + var openErr *ssh.OpenChannelError + if errors.As(err, &openErr) && openErr.Reason == ssh.Prohibited { + _, _ = fmt.Fprintf(os.Stderr, "channel open failed: port forwarding is disabled\n") } else { log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err) } @@ -566,19 +566,11 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) { } defer func() { if err := channel.Close(); err != nil { - log.Debugf("remote channel close error: %v", err) + log.Debugf("local port forwarding: close remote channel: %v", err) } }() - go func() { - if _, err := io.Copy(channel, localConn); err != nil { - log.Debugf("local forward copy error (local->remote): %v", err) - } - }() - - if _, err := io.Copy(localConn, channel); err != nil { - log.Debugf("local forward copy error (remote->local): %v", err) - } + nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel) } // RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr @@ -633,7 +625,7 @@ func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error { return fmt.Errorf("send tcpip-forward request: %w", err) } if !ok { - return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)") + return fmt.Errorf("remote port forwarding denied by server") } return nil } @@ -676,7 +668,7 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st } defer func() { if err := channel.Close(); err != nil { - log.Debugf("remote channel close error: %v", err) + log.Debugf("remote port forwarding: close remote channel: %v", err) } }() @@ -688,19 +680,11 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st } defer func() { if err := localConn.Close(); err != nil { - log.Debugf("local connection close error: %v", err) + log.Debugf("remote port forwarding: close local connection: %v", err) } }() - go func() { - if _, err := io.Copy(localConn, channel); err != nil { - log.Debugf("remote forward copy error (remote->local): %v", err) - } - }() - - if _, err := io.Copy(channel, localConn); err != nil { - log.Debugf("remote forward copy error (local->remote): %v", err) - } + nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel) } // tcpipForwardMsg represents the structure for tcpip-forward requests diff --git a/client/ssh/common.go b/client/ssh/common.go index 6574437b5..f6aec5f9c 100644 --- a/client/ssh/common.go +++ b/client/ssh/common.go @@ -193,3 +193,64 @@ func buildAddressList(hostname string, remote net.Addr) []string { } return addresses } + +// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections. +// It waits for both directions to complete before returning. +// The caller is responsible for closing the connections. +func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) { + done := make(chan struct{}, 2) + + go func() { + if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) { + logger.Debugf("copy error (1->2): %v", err) + } + done <- struct{}{} + }() + + go func() { + if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) { + logger.Debugf("copy error (2->1): %v", err) + } + done <- struct{}{} + }() + + <-done + <-done +} + +func isExpectedCopyError(err error) bool { + return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) +} + +// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections. +// It waits for both directions to complete or for context cancellation before returning. +// Both connections are closed when the function returns. +func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) { + done := make(chan struct{}, 2) + + go func() { + if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) { + logger.Debugf("copy error (1->2): %v", err) + } + done <- struct{}{} + }() + + go func() { + if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) { + logger.Debugf("copy error (2->1): %v", err) + } + done <- struct{}{} + }() + + select { + case <-ctx.Done(): + case <-done: + select { + case <-ctx.Done(): + case <-done: + } + } + + _ = conn1.Close() + _ = conn2.Close() +} diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go index 4e807e33c..cb1c36e13 100644 --- a/client/ssh/proxy/proxy.go +++ b/client/ssh/proxy/proxy.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "encoding/binary" "errors" "fmt" "io" @@ -42,6 +43,14 @@ type SSHProxy struct { conn *grpc.ClientConn daemonClient proto.DaemonServiceClient browserOpener func(string) error + + mu sync.RWMutex + backendClient *cryptossh.Client + // jwtToken is set once in runProxySSHServer before any handlers are called, + // so concurrent access is safe without additional synchronization. + jwtToken string + + forwardedChannelsOnce sync.Once } func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) { @@ -63,6 +72,17 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browse } func (p *SSHProxy) Close() error { + p.mu.Lock() + backendClient := p.backendClient + p.backendClient = nil + p.mu.Unlock() + + if backendClient != nil { + if err := backendClient.Close(); err != nil { + log.Debugf("close backend client: %v", err) + } + } + if p.conn != nil { return p.conn.Close() } @@ -77,16 +97,16 @@ func (p *SSHProxy) Connect(ctx context.Context) error { return fmt.Errorf(jwtAuthErrorMsg, err) } - return p.runProxySSHServer(ctx, jwtToken) + log.Debugf("JWT authentication successful, starting proxy to %s:%d", p.targetHost, p.targetPort) + return p.runProxySSHServer(jwtToken) } -func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error { +func (p *SSHProxy) runProxySSHServer(jwtToken string) error { + p.jwtToken = jwtToken serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion()) sshServer := &ssh.Server{ - Handler: func(s ssh.Session) { - p.handleSSHSession(ctx, s, jwtToken) - }, + Handler: p.handleSSHSession, ChannelHandlers: map[string]ssh.ChannelHandler{ "session": ssh.DefaultSessionHandler, "direct-tcpip": p.directTCPIPHandler, @@ -119,15 +139,20 @@ func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error return nil } -func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) { - targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort)) +func (p *SSHProxy) handleSSHSession(session ssh.Session) { + ptyReq, winCh, isPty := session.Pty() + hasCommand := len(session.Command()) > 0 - sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken) + sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User()) if err != nil { _, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err) return } - defer func() { _ = sshClient.Close() }() + + if !isPty && !hasCommand { + p.handleNonInteractiveSession(session, sshClient) + return + } serverSession, err := sshClient.NewSession() if err != nil { @@ -140,7 +165,6 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw serverSession.Stdout = session serverSession.Stderr = session.Stderr() - ptyReq, winCh, isPty := session.Pty() if isPty { if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil { log.Debugf("PTY request to backend: %v", err) @@ -155,7 +179,7 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw }() } - if len(session.Command()) > 0 { + if hasCommand { if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil { log.Debugf("run command: %v", err) p.handleProxyExitCode(session, err) @@ -176,12 +200,29 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) { var exitErr *cryptossh.ExitError if errors.As(err, &exitErr) { - if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil { - log.Debugf("set exit status: %v", exitErr) + if err := session.Exit(exitErr.ExitStatus()); err != nil { + log.Debugf("set exit status: %v", err) } } } +func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) { + // Create a backend session to mirror the client's session request. + // This keeps the connection alive on the server side while port forwarding channels operate. + serverSession, err := sshClient.NewSession() + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err) + return + } + defer func() { _ = serverSession.Close() }() + + <-session.Context().Done() + + if err := session.Exit(0); err != nil { + log.Debugf("session exit: %v", err) + } +} + func generateHostKey() (ssh.Signer, error) { keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519) if err != nil { @@ -250,8 +291,52 @@ func (c *stdioConn) SetWriteDeadline(_ time.Time) error { return nil } -func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) { - _ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy") +// directTCPIPHandler handles local port forwarding (direct-tcpip channel). +func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, sshCtx ssh.Context) { + var payload struct { + DestAddr string + DestPort uint32 + OriginAddr string + OriginPort uint32 + } + if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil { + _, _ = fmt.Fprintf(p.stderr, "parse direct-tcpip payload: %v\n", err) + _ = newChan.Reject(cryptossh.ConnectionFailed, "invalid payload") + return + } + + dest := fmt.Sprintf("%s:%d", payload.DestAddr, payload.DestPort) + log.Debugf("local port forwarding: %s", dest) + + backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User()) + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "backend connection for port forwarding: %v\n", err) + _ = newChan.Reject(cryptossh.ConnectionFailed, "backend connection failed") + return + } + + backendChan, backendReqs, err := backendClient.OpenChannel("direct-tcpip", newChan.ExtraData()) + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "open backend channel for %s: %v\n", dest, err) + var openErr *cryptossh.OpenChannelError + if errors.As(err, &openErr) { + _ = newChan.Reject(openErr.Reason, openErr.Message) + } else { + _ = newChan.Reject(cryptossh.ConnectionFailed, err.Error()) + } + return + } + go cryptossh.DiscardRequests(backendReqs) + + clientChan, clientReqs, err := newChan.Accept() + if err != nil { + log.Debugf("local port forwarding: accept channel: %v", err) + _ = backendChan.Close() + return + } + go cryptossh.DiscardRequests(clientReqs) + + nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan) } func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) { @@ -354,12 +439,143 @@ func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.Wr } } -func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) { - return false, []byte("port forwarding not supported in proxy") +// tcpipForwardHandler handles remote port forwarding (tcpip-forward request). +func (p *SSHProxy) tcpipForwardHandler(sshCtx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) { + var reqPayload struct { + Host string + Port uint32 + } + if err := cryptossh.Unmarshal(req.Payload, &reqPayload); err != nil { + _, _ = fmt.Fprintf(p.stderr, "parse tcpip-forward payload: %v\n", err) + return false, nil + } + + log.Debugf("tcpip-forward request for %s:%d", reqPayload.Host, reqPayload.Port) + + backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User()) + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "backend connection for remote port forwarding: %v\n", err) + return false, nil + } + + ok, payload, err := backendClient.SendRequest(req.Type, req.WantReply, req.Payload) + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "forward tcpip-forward request for %s:%d: %v\n", reqPayload.Host, reqPayload.Port, err) + return false, nil + } + + if ok { + actualPort := reqPayload.Port + if reqPayload.Port == 0 && len(payload) >= 4 { + actualPort = binary.BigEndian.Uint32(payload) + } + log.Debugf("remote port forwarding established for %s:%d", reqPayload.Host, actualPort) + p.forwardedChannelsOnce.Do(func() { + go p.handleForwardedChannels(sshCtx, backendClient) + }) + } + + return ok, payload } -func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) { - return true, nil +// cancelTcpipForwardHandler handles cancel-tcpip-forward request. +func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) { + var reqPayload struct { + Host string + Port uint32 + } + if err := cryptossh.Unmarshal(req.Payload, &reqPayload); err != nil { + _, _ = fmt.Fprintf(p.stderr, "parse cancel-tcpip-forward payload: %v\n", err) + return false, nil + } + + log.Debugf("cancel-tcpip-forward request for %s:%d", reqPayload.Host, reqPayload.Port) + + backendClient := p.getBackendClient() + if backendClient == nil { + return false, nil + } + + ok, payload, err := backendClient.SendRequest(req.Type, req.WantReply, req.Payload) + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "cancel-tcpip-forward for %s:%d: %v\n", reqPayload.Host, reqPayload.Port, err) + return false, nil + } + + return ok, payload +} + +// getOrCreateBackendClient returns the existing backend client or creates a new one. +func (p *SSHProxy) getOrCreateBackendClient(ctx context.Context, user string) (*cryptossh.Client, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.backendClient != nil { + return p.backendClient, nil + } + + targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort)) + log.Debugf("connecting to backend %s", targetAddr) + + client, err := p.dialBackend(ctx, targetAddr, user, p.jwtToken) + if err != nil { + return nil, err + } + + log.Debugf("backend connection established to %s", targetAddr) + p.backendClient = client + return client, nil +} + +// getBackendClient returns the existing backend client or nil. +func (p *SSHProxy) getBackendClient() *cryptossh.Client { + p.mu.RLock() + defer p.mu.RUnlock() + return p.backendClient +} + +// handleForwardedChannels handles forwarded-tcpip channels from the backend for remote port forwarding. +// When the backend receives incoming connections on the forwarded port, it sends them as +// "forwarded-tcpip" channels which we need to proxy to the client. +func (p *SSHProxy) handleForwardedChannels(sshCtx ssh.Context, backendClient *cryptossh.Client) { + sshConn, ok := sshCtx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn) + if !ok || sshConn == nil { + log.Debugf("no SSH connection in context for forwarded channels") + return + } + + channelChan := backendClient.HandleChannelOpen("forwarded-tcpip") + for { + select { + case <-sshCtx.Done(): + return + case newChannel, ok := <-channelChan: + if !ok { + return + } + go p.handleForwardedChannel(sshCtx, sshConn, newChannel) + } + } +} + +// handleForwardedChannel handles a single forwarded-tcpip channel from the backend. +func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh.ServerConn, newChannel cryptossh.NewChannel) { + backendChan, backendReqs, err := newChannel.Accept() + if err != nil { + log.Debugf("remote port forwarding: accept from backend: %v", err) + return + } + go cryptossh.DiscardRequests(backendReqs) + + clientChan, clientReqs, err := sshConn.OpenChannel("forwarded-tcpip", newChannel.ExtraData()) + if err != nil { + log.Debugf("remote port forwarding: open to client: %v", err) + _ = backendChan.Close() + return + } + go cryptossh.DiscardRequests(clientReqs) + + nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan) } func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) { diff --git a/client/ssh/server/port_forwarding.go b/client/ssh/server/port_forwarding.go index 6138f9296..c60cf4f58 100644 --- a/client/ssh/server/port_forwarding.go +++ b/client/ssh/server/port_forwarding.go @@ -1,25 +1,32 @@ +// Package server implements port forwarding for the SSH server. +// +// Security note: Port forwarding runs in the main server process without privilege separation. +// The attack surface is primarily io.Copy through well-tested standard library code, making it +// lower risk than shell execution which uses privilege-separated child processes. We enforce +// user-level port restrictions: non-privileged users cannot bind to ports < 1024. package server import ( "encoding/binary" "fmt" - "io" "net" + "runtime" "strconv" "github.com/gliderlabs/ssh" log "github.com/sirupsen/logrus" cryptossh "golang.org/x/crypto/ssh" + + nbssh "github.com/netbirdio/netbird/client/ssh" ) -// SessionKey uniquely identifies an SSH session -type SessionKey string +const privilegedPortThreshold = 1024 -// ConnectionKey uniquely identifies a port forwarding connection within a session -type ConnectionKey string +// sessionKey uniquely identifies an SSH session +type sessionKey string -// ForwardKey uniquely identifies a port forwarding listener -type ForwardKey string +// forwardKey uniquely identifies a port forwarding listener +type forwardKey string // tcpipForwardMsg represents the structure for tcpip-forward SSH requests type tcpipForwardMsg struct { @@ -47,34 +54,32 @@ func (s *Server) configurePortForwarding(server *ssh.Server) { allowRemote := s.allowRemotePortForwarding server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool { + logger := s.getRequestLogger(ctx) if !allowLocal { - log.Warnf("local port forwarding denied for %s from %s: disabled by configuration", - net.JoinHostPort(dstHost, fmt.Sprintf("%d", dstPort)), ctx.RemoteAddr()) + logger.Warnf("local port forwarding denied for %s:%d: disabled", dstHost, dstPort) return false } if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil { - log.Warnf("local port forwarding denied for %s:%d from %s: %v", dstHost, dstPort, ctx.RemoteAddr(), err) + logger.Warnf("local port forwarding denied for %s:%d: %v", dstHost, dstPort, err) return false } - log.Debugf("local port forwarding allowed: %s:%d", dstHost, dstPort) return true } server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool { + logger := s.getRequestLogger(ctx) if !allowRemote { - log.Warnf("remote port forwarding denied for %s from %s: disabled by configuration", - net.JoinHostPort(bindHost, fmt.Sprintf("%d", bindPort)), ctx.RemoteAddr()) + logger.Warnf("remote port forwarding denied for %s:%d: disabled", bindHost, bindPort) return false } if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil { - log.Warnf("remote port forwarding denied for %s:%d from %s: %v", bindHost, bindPort, ctx.RemoteAddr(), err) + logger.Warnf("remote port forwarding denied for %s:%d: %v", bindHost, bindPort, err) return false } - log.Debugf("remote port forwarding allowed: %s:%d", bindHost, bindPort) return true } @@ -82,23 +87,20 @@ func (s *Server) configurePortForwarding(server *ssh.Server) { } // checkPortForwardingPrivileges validates privilege requirements for port forwarding operations. -// Returns nil if allowed, error if denied. +// For remote port forwarding (binding), it enforces that non-privileged users cannot bind to +// ports below 1024, mirroring the restriction they would face if binding directly. +// +// Note: FeatureSupportsUserSwitch is true because we accept requests from any authenticated user, +// though we don't actually switch users - port forwarding runs in the server process. The resolved +// user is used for privileged port access checks. func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error { if ctx == nil { return fmt.Errorf("%s port forwarding denied: no context", forwardType) } - username := ctx.User() - remoteAddr := "unknown" - if ctx.RemoteAddr() != nil { - remoteAddr = ctx.RemoteAddr().String() - } - - logger := log.WithFields(log.Fields{"user": username, "remote": remoteAddr, "port": port}) - result := s.CheckPrivileges(PrivilegeCheckRequest{ - RequestedUsername: username, - FeatureSupportsUserSwitch: false, + RequestedUsername: ctx.User(), + FeatureSupportsUserSwitch: true, FeatureName: forwardType + " port forwarding", }) @@ -106,12 +108,42 @@ func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType stri return result.Error } - logger.Debugf("%s port forwarding allowed: user %s validated (port %d)", - forwardType, result.User.Username, port) + if err := s.checkPrivilegedPortAccess(forwardType, port, result); err != nil { + return err + } return nil } +// checkPrivilegedPortAccess enforces that non-privileged users cannot bind to privileged ports. +// This applies to remote port forwarding where the server binds a port on behalf of the user. +// On Windows, there is no privileged port restriction, so this check is skipped. +func (s *Server) checkPrivilegedPortAccess(forwardType string, port uint32, result PrivilegeCheckResult) error { + if runtime.GOOS == "windows" { + return nil + } + + isBindOperation := forwardType == "remote" || forwardType == "tcpip-forward" + if !isBindOperation { + return nil + } + + // Port 0 means "pick any available port", which will be >= 1024 + if port == 0 || port >= privilegedPortThreshold { + return nil + } + + if result.User != nil && isPrivilegedUsername(result.User.Username) { + return nil + } + + username := "unknown" + if result.User != nil { + username = result.User.Username + } + return fmt.Errorf("user %s cannot bind to privileged port %d (requires root)", username, port) +} + // tcpipForwardHandler handles tcpip-forward requests for remote port forwarding. func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) { logger := s.getRequestLogger(ctx) @@ -132,8 +164,6 @@ func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *crypto return false, nil } - logger.Debugf("tcpip-forward request: %s:%d", payload.Host, payload.Port) - sshConn, err := s.getSSHConnection(ctx) if err != nil { logger.Warnf("tcpip-forward request denied: %v", err) @@ -153,8 +183,10 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req * return false, nil } - key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) + key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) if s.removeRemoteForwardListener(key) { + forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, payload.Port) + s.removeConnectionPortForward(ctx.RemoteAddr(), forwardAddr) logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port) return true, nil } @@ -165,14 +197,11 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req * // handleRemoteForwardListener handles incoming connections for remote port forwarding. func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) { - log.Debugf("starting remote forward listener handler for %s:%d", host, port) + logger := s.getRequestLogger(ctx) defer func() { - log.Debugf("cleaning up remote forward listener for %s:%d", host, port) if err := ln.Close(); err != nil { - log.Debugf("remote forward listener close error: %v", err) - } else { - log.Debugf("remote forward listener closed successfully for %s:%d", host, port) + logger.Debugf("remote forward listener close error for %s:%d: %v", host, port, err) } }() @@ -196,28 +225,43 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h select { case result := <-acceptChan: if result.err != nil { - log.Debugf("remote forward accept error: %v", result.err) + logger.Debugf("remote forward accept error: %v", result.err) return } go s.handleRemoteForwardConnection(ctx, result.conn, host, port) case <-ctx.Done(): - log.Debugf("remote forward listener shutting down due to context cancellation for %s:%d", host, port) + logger.Debugf("remote forward listener shutting down for %s:%d", host, port) return } } } -// getRequestLogger creates a logger with user and remote address context +// getRequestLogger creates a logger with session/conn and jwt_user context func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry { - remoteAddr := "unknown" - username := "unknown" - if ctx != nil { - if ctx.RemoteAddr() != nil { - remoteAddr = ctx.RemoteAddr().String() + sessionKey := s.findSessionKeyByContext(ctx) + + s.mu.RLock() + defer s.mu.RUnlock() + + if state, exists := s.sessions[sessionKey]; exists { + logger := log.WithField("session", sessionKey) + if state.jwtUsername != "" { + logger = logger.WithField("jwt_user", state.jwtUsername) } - username = ctx.User() + return logger } - return log.WithFields(log.Fields{"user": username, "remote": remoteAddr}) + + if ctx.RemoteAddr() != nil { + if connState, exists := s.connections[connKey(ctx.RemoteAddr().String())]; exists { + return s.connLogger(connState) + } + } + + remoteAddr := "unknown" + if ctx.RemoteAddr() != nil { + remoteAddr = ctx.RemoteAddr().String() + } + return log.WithField("session", fmt.Sprintf("%s@%s", ctx.User(), remoteAddr)) } // isRemotePortForwardingAllowed checks if remote port forwarding is enabled @@ -227,6 +271,13 @@ func (s *Server) isRemotePortForwardingAllowed() bool { return s.allowRemotePortForwarding } +// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled +func (s *Server) isPortForwardingEnabled() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.allowLocalPortForwarding || s.allowRemotePortForwarding +} + // parseTcpipForwardRequest parses the SSH request payload func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) { var payload tcpipForwardMsg @@ -267,10 +318,11 @@ func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host) } - key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) + key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) s.storeRemoteForwardListener(key, ln) - s.markConnectionActivePortForward(sshConn, ctx.User(), ctx.RemoteAddr().String()) + forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, actualPort) + s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr) go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort) response := make([]byte, 4) @@ -288,44 +340,34 @@ type acceptResult struct { // handleRemoteForwardConnection handles a single remote port forwarding connection func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) { - sessionKey := s.findSessionKeyByContext(ctx) - connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port) - logger := log.WithFields(log.Fields{ - "session": sessionKey, - "conn": connID, - }) + logger := s.getRequestLogger(ctx) - defer func() { - if err := conn.Close(); err != nil { - logger.Debugf("connection close error: %v", err) - } - }() - - sshConn := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn) - if sshConn == nil { + sshConn, ok := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn) + if !ok || sshConn == nil { logger.Debugf("remote forward: no SSH connection in context") + _ = conn.Close() return } remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr) if !ok { logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr()) + _ = conn.Close() return } - channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger) + channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr) if err != nil { - logger.Debugf("open forward channel: %v", err) + logger.Debugf("open forward channel for %s:%d: %v", host, port, err) + _ = conn.Close() return } - s.proxyForwardConnection(ctx, logger, conn, channel) + nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel) } // openForwardChannel creates an SSH forwarded-tcpip channel -func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr, logger *log.Entry) (cryptossh.Channel, error) { - logger.Tracef("opening forwarded-tcpip channel for %s:%d", host, port) - +func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr) (cryptossh.Channel, error) { payload := struct { ConnectedAddress string ConnectedPort uint32 @@ -346,41 +388,3 @@ func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, go cryptossh.DiscardRequests(reqs) return channel, nil } - -// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel -func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) { - done := make(chan struct{}, 2) - - go func() { - if _, err := io.Copy(channel, conn); err != nil { - logger.Debugf("copy error (conn->channel): %v", err) - } - done <- struct{}{} - }() - - go func() { - if _, err := io.Copy(conn, channel); err != nil { - logger.Debugf("copy error (channel->conn): %v", err) - } - done <- struct{}{} - }() - - select { - case <-ctx.Done(): - logger.Debugf("session ended, closing connections") - case <-done: - // First copy finished, wait for second copy or context cancellation - select { - case <-ctx.Done(): - logger.Debugf("session ended, closing connections") - case <-done: - } - } - - if err := channel.Close(); err != nil { - logger.Debugf("channel close error: %v", err) - } - if err := conn.Close(); err != nil { - logger.Debugf("connection close error: %v", err) - } -} diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index 82718d002..f957e66a5 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -9,6 +9,7 @@ import ( "io" "net" "net/netip" + "slices" "strings" "sync" "time" @@ -40,6 +41,11 @@ const ( msgPrivilegedUserDisabled = "privileged user login is disabled" + cmdInteractiveShell = "" + cmdPortForwarding = "" + cmdSFTP = "" + cmdNonInteractive = "" + // DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server DefaultJWTMaxTokenAge = 5 * 60 ) @@ -90,10 +96,10 @@ func logSessionExitError(logger *log.Entry, err error) { } } -// safeLogCommand returns a safe representation of the command for logging +// safeLogCommand returns a safe representation of the command for logging. func safeLogCommand(cmd []string) string { if len(cmd) == 0 { - return "" + return cmdInteractiveShell } if len(cmd) == 1 { return cmd[0] @@ -101,26 +107,50 @@ func safeLogCommand(cmd []string) string { return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1) } -type sshConnectionState struct { - hasActivePortForward bool - username string - remoteAddr string +// connState tracks the state of an SSH connection for port forwarding and status display. +type connState struct { + username string + remoteAddr net.Addr + portForwards []string + jwtUsername string } +// authKey uniquely identifies an authentication attempt by username and remote address. +// Used to temporarily store JWT username between passwordHandler and sessionHandler. type authKey string +// connKey uniquely identifies an SSH connection by its remote address. +// Used to track authenticated connections for status display and port forwarding. +type connKey string + func newAuthKey(username string, remoteAddr net.Addr) authKey { return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String())) } +// sessionState tracks an active SSH session (shell, command, or subsystem like SFTP). +type sessionState struct { + session ssh.Session + sessionType string + jwtUsername string +} + type Server struct { - sshServer *ssh.Server - mu sync.RWMutex - hostKeyPEM []byte - sessions map[SessionKey]ssh.Session - sessionCancels map[ConnectionKey]context.CancelFunc - sessionJWTUsers map[SessionKey]string - pendingAuthJWT map[authKey]string + sshServer *ssh.Server + mu sync.RWMutex + hostKeyPEM []byte + + // sessions tracks active SSH sessions (shell, command, SFTP). + // These are created when a client opens a session channel and requests shell/exec/subsystem. + sessions map[sessionKey]*sessionState + + // pendingAuthJWT temporarily stores JWT username during the auth→session handoff. + // Populated in passwordHandler, consumed in sessionHandler/sftpSubsystemHandler. + pendingAuthJWT map[authKey]string + + // connections tracks all SSH connections by their remote address. + // Populated at authentication time, stores JWT username and port forwards for status display. + connections map[connKey]*connState + allowLocalPortForwarding bool allowRemotePortForwarding bool @@ -132,8 +162,7 @@ type Server struct { wgAddress wgaddr.Address - remoteForwardListeners map[ForwardKey]net.Listener - sshConnections map[*cryptossh.ServerConn]*sshConnectionState + remoteForwardListeners map[forwardKey]net.Listener jwtValidator *jwt.Validator jwtExtractor *jwt.ClaimsExtractor @@ -167,6 +196,7 @@ type SessionInfo struct { RemoteAddress string Command string JWTUsername string + PortForwards []string } // New creates an SSH server instance with the provided host key and optional JWT configuration @@ -175,11 +205,10 @@ func New(config *Config) *Server { s := &Server{ mu: sync.RWMutex{}, hostKeyPEM: config.HostKeyPEM, - sessions: make(map[SessionKey]ssh.Session), - sessionJWTUsers: make(map[SessionKey]string), + sessions: make(map[sessionKey]*sessionState), pendingAuthJWT: make(map[authKey]string), - remoteForwardListeners: make(map[ForwardKey]net.Listener), - sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState), + remoteForwardListeners: make(map[forwardKey]net.Listener), + connections: make(map[connKey]*connState), jwtEnabled: config.JWT != nil, jwtConfig: config.JWT, authorizer: sshauth.NewAuthorizer(), // Initialize with empty config @@ -265,14 +294,8 @@ func (s *Server) Stop() error { s.sshServer = nil maps.Clear(s.sessions) - maps.Clear(s.sessionJWTUsers) maps.Clear(s.pendingAuthJWT) - maps.Clear(s.sshConnections) - - for _, cancelFunc := range s.sessionCancels { - cancelFunc() - } - maps.Clear(s.sessionCancels) + maps.Clear(s.connections) for _, listener := range s.remoteForwardListeners { if err := listener.Close(); err != nil { @@ -284,32 +307,70 @@ func (s *Server) Stop() error { return nil } -// GetStatus returns the current status of the SSH server and active sessions +// GetStatus returns the current status of the SSH server and active sessions. func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) { s.mu.RLock() defer s.mu.RUnlock() enabled = s.sshServer != nil + reportedAddrs := make(map[string]bool) - for sessionKey, session := range s.sessions { - cmd := "" - if len(session.Command()) > 0 { - cmd = safeLogCommand(session.Command()) + for _, state := range s.sessions { + info := s.buildSessionInfo(state) + reportedAddrs[info.RemoteAddress] = true + sessions = append(sessions, info) + } + + // Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only) + for key, connState := range s.connections { + remoteAddr := string(key) + if reportedAddrs[remoteAddr] { + continue + } + cmd := cmdNonInteractive + if len(connState.portForwards) > 0 { + cmd = cmdPortForwarding } - - jwtUsername := s.sessionJWTUsers[sessionKey] - sessions = append(sessions, SessionInfo{ - Username: session.User(), - RemoteAddress: session.RemoteAddr().String(), + Username: connState.username, + RemoteAddress: remoteAddr, Command: cmd, - JWTUsername: jwtUsername, + JWTUsername: connState.jwtUsername, + PortForwards: connState.portForwards, }) } return enabled, sessions } +func (s *Server) buildSessionInfo(state *sessionState) SessionInfo { + session := state.session + cmd := state.sessionType + if cmd == "" { + cmd = safeLogCommand(session.Command()) + } + + remoteAddr := session.RemoteAddr().String() + info := SessionInfo{ + Username: session.User(), + RemoteAddress: remoteAddr, + Command: cmd, + JWTUsername: state.jwtUsername, + } + + connState, exists := s.connections[connKey(remoteAddr)] + if !exists { + return info + } + + info.PortForwards = connState.portForwards + if len(connState.portForwards) > 0 && (cmd == cmdInteractiveShell || cmd == cmdNonInteractive) { + info.Command = cmdPortForwarding + } + + return info +} + // SetNetstackNet sets the netstack network for userspace networking func (s *Server) SetNetstackNet(net *netstack.Net) { s.mu.Lock() @@ -520,69 +581,129 @@ func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]int func (s *Server) passwordHandler(ctx ssh.Context, password string) bool { osUsername := ctx.User() remoteAddr := ctx.RemoteAddr() + logger := s.getRequestLogger(ctx) if err := s.ensureJWTValidator(); err != nil { - log.Errorf("JWT validator initialization failed for user %s from %s: %v", osUsername, remoteAddr, err) + logger.Errorf("JWT validator initialization failed: %v", err) return false } token, err := s.validateJWTToken(password) if err != nil { - log.Warnf("JWT authentication failed for user %s from %s: %v", osUsername, remoteAddr, err) + logger.Warnf("JWT authentication failed: %v", err) return false } userAuth, err := s.extractAndValidateUser(token) if err != nil { - log.Warnf("User validation failed for user %s from %s: %v", osUsername, remoteAddr, err) + logger.Warnf("user validation failed: %v", err) return false } + logger = logger.WithField("jwt_user", userAuth.UserId) + s.mu.RLock() authorizer := s.authorizer s.mu.RUnlock() - if err := authorizer.Authorize(userAuth.UserId, osUsername); err != nil { - log.Warnf("SSH authorization denied for user %s (JWT user ID: %s) from %s: %v", osUsername, userAuth.UserId, remoteAddr, err) + msg, err := authorizer.Authorize(userAuth.UserId, osUsername) + if err != nil { + logger.Warnf("SSH auth denied: %v", err) return false } + logger.Infof("SSH auth %s", msg) + key := newAuthKey(osUsername, remoteAddr) + remoteAddrStr := ctx.RemoteAddr().String() s.mu.Lock() s.pendingAuthJWT[key] = userAuth.UserId + s.connections[connKey(remoteAddrStr)] = &connState{ + username: ctx.User(), + remoteAddr: ctx.RemoteAddr(), + jwtUsername: userAuth.UserId, + } s.mu.Unlock() - log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", osUsername, userAuth.UserId, remoteAddr) return true } -func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) { +func (s *Server) addConnectionPortForward(username string, remoteAddr net.Addr, forwardAddr string) { s.mu.Lock() defer s.mu.Unlock() - if state, exists := s.sshConnections[sshConn]; exists { - state.hasActivePortForward = true - } else { - s.sshConnections[sshConn] = &sshConnectionState{ - hasActivePortForward: true, - username: username, - remoteAddr: remoteAddr, + key := connKey(remoteAddr.String()) + if state, exists := s.connections[key]; exists { + if !slices.Contains(state.portForwards, forwardAddr) { + state.portForwards = append(state.portForwards, forwardAddr) } + return + } + + // Connection not in connections (non-JWT auth path) + s.connections[key] = &connState{ + username: username, + remoteAddr: remoteAddr, + portForwards: []string{forwardAddr}, + jwtUsername: s.pendingAuthJWT[newAuthKey(username, remoteAddr)], } } -func (s *Server) connectionCloseHandler(conn net.Conn, err error) { - // We can't extract the SSH connection from net.Conn directly - // Connection cleanup will happen during session cleanup or via timeout - log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err) +func (s *Server) removeConnectionPortForward(remoteAddr net.Addr, forwardAddr string) { + s.mu.Lock() + defer s.mu.Unlock() + + state, exists := s.connections[connKey(remoteAddr.String())] + if !exists { + return + } + + state.portForwards = slices.DeleteFunc(state.portForwards, func(addr string) bool { + return addr == forwardAddr + }) } -func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey { +// trackedConn wraps a net.Conn to detect when it closes +type trackedConn struct { + net.Conn + server *Server + remoteAddr string + onceClose sync.Once +} + +func (c *trackedConn) Close() error { + err := c.Conn.Close() + c.onceClose.Do(func() { + c.server.handleConnectionClose(c.remoteAddr) + }) + return err +} + +func (s *Server) handleConnectionClose(remoteAddr string) { + s.mu.Lock() + defer s.mu.Unlock() + + key := connKey(remoteAddr) + state, exists := s.connections[key] + if exists && len(state.portForwards) > 0 { + s.connLogger(state).Info("port forwarding connection closed") + } + delete(s.connections, key) +} + +func (s *Server) connLogger(state *connState) *log.Entry { + logger := log.WithField("session", fmt.Sprintf("%s@%s", state.username, state.remoteAddr)) + if state.jwtUsername != "" { + logger = logger.WithField("jwt_user", state.jwtUsername) + } + return logger +} + +func (s *Server) findSessionKeyByContext(ctx ssh.Context) sessionKey { if ctx == nil { return "unknown" } - // Try to match by SSH connection sshConn := ctx.Value(ssh.ContextKeyConn) if sshConn == nil { return "unknown" @@ -591,19 +712,14 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey { s.mu.RLock() defer s.mu.RUnlock() - // Look through sessions to find one with matching connection - for sessionKey, session := range s.sessions { - if session.Context().Value(ssh.ContextKeyConn) == sshConn { + for sessionKey, state := range s.sessions { + if state.session.Context().Value(ssh.ContextKeyConn) == sshConn { return sessionKey } } - // If no session found, this might be during early connection setup - // Return a temporary key that we'll fix up later if ctx.User() != "" && ctx.RemoteAddr() != nil { - tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String())) - log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey) - return tempKey + return sessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String())) } return "unknown" @@ -644,7 +760,11 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn { } log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr) - return conn + return &trackedConn{ + Conn: conn, + server: s, + remoteAddr: conn.RemoteAddr().String(), + } } func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) { @@ -672,9 +792,8 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) { "tcpip-forward": s.tcpipForwardHandler, "cancel-tcpip-forward": s.cancelTcpipForwardHandler, }, - ConnCallback: s.connectionValidator, - ConnectionFailedCallback: s.connectionCloseHandler, - Version: serverVersion, + ConnCallback: s.connectionValidator, + Version: serverVersion, } if s.jwtEnabled { @@ -690,13 +809,13 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) { return server, nil } -func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) { +func (s *Server) storeRemoteForwardListener(key forwardKey, ln net.Listener) { s.mu.Lock() defer s.mu.Unlock() s.remoteForwardListeners[key] = ln } -func (s *Server) removeRemoteForwardListener(key ForwardKey) bool { +func (s *Server) removeRemoteForwardListener(key forwardKey) bool { s.mu.Lock() defer s.mu.Unlock() @@ -714,6 +833,8 @@ func (s *Server) removeRemoteForwardListener(key ForwardKey) bool { } func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) { + logger := s.getRequestLogger(ctx) + var payload struct { Host string Port uint32 @@ -723,7 +844,7 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil { if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil { - log.Debugf("channel reject error: %v", err) + logger.Debugf("channel reject error: %v", err) } return } @@ -733,19 +854,20 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, s.mu.RUnlock() if !allowLocal { - log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port) + logger.Warnf("local port forwarding denied for %s:%d: disabled", payload.Host, payload.Port) _ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled") return } - // Check privilege requirements for the destination port if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil { - log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err) + logger.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err) _ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges") return } - log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port) + forwardAddr := fmt.Sprintf("-L %s:%d", payload.Host, payload.Port) + s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr) + logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port) ssh.DirectTCPIPHandler(srv, conn, newChan, ctx) } diff --git a/client/ssh/server/server_config_test.go b/client/ssh/server/server_config_test.go index 24e455025..d85d85a51 100644 --- a/client/ssh/server/server_config_test.go +++ b/client/ssh/server/server_config_test.go @@ -224,6 +224,96 @@ func TestServer_PortForwardingRestriction(t *testing.T) { } } +func TestServer_PrivilegedPortAccess(t *testing.T) { + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + } + server := New(serverConfig) + server.SetAllowRemotePortForwarding(true) + + tests := []struct { + name string + forwardType string + port uint32 + username string + expectError bool + errorMsg string + skipOnWindows bool + }{ + { + name: "non-root user remote forward privileged port", + forwardType: "remote", + port: 80, + username: "testuser", + expectError: true, + errorMsg: "cannot bind to privileged port", + skipOnWindows: true, + }, + { + name: "non-root user tcpip-forward privileged port", + forwardType: "tcpip-forward", + port: 443, + username: "testuser", + expectError: true, + errorMsg: "cannot bind to privileged port", + skipOnWindows: true, + }, + { + name: "non-root user remote forward unprivileged port", + forwardType: "remote", + port: 8080, + username: "testuser", + expectError: false, + }, + { + name: "non-root user remote forward port 0", + forwardType: "remote", + port: 0, + username: "testuser", + expectError: false, + }, + { + name: "root user remote forward privileged port", + forwardType: "remote", + port: 22, + username: "root", + expectError: false, + }, + { + name: "local forward privileged port allowed for non-root", + forwardType: "local", + port: 80, + username: "testuser", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipOnWindows && runtime.GOOS == "windows" { + t.Skip("Windows does not have privileged port restrictions") + } + + result := PrivilegeCheckResult{ + Allowed: true, + User: &user.User{Username: tt.username}, + } + + err := server.checkPrivilegedPortAccess(tt.forwardType, tt.port, result) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + } + }) + } +} + func TestServer_PortConflictHandling(t *testing.T) { // Test that multiple sessions requesting the same local port are handled naturally by the OS // Get current user for SSH connection @@ -392,3 +482,95 @@ func TestServer_IsPrivilegedUser(t *testing.T) { }) } } + +func TestServer_PortForwardingOnlySession(t *testing.T) { + // Test that sessions without PTY and command are allowed when port forwarding is enabled + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + tests := []struct { + name string + allowLocalForwarding bool + allowRemoteForwarding bool + expectAllowed bool + description string + }{ + { + name: "session_allowed_with_local_forwarding", + allowLocalForwarding: true, + allowRemoteForwarding: false, + expectAllowed: true, + description: "Port-forwarding-only session should be allowed when local forwarding is enabled", + }, + { + name: "session_allowed_with_remote_forwarding", + allowLocalForwarding: false, + allowRemoteForwarding: true, + expectAllowed: true, + description: "Port-forwarding-only session should be allowed when remote forwarding is enabled", + }, + { + name: "session_allowed_with_both", + allowLocalForwarding: true, + allowRemoteForwarding: true, + expectAllowed: true, + description: "Port-forwarding-only session should be allowed when both forwarding types enabled", + }, + { + name: "session_denied_without_forwarding", + allowLocalForwarding: false, + allowRemoteForwarding: false, + expectAllowed: false, + description: "Port-forwarding-only session should be denied when all forwarding is disabled", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + server.SetAllowLocalPortForwarding(tt.allowLocalForwarding) + server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding) + + serverAddr := StartTestServer(t, server) + defer func() { + _ = server.Stop() + }() + + // Connect to the server without requesting PTY or command + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + client, err := sshclient.Dial(ctx, serverAddr, currentUser.Username, sshclient.DialOptions{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + defer func() { + _ = client.Close() + }() + + // Execute a command without PTY - this simulates ssh -T with no command + // The server should either allow it (port forwarding enabled) or reject it + output, err := client.ExecuteCommand(ctx, "") + if tt.expectAllowed { + // When allowed, the session stays open until cancelled + // ExecuteCommand with empty command should return without error + assert.NoError(t, err, "Session should be allowed when port forwarding is enabled") + assert.NotContains(t, output, "port forwarding is disabled", + "Output should not contain port forwarding disabled message") + } else if err != nil { + // When denied, we expect an error message about port forwarding being disabled + assert.Contains(t, err.Error(), "port forwarding is disabled", + "Should get port forwarding disabled message") + } + }) + } +} diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go index 4e6d72098..3fd578064 100644 --- a/client/ssh/server/session_handlers.go +++ b/client/ssh/server/session_handlers.go @@ -6,37 +6,45 @@ import ( "errors" "fmt" "io" - "strings" "time" "github.com/gliderlabs/ssh" log "github.com/sirupsen/logrus" - cryptossh "golang.org/x/crypto/ssh" ) +// associateJWTUsername extracts pending JWT username for the session and associates it with the session state. +// Returns the JWT username (empty if none) for logging purposes. +func (s *Server) associateJWTUsername(sess ssh.Session, sessionKey sessionKey) string { + key := newAuthKey(sess.User(), sess.RemoteAddr()) + + s.mu.Lock() + defer s.mu.Unlock() + + jwtUsername := s.pendingAuthJWT[key] + if jwtUsername == "" { + return "" + } + + if state, exists := s.sessions[sessionKey]; exists { + state.jwtUsername = jwtUsername + } + delete(s.pendingAuthJWT, key) + return jwtUsername +} + // sessionHandler handles SSH sessions func (s *Server) sessionHandler(session ssh.Session) { - sessionKey := s.registerSession(session) - - key := newAuthKey(session.User(), session.RemoteAddr()) - s.mu.Lock() - jwtUsername := s.pendingAuthJWT[key] - if jwtUsername != "" { - s.sessionJWTUsers[sessionKey] = jwtUsername - delete(s.pendingAuthJWT, key) - } - s.mu.Unlock() + sessionKey := s.registerSession(session, "") + jwtUsername := s.associateJWTUsername(session, sessionKey) logger := log.WithField("session", sessionKey) if jwtUsername != "" { logger = logger.WithField("jwt_user", jwtUsername) - logger.Infof("SSH session started (JWT user: %s)", jwtUsername) - } else { - logger.Infof("SSH session started") } + logger.Info("SSH session started") sessionStart := time.Now() - defer s.unregisterSession(sessionKey, session) + defer s.unregisterSession(sessionKey) defer func() { duration := time.Since(sessionStart).Round(time.Millisecond) if err := session.Close(); err != nil && !errors.Is(err, io.EOF) { @@ -65,27 +73,52 @@ func (s *Server) sessionHandler(session ssh.Session) { // ssh - non-Pty command execution s.handleCommand(logger, session, privilegeResult, nil) default: - s.rejectInvalidSession(logger, session) + // ssh -T (or ssh -N) - no PTY, no command + s.handleNonInteractiveSession(logger, session) } } -func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) { - if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil { - logger.Debugf(errWriteSession, err) +// handleNonInteractiveSession handles sessions that have no PTY and no command. +// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N). +func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) { + s.updateSessionType(session, cmdNonInteractive) + + if !s.isPortForwardingEnabled() { + if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil { + logger.Debugf(errWriteSession, err) + } + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + logger.Infof("rejected non-interactive session: port forwarding disabled") + return } - if err := session.Exit(1); err != nil { + + <-session.Context().Done() + + if err := session.Exit(0); err != nil { logSessionExitError(logger, err) } - logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr()) } -func (s *Server) registerSession(session ssh.Session) SessionKey { +func (s *Server) updateSessionType(session ssh.Session, sessionType string) { + s.mu.Lock() + defer s.mu.Unlock() + + for _, state := range s.sessions { + if state.session == session { + state.sessionType = sessionType + return + } + } +} + +func (s *Server) registerSession(session ssh.Session, sessionType string) sessionKey { sessionID := session.Context().Value(ssh.ContextKeySessionID) if sessionID == nil { sessionID = fmt.Sprintf("%p", session) } - // Create a short 4-byte identifier from the full session ID hasher := sha256.New() hasher.Write([]byte(fmt.Sprintf("%v", sessionID))) hash := hasher.Sum(nil) @@ -93,43 +126,23 @@ func (s *Server) registerSession(session ssh.Session) SessionKey { remoteAddr := session.RemoteAddr().String() username := session.User() - sessionKey := SessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID)) + sessionKey := sessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID)) s.mu.Lock() - s.sessions[sessionKey] = session + s.sessions[sessionKey] = &sessionState{ + session: session, + sessionType: sessionType, + } s.mu.Unlock() return sessionKey } -func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) { +func (s *Server) unregisterSession(sessionKey sessionKey) { s.mu.Lock() + defer s.mu.Unlock() + delete(s.sessions, sessionKey) - delete(s.sessionJWTUsers, sessionKey) - - // Cancel all port forwarding connections for this session - var connectionsToCancel []ConnectionKey - for key := range s.sessionCancels { - if strings.HasPrefix(string(key), string(sessionKey)+"-") { - connectionsToCancel = append(connectionsToCancel, key) - } - } - - for _, key := range connectionsToCancel { - if cancelFunc, exists := s.sessionCancels[key]; exists { - log.WithField("session", sessionKey).Debugf("cancelling port forwarding context: %s", key) - cancelFunc() - delete(s.sessionCancels, key) - } - } - - if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil { - if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok { - delete(s.sshConnections, sshConn) - } - } - - s.mu.Unlock() } func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) { diff --git a/client/ssh/server/sftp.go b/client/ssh/server/sftp.go index c2b9f552b..199444abb 100644 --- a/client/ssh/server/sftp.go +++ b/client/ssh/server/sftp.go @@ -18,14 +18,26 @@ func (s *Server) SetAllowSFTP(allow bool) { // sftpSubsystemHandler handles SFTP subsystem requests func (s *Server) sftpSubsystemHandler(sess ssh.Session) { + sessionKey := s.registerSession(sess, cmdSFTP) + defer s.unregisterSession(sessionKey) + + jwtUsername := s.associateJWTUsername(sess, sessionKey) + + logger := log.WithField("session", sessionKey) + if jwtUsername != "" { + logger = logger.WithField("jwt_user", jwtUsername) + } + logger.Info("SFTP session started") + defer logger.Info("SFTP session closed") + s.mu.RLock() allowSFTP := s.allowSFTP s.mu.RUnlock() if !allowSFTP { - log.Debugf("SFTP subsystem request denied: SFTP disabled") + logger.Debug("SFTP subsystem request denied: SFTP disabled") if err := sess.Exit(1); err != nil { - log.Debugf("SFTP session exit failed: %v", err) + logger.Debugf("SFTP session exit: %v", err) } return } @@ -37,31 +49,27 @@ func (s *Server) sftpSubsystemHandler(sess ssh.Session) { }) if !result.Allowed { - log.Warnf("SFTP access denied for user %s from %s: %v", sess.User(), sess.RemoteAddr(), result.Error) + logger.Warnf("SFTP access denied: %v", result.Error) if err := sess.Exit(1); err != nil { - log.Debugf("exit SFTP session: %v", err) + logger.Debugf("exit SFTP session: %v", err) } return } - log.Debugf("SFTP subsystem request from user %s (effective user %s)", sess.User(), result.User.Username) - if !result.RequiresUserSwitching { if err := s.executeSftpDirect(sess); err != nil { - log.Errorf("SFTP direct execution: %v", err) + logger.Errorf("SFTP direct execution: %v", err) } return } if err := s.executeSftpWithPrivilegeDrop(sess, result.User); err != nil { - log.Errorf("SFTP privilege drop execution: %v", err) + logger.Errorf("SFTP privilege drop execution: %v", err) } } // executeSftpDirect executes SFTP directly without privilege dropping func (s *Server) executeSftpDirect(sess ssh.Session) error { - log.Debugf("starting SFTP session for user %s (no privilege dropping)", sess.User()) - sftpServer, err := sftp.NewServer(sess) if err != nil { return fmt.Errorf("SFTP server creation: %w", err) diff --git a/client/status/status.go b/client/status/status.go index d975f0e29..4f31f3637 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -82,10 +82,11 @@ type NsServerGroupStateOutput struct { } type SSHSessionOutput struct { - Username string `json:"username" yaml:"username"` - RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"` - Command string `json:"command" yaml:"command"` - JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"` + Username string `json:"username" yaml:"username"` + RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"` + Command string `json:"command" yaml:"command"` + JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"` + PortForwards []string `json:"portForwards,omitempty" yaml:"portForwards,omitempty"` } type SSHServerStateOutput struct { @@ -220,6 +221,7 @@ func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput { RemoteAddress: session.GetRemoteAddress(), Command: session.GetCommand(), JWTUsername: session.GetJwtUsername(), + PortForwards: session.GetPortForwards(), }) } @@ -475,6 +477,9 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, ) } sshServerStatus += "\n " + sessionDisplay + for _, pf := range session.PortForwards { + sshServerStatus += "\n " + pf + } } } }