[client] Add port forwarding to ssh proxy (#5031)

* Implement port forwarding for the ssh proxy

* Allow user switching for port forwarding
This commit is contained in:
Viktor Liu
2026-01-07 12:18:04 +08:00
committed by GitHub
parent 7142d45ef3
commit f012fb8592
15 changed files with 1006 additions and 370 deletions

View File

@@ -634,7 +634,11 @@ func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward
return err return err
} }
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr) if err := validateDestinationPort(remoteAddr); err != nil {
return fmt.Errorf("invalid remote address: %w", err)
}
log.Debugf("Local port forwarding: %s -> %s", localAddr, remoteAddr)
go func() { go func() {
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) { if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
@@ -652,7 +656,11 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
return err return err
} }
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr) if err := validateDestinationPort(localAddr); err != nil {
return fmt.Errorf("invalid local address: %w", err)
}
log.Debugf("Remote port forwarding: %s -> %s", remoteAddr, localAddr)
go func() { go func() {
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) { if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
@@ -663,6 +671,35 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
return nil return nil
} }
// validateDestinationPort checks that the destination address has a valid port.
// Port 0 is only valid for bind addresses (where the OS picks an available port),
// not for destination addresses where we need to connect.
func validateDestinationPort(addr string) error {
if strings.HasPrefix(addr, "/") || strings.HasPrefix(addr, "./") {
return nil
}
_, portStr, err := net.SplitHostPort(addr)
if err != nil {
return fmt.Errorf("parse address %s: %w", addr, err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("invalid port %s: %w", portStr, err)
}
if port == 0 {
return fmt.Errorf("port 0 is not valid for destination address")
}
if port < 0 || port > 65535 {
return fmt.Errorf("port %d out of range (1-65535)", port)
}
return nil
}
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80". // parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket". // Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
func parsePortForwardSpec(spec string) (string, string, error) { func parsePortForwardSpec(spec string) (string, string, error) {

View File

@@ -2013,6 +2013,7 @@ type SSHSessionInfo struct {
RemoteAddress string `protobuf:"bytes,2,opt,name=remoteAddress,proto3" json:"remoteAddress,omitempty"` RemoteAddress string `protobuf:"bytes,2,opt,name=remoteAddress,proto3" json:"remoteAddress,omitempty"`
Command string `protobuf:"bytes,3,opt,name=command,proto3" json:"command,omitempty"` Command string `protobuf:"bytes,3,opt,name=command,proto3" json:"command,omitempty"`
JwtUsername string `protobuf:"bytes,4,opt,name=jwtUsername,proto3" json:"jwtUsername,omitempty"` JwtUsername string `protobuf:"bytes,4,opt,name=jwtUsername,proto3" json:"jwtUsername,omitempty"`
PortForwards []string `protobuf:"bytes,5,rep,name=portForwards,proto3" json:"portForwards,omitempty"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
} }
@@ -2075,6 +2076,13 @@ func (x *SSHSessionInfo) GetJwtUsername() string {
return "" return ""
} }
func (x *SSHSessionInfo) GetPortForwards() []string {
if x != nil {
return x.PortForwards
}
return nil
}
// SSHServerState contains the latest state of the SSH server // SSHServerState contains the latest state of the SSH server
type SSHServerState struct { type SSHServerState struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
@@ -5706,12 +5714,13 @@ const file_daemon_proto_rawDesc = "" +
"\aservers\x18\x01 \x03(\tR\aservers\x12\x18\n" + "\aservers\x18\x01 \x03(\tR\aservers\x12\x18\n" +
"\adomains\x18\x02 \x03(\tR\adomains\x12\x18\n" + "\adomains\x18\x02 \x03(\tR\adomains\x12\x18\n" +
"\aenabled\x18\x03 \x01(\bR\aenabled\x12\x14\n" + "\aenabled\x18\x03 \x01(\bR\aenabled\x12\x14\n" +
"\x05error\x18\x04 \x01(\tR\x05error\"\x8e\x01\n" + "\x05error\x18\x04 \x01(\tR\x05error\"\xb2\x01\n" +
"\x0eSSHSessionInfo\x12\x1a\n" + "\x0eSSHSessionInfo\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\x12$\n" + "\busername\x18\x01 \x01(\tR\busername\x12$\n" +
"\rremoteAddress\x18\x02 \x01(\tR\rremoteAddress\x12\x18\n" + "\rremoteAddress\x18\x02 \x01(\tR\rremoteAddress\x12\x18\n" +
"\acommand\x18\x03 \x01(\tR\acommand\x12 \n" + "\acommand\x18\x03 \x01(\tR\acommand\x12 \n" +
"\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\"^\n" + "\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\x12\"\n" +
"\fportForwards\x18\x05 \x03(\tR\fportForwards\"^\n" +
"\x0eSSHServerState\x12\x18\n" + "\x0eSSHServerState\x12\x18\n" +
"\aenabled\x18\x01 \x01(\bR\aenabled\x122\n" + "\aenabled\x18\x01 \x01(\bR\aenabled\x122\n" +
"\bsessions\x18\x02 \x03(\v2\x16.daemon.SSHSessionInfoR\bsessions\"\xaf\x04\n" + "\bsessions\x18\x02 \x03(\v2\x16.daemon.SSHSessionInfoR\bsessions\"\xaf\x04\n" +

View File

@@ -372,6 +372,7 @@ message SSHSessionInfo {
string remoteAddress = 2; string remoteAddress = 2;
string command = 3; string command = 3;
string jwtUsername = 4; string jwtUsername = 4;
repeated string portForwards = 5;
} }
// SSHServerState contains the latest state of the SSH server // SSHServerState contains the latest state of the SSH server

View File

@@ -1104,6 +1104,7 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
RemoteAddress: session.RemoteAddress, RemoteAddress: session.RemoteAddress,
Command: session.Command, Command: session.Command,
JwtUsername: session.JWTUsername, JwtUsername: session.JWTUsername,
PortForwards: session.PortForwards,
}) })
} }

View File

@@ -98,19 +98,17 @@ func (a *Authorizer) Update(config *Config) {
len(config.AuthorizedUsers), len(machineUsers)) len(config.AuthorizedUsers), len(machineUsers))
} }
// Authorize validates if a user is authorized to login as the specified OS user // Authorize validates if a user is authorized to login as the specified OS user.
// Returns nil if authorized, or an error describing why authorization failed // Returns a success message describing how authorization was granted, or an error.
func (a *Authorizer) Authorize(jwtUserID, osUsername string) error { func (a *Authorizer) Authorize(jwtUserID, osUsername string) (string, error) {
if jwtUserID == "" { if jwtUserID == "" {
log.Warnf("SSH auth denied: JWT user ID is empty for OS user '%s'", osUsername) return "", fmt.Errorf("JWT user ID is empty for OS user %q: %w", osUsername, ErrEmptyUserID)
return ErrEmptyUserID
} }
// Hash the JWT user ID for comparison // Hash the JWT user ID for comparison
hashedUserID, err := sshuserhash.HashUserID(jwtUserID) hashedUserID, err := sshuserhash.HashUserID(jwtUserID)
if err != nil { if err != nil {
log.Errorf("SSH auth denied: failed to hash user ID '%s' for OS user '%s': %v", jwtUserID, osUsername, err) return "", fmt.Errorf("hash user ID %q for OS user %q: %w", jwtUserID, osUsername, err)
return fmt.Errorf("failed to hash user ID: %w", err)
} }
a.mu.RLock() a.mu.RLock()
@@ -119,8 +117,7 @@ func (a *Authorizer) Authorize(jwtUserID, osUsername string) error {
// Find the index of this user in the authorized list // Find the index of this user in the authorized list
userIndex, found := a.findUserIndex(hashedUserID) userIndex, found := a.findUserIndex(hashedUserID)
if !found { if !found {
log.Warnf("SSH auth denied: user '%s' (hash: %s) not in authorized list for OS user '%s'", jwtUserID, hashedUserID, osUsername) return "", fmt.Errorf("user %q (hash: %s) not in authorized list for OS user %q: %w", jwtUserID, hashedUserID, osUsername, ErrUserNotAuthorized)
return ErrUserNotAuthorized
} }
return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex) return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex)
@@ -128,12 +125,11 @@ func (a *Authorizer) Authorize(jwtUserID, osUsername string) error {
// checkMachineUserMapping validates if a user's index is authorized for the specified OS user // checkMachineUserMapping validates if a user's index is authorized for the specified OS user
// Checks wildcard mapping first, then specific OS user mappings // Checks wildcard mapping first, then specific OS user mappings
func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) error { func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) (string, error) {
// If wildcard exists and user's index is in the wildcard list, allow access to any OS user // If wildcard exists and user's index is in the wildcard list, allow access to any OS user
if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard { if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard {
if a.isIndexInList(uint32(userIndex), wildcardIndexes) { if a.isIndexInList(uint32(userIndex), wildcardIndexes) {
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' via wildcard (index: %d)", jwtUserID, osUsername, userIndex) return fmt.Sprintf("granted via wildcard (index: %d)", userIndex), nil
return nil
} }
} }
@@ -141,18 +137,15 @@ func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userI
allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername] allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername]
if !hasMachineUserMapping { if !hasMachineUserMapping {
// No mapping for this OS user - deny by default (fail closed) // No mapping for this OS user - deny by default (fail closed)
log.Warnf("SSH auth denied: no machine user mapping for OS user '%s' (JWT user: %s)", osUsername, jwtUserID) return "", fmt.Errorf("no machine user mapping for OS user %q (JWT user: %s): %w", osUsername, jwtUserID, ErrNoMachineUserMapping)
return ErrNoMachineUserMapping
} }
// Check if user's index is in the allowed indexes for this specific OS user // Check if user's index is in the allowed indexes for this specific OS user
if !a.isIndexInList(uint32(userIndex), allowedIndexes) { if !a.isIndexInList(uint32(userIndex), allowedIndexes) {
log.Warnf("SSH auth denied: user '%s' not mapped to OS user '%s' (user index: %d)", jwtUserID, osUsername, userIndex) return "", fmt.Errorf("user %q not mapped to OS user %q (index: %d): %w", jwtUserID, osUsername, userIndex, ErrUserNotMappedToOSUser)
return ErrUserNotMappedToOSUser
} }
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' (index: %d)", jwtUserID, osUsername, userIndex) return fmt.Sprintf("granted (index: %d)", userIndex), nil
return nil
} }
// GetUserIDClaim returns the JWT claim name used to extract user IDs // GetUserIDClaim returns the JWT claim name used to extract user IDs

View File

@@ -24,7 +24,7 @@ func TestAuthorizer_Authorize_UserNotInList(t *testing.T) {
authorizer.Update(config) authorizer.Update(config)
// Try to authorize a different user // Try to authorize a different user
err = authorizer.Authorize("unauthorized-user", "root") _, err = authorizer.Authorize("unauthorized-user", "root")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized) assert.ErrorIs(t, err, ErrUserNotAuthorized)
} }
@@ -45,15 +45,15 @@ func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T)
authorizer.Update(config) authorizer.Update(config)
// All attempts should fail when no machine user mappings exist (fail closed) // All attempts should fail when no machine user mappings exist (fail closed)
err = authorizer.Authorize("user1", "root") _, err = authorizer.Authorize("user1", "root")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping) assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "admin") _, err = authorizer.Authorize("user2", "admin")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping) assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user1", "postgres") _, err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping) assert.ErrorIs(t, err, ErrNoMachineUserMapping)
} }
@@ -80,21 +80,21 @@ func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testi
authorizer.Update(config) authorizer.Update(config)
// user1 (index 0) should access root and admin // user1 (index 0) should access root and admin
err = authorizer.Authorize("user1", "root") _, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err) assert.NoError(t, err)
err = authorizer.Authorize("user1", "admin") _, err = authorizer.Authorize("user1", "admin")
assert.NoError(t, err) assert.NoError(t, err)
// user2 (index 1) should access root and postgres // user2 (index 1) should access root and postgres
err = authorizer.Authorize("user2", "root") _, err = authorizer.Authorize("user2", "root")
assert.NoError(t, err) assert.NoError(t, err)
err = authorizer.Authorize("user2", "postgres") _, err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err) assert.NoError(t, err)
// user3 (index 2) should access postgres // user3 (index 2) should access postgres
err = authorizer.Authorize("user3", "postgres") _, err = authorizer.Authorize("user3", "postgres")
assert.NoError(t, err) assert.NoError(t, err)
} }
@@ -121,22 +121,22 @@ func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testin
authorizer.Update(config) authorizer.Update(config)
// user1 (index 0) should NOT access postgres // user1 (index 0) should NOT access postgres
err = authorizer.Authorize("user1", "postgres") _, err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user2 (index 1) should NOT access admin // user2 (index 1) should NOT access admin
err = authorizer.Authorize("user2", "admin") _, err = authorizer.Authorize("user2", "admin")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user3 (index 2) should NOT access root // user3 (index 2) should NOT access root
err = authorizer.Authorize("user3", "root") _, err = authorizer.Authorize("user3", "root")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user3 (index 2) should NOT access admin // user3 (index 2) should NOT access admin
err = authorizer.Authorize("user3", "admin") _, err = authorizer.Authorize("user3", "admin")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
} }
@@ -158,7 +158,7 @@ func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) {
authorizer.Update(config) authorizer.Update(config)
// user1 should NOT access an unmapped OS user (fail closed) // user1 should NOT access an unmapped OS user (fail closed)
err = authorizer.Authorize("user1", "postgres") _, err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping) assert.ErrorIs(t, err, ErrNoMachineUserMapping)
} }
@@ -178,7 +178,7 @@ func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) {
authorizer.Update(config) authorizer.Update(config)
// Empty user ID should fail // Empty user ID should fail
err = authorizer.Authorize("", "root") _, err = authorizer.Authorize("", "root")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrEmptyUserID) assert.ErrorIs(t, err, ErrEmptyUserID)
} }
@@ -211,12 +211,12 @@ func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) {
// All users should be authorized for root // All users should be authorized for root
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
err := authorizer.Authorize("user"+string(rune('0'+i)), "root") _, err := authorizer.Authorize("user"+string(rune('0'+i)), "root")
assert.NoError(t, err, "user%d should be authorized", i) assert.NoError(t, err, "user%d should be authorized", i)
} }
// User not in list should fail // User not in list should fail
err := authorizer.Authorize("unknown-user", "root") _, err := authorizer.Authorize("unknown-user", "root")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized) assert.ErrorIs(t, err, ErrUserNotAuthorized)
} }
@@ -236,14 +236,14 @@ func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) {
authorizer.Update(config) authorizer.Update(config)
// user1 should be authorized // user1 should be authorized
err = authorizer.Authorize("user1", "root") _, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err) assert.NoError(t, err)
// Clear configuration // Clear configuration
authorizer.Update(nil) authorizer.Update(nil)
// user1 should no longer be authorized // user1 should no longer be authorized
err = authorizer.Authorize("user1", "root") _, err = authorizer.Authorize("user1", "root")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized) assert.ErrorIs(t, err, ErrUserNotAuthorized)
} }
@@ -267,16 +267,16 @@ func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) {
authorizer.Update(config) authorizer.Update(config)
// root should work // root should work
err = authorizer.Authorize("user1", "root") _, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err) assert.NoError(t, err)
// postgres should fail (no mapping) // postgres should fail (no mapping)
err = authorizer.Authorize("user1", "postgres") _, err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping) assert.ErrorIs(t, err, ErrNoMachineUserMapping)
// admin should fail (no mapping) // admin should fail (no mapping)
err = authorizer.Authorize("user1", "admin") _, err = authorizer.Authorize("user1", "admin")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping) assert.ErrorIs(t, err, ErrNoMachineUserMapping)
} }
@@ -301,7 +301,7 @@ func TestAuthorizer_CustomUserIDClaim(t *testing.T) {
assert.Equal(t, "email", authorizer.GetUserIDClaim()) assert.Equal(t, "email", authorizer.GetUserIDClaim())
// Authorize with email as user ID // Authorize with email as user ID
err = authorizer.Authorize("user@example.com", "root") _, err = authorizer.Authorize("user@example.com", "root")
assert.NoError(t, err) assert.NoError(t, err)
} }
@@ -349,19 +349,19 @@ func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) {
authorizer.Update(config) authorizer.Update(config)
// First user should have access // First user should have access
err := authorizer.Authorize("user"+string(rune(0)), "root") _, err := authorizer.Authorize("user"+string(rune(0)), "root")
assert.NoError(t, err) assert.NoError(t, err)
// Middle user should have access // Middle user should have access
err = authorizer.Authorize("user"+string(rune(500)), "root") _, err = authorizer.Authorize("user"+string(rune(500)), "root")
assert.NoError(t, err) assert.NoError(t, err)
// Last user should have access // Last user should have access
err = authorizer.Authorize("user"+string(rune(999)), "root") _, err = authorizer.Authorize("user"+string(rune(999)), "root")
assert.NoError(t, err) assert.NoError(t, err)
// User not in mapping should NOT have access // User not in mapping should NOT have access
err = authorizer.Authorize("user"+string(rune(100)), "root") _, err = authorizer.Authorize("user"+string(rune(100)), "root")
assert.Error(t, err) assert.Error(t, err)
} }
@@ -393,7 +393,7 @@ func TestAuthorizer_ConcurrentAuthorization(t *testing.T) {
if idx%2 == 0 { if idx%2 == 0 {
user = "user2" user = "user2"
} }
err := authorizer.Authorize(user, "root") _, err := authorizer.Authorize(user, "root")
errChan <- err errChan <- err
}(i) }(i)
} }
@@ -426,22 +426,22 @@ func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) {
authorizer.Update(config) authorizer.Update(config)
// All authorized users should be able to access any OS user // All authorized users should be able to access any OS user
err = authorizer.Authorize("user1", "root") _, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err) assert.NoError(t, err)
err = authorizer.Authorize("user2", "postgres") _, err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err) assert.NoError(t, err)
err = authorizer.Authorize("user3", "admin") _, err = authorizer.Authorize("user3", "admin")
assert.NoError(t, err) assert.NoError(t, err)
err = authorizer.Authorize("user1", "ubuntu") _, err = authorizer.Authorize("user1", "ubuntu")
assert.NoError(t, err) assert.NoError(t, err)
err = authorizer.Authorize("user2", "nginx") _, err = authorizer.Authorize("user2", "nginx")
assert.NoError(t, err) assert.NoError(t, err)
err = authorizer.Authorize("user3", "docker") _, err = authorizer.Authorize("user3", "docker")
assert.NoError(t, err) assert.NoError(t, err)
} }
@@ -462,11 +462,11 @@ func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) {
authorizer.Update(config) authorizer.Update(config)
// user1 should have access // user1 should have access
err = authorizer.Authorize("user1", "root") _, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err) assert.NoError(t, err)
// Unauthorized user should still be denied even with wildcard // Unauthorized user should still be denied even with wildcard
err = authorizer.Authorize("unauthorized-user", "root") _, err = authorizer.Authorize("unauthorized-user", "root")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized) assert.ErrorIs(t, err, ErrUserNotAuthorized)
} }
@@ -492,17 +492,17 @@ func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) {
authorizer.Update(config) authorizer.Update(config)
// Both users should be able to access root via wildcard (takes precedence over specific mapping) // Both users should be able to access root via wildcard (takes precedence over specific mapping)
err = authorizer.Authorize("user1", "root") _, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err) assert.NoError(t, err)
err = authorizer.Authorize("user2", "root") _, err = authorizer.Authorize("user2", "root")
assert.NoError(t, err) assert.NoError(t, err)
// Both users should be able to access any other OS user via wildcard // Both users should be able to access any other OS user via wildcard
err = authorizer.Authorize("user1", "postgres") _, err = authorizer.Authorize("user1", "postgres")
assert.NoError(t, err) assert.NoError(t, err)
err = authorizer.Authorize("user2", "admin") _, err = authorizer.Authorize("user2", "admin")
assert.NoError(t, err) assert.NoError(t, err)
} }
@@ -526,29 +526,29 @@ func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) {
authorizer.Update(config) authorizer.Update(config)
// user1 can access root // user1 can access root
err = authorizer.Authorize("user1", "root") _, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err) assert.NoError(t, err)
// user2 can access postgres // user2 can access postgres
err = authorizer.Authorize("user2", "postgres") _, err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err) assert.NoError(t, err)
// user1 cannot access postgres // user1 cannot access postgres
err = authorizer.Authorize("user1", "postgres") _, err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user2 cannot access root // user2 cannot access root
err = authorizer.Authorize("user2", "root") _, err = authorizer.Authorize("user2", "root")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// Neither can access unmapped OS users // Neither can access unmapped OS users
err = authorizer.Authorize("user1", "admin") _, err = authorizer.Authorize("user1", "admin")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping) assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "admin") _, err = authorizer.Authorize("user2", "admin")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping) assert.ErrorIs(t, err, ErrNoMachineUserMapping)
} }
@@ -578,35 +578,35 @@ func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
authorizer.Update(config) authorizer.Update(config)
// wasm (index 0) should access any OS user via wildcard // wasm (index 0) should access any OS user via wildcard
err = authorizer.Authorize("wasm", "root") _, err = authorizer.Authorize("wasm", "root")
assert.NoError(t, err, "wasm should access root via wildcard") assert.NoError(t, err, "wasm should access root via wildcard")
err = authorizer.Authorize("wasm", "alice") _, err = authorizer.Authorize("wasm", "alice")
assert.NoError(t, err, "wasm should access alice via wildcard") assert.NoError(t, err, "wasm should access alice via wildcard")
err = authorizer.Authorize("wasm", "bob") _, err = authorizer.Authorize("wasm", "bob")
assert.NoError(t, err, "wasm should access bob via wildcard") assert.NoError(t, err, "wasm should access bob via wildcard")
err = authorizer.Authorize("wasm", "postgres") _, err = authorizer.Authorize("wasm", "postgres")
assert.NoError(t, err, "wasm should access postgres via wildcard") assert.NoError(t, err, "wasm should access postgres via wildcard")
// user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres // user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres
err = authorizer.Authorize("user2", "alice") _, err = authorizer.Authorize("user2", "alice")
assert.NoError(t, err, "user2 should access alice via explicit mapping") assert.NoError(t, err, "user2 should access alice via explicit mapping")
err = authorizer.Authorize("user2", "bob") _, err = authorizer.Authorize("user2", "bob")
assert.NoError(t, err, "user2 should access bob via explicit mapping") assert.NoError(t, err, "user2 should access bob via explicit mapping")
err = authorizer.Authorize("user2", "root") _, err = authorizer.Authorize("user2", "root")
assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)") assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)")
assert.ErrorIs(t, err, ErrNoMachineUserMapping) assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "postgres") _, err = authorizer.Authorize("user2", "postgres")
assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)") assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)")
assert.ErrorIs(t, err, ErrNoMachineUserMapping) assert.ErrorIs(t, err, ErrNoMachineUserMapping)
// Unauthorized user should still be denied // Unauthorized user should still be denied
err = authorizer.Authorize("user3", "root") _, err = authorizer.Authorize("user3", "root")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied") assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
} }

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
@@ -551,14 +550,15 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) { func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
defer func() { defer func() {
if err := localConn.Close(); err != nil { if err := localConn.Close(); err != nil {
log.Debugf("local connection close error: %v", err) log.Debugf("local port forwarding: close local connection: %v", err)
} }
}() }()
channel, err := c.client.Dial("tcp", remoteAddr) channel, err := c.client.Dial("tcp", remoteAddr)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "administratively prohibited") { var openErr *ssh.OpenChannelError
_, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n") if errors.As(err, &openErr) && openErr.Reason == ssh.Prohibited {
_, _ = fmt.Fprintf(os.Stderr, "channel open failed: port forwarding is disabled\n")
} else { } else {
log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err) log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err)
} }
@@ -566,19 +566,11 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
} }
defer func() { defer func() {
if err := channel.Close(); err != nil { if err := channel.Close(); err != nil {
log.Debugf("remote channel close error: %v", err) log.Debugf("local port forwarding: close remote channel: %v", err)
} }
}() }()
go func() { nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
if _, err := io.Copy(channel, localConn); err != nil {
log.Debugf("local forward copy error (local->remote): %v", err)
}
}()
if _, err := io.Copy(localConn, channel); err != nil {
log.Debugf("local forward copy error (remote->local): %v", err)
}
} }
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr // RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
@@ -633,7 +625,7 @@ func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error {
return fmt.Errorf("send tcpip-forward request: %w", err) return fmt.Errorf("send tcpip-forward request: %w", err)
} }
if !ok { if !ok {
return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)") return fmt.Errorf("remote port forwarding denied by server")
} }
return nil return nil
} }
@@ -676,7 +668,7 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
} }
defer func() { defer func() {
if err := channel.Close(); err != nil { if err := channel.Close(); err != nil {
log.Debugf("remote channel close error: %v", err) log.Debugf("remote port forwarding: close remote channel: %v", err)
} }
}() }()
@@ -688,19 +680,11 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
} }
defer func() { defer func() {
if err := localConn.Close(); err != nil { if err := localConn.Close(); err != nil {
log.Debugf("local connection close error: %v", err) log.Debugf("remote port forwarding: close local connection: %v", err)
} }
}() }()
go func() { nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
if _, err := io.Copy(localConn, channel); err != nil {
log.Debugf("remote forward copy error (remote->local): %v", err)
}
}()
if _, err := io.Copy(channel, localConn); err != nil {
log.Debugf("remote forward copy error (local->remote): %v", err)
}
} }
// tcpipForwardMsg represents the structure for tcpip-forward requests // tcpipForwardMsg represents the structure for tcpip-forward requests

View File

@@ -193,3 +193,64 @@ func buildAddressList(hostname string, remote net.Addr) []string {
} }
return addresses return addresses
} }
// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections.
// It waits for both directions to complete before returning.
// The caller is responsible for closing the connections.
func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (1->2): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (2->1): %v", err)
}
done <- struct{}{}
}()
<-done
<-done
}
func isExpectedCopyError(err error) bool {
return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled)
}
// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections.
// It waits for both directions to complete or for context cancellation before returning.
// Both connections are closed when the function returns.
func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (1->2): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (2->1): %v", err)
}
done <- struct{}{}
}()
select {
case <-ctx.Done():
case <-done:
select {
case <-ctx.Done():
case <-done:
}
}
_ = conn1.Close()
_ = conn2.Close()
}

View File

@@ -2,6 +2,7 @@ package proxy
import ( import (
"context" "context"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -42,6 +43,14 @@ type SSHProxy struct {
conn *grpc.ClientConn conn *grpc.ClientConn
daemonClient proto.DaemonServiceClient daemonClient proto.DaemonServiceClient
browserOpener func(string) error browserOpener func(string) error
mu sync.RWMutex
backendClient *cryptossh.Client
// jwtToken is set once in runProxySSHServer before any handlers are called,
// so concurrent access is safe without additional synchronization.
jwtToken string
forwardedChannelsOnce sync.Once
} }
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) { func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) {
@@ -63,6 +72,17 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browse
} }
func (p *SSHProxy) Close() error { func (p *SSHProxy) Close() error {
p.mu.Lock()
backendClient := p.backendClient
p.backendClient = nil
p.mu.Unlock()
if backendClient != nil {
if err := backendClient.Close(); err != nil {
log.Debugf("close backend client: %v", err)
}
}
if p.conn != nil { if p.conn != nil {
return p.conn.Close() return p.conn.Close()
} }
@@ -77,16 +97,16 @@ func (p *SSHProxy) Connect(ctx context.Context) error {
return fmt.Errorf(jwtAuthErrorMsg, err) return fmt.Errorf(jwtAuthErrorMsg, err)
} }
return p.runProxySSHServer(ctx, jwtToken) log.Debugf("JWT authentication successful, starting proxy to %s:%d", p.targetHost, p.targetPort)
return p.runProxySSHServer(jwtToken)
} }
func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error { func (p *SSHProxy) runProxySSHServer(jwtToken string) error {
p.jwtToken = jwtToken
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion()) serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
sshServer := &ssh.Server{ sshServer := &ssh.Server{
Handler: func(s ssh.Session) { Handler: p.handleSSHSession,
p.handleSSHSession(ctx, s, jwtToken)
},
ChannelHandlers: map[string]ssh.ChannelHandler{ ChannelHandlers: map[string]ssh.ChannelHandler{
"session": ssh.DefaultSessionHandler, "session": ssh.DefaultSessionHandler,
"direct-tcpip": p.directTCPIPHandler, "direct-tcpip": p.directTCPIPHandler,
@@ -119,15 +139,20 @@ func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error
return nil return nil
} }
func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) { func (p *SSHProxy) handleSSHSession(session ssh.Session) {
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort)) ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0
sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken) sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
if err != nil { if err != nil {
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err) _, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
return return
} }
defer func() { _ = sshClient.Close() }()
if !isPty && !hasCommand {
p.handleNonInteractiveSession(session, sshClient)
return
}
serverSession, err := sshClient.NewSession() serverSession, err := sshClient.NewSession()
if err != nil { if err != nil {
@@ -140,7 +165,6 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
serverSession.Stdout = session serverSession.Stdout = session
serverSession.Stderr = session.Stderr() serverSession.Stderr = session.Stderr()
ptyReq, winCh, isPty := session.Pty()
if isPty { if isPty {
if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil { if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil {
log.Debugf("PTY request to backend: %v", err) log.Debugf("PTY request to backend: %v", err)
@@ -155,7 +179,7 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
}() }()
} }
if len(session.Command()) > 0 { if hasCommand {
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil { if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
log.Debugf("run command: %v", err) log.Debugf("run command: %v", err)
p.handleProxyExitCode(session, err) p.handleProxyExitCode(session, err)
@@ -176,12 +200,29 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) { func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
var exitErr *cryptossh.ExitError var exitErr *cryptossh.ExitError
if errors.As(err, &exitErr) { if errors.As(err, &exitErr) {
if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil { if err := session.Exit(exitErr.ExitStatus()); err != nil {
log.Debugf("set exit status: %v", exitErr) log.Debugf("set exit status: %v", err)
} }
} }
} }
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
// Create a backend session to mirror the client's session request.
// This keeps the connection alive on the server side while port forwarding channels operate.
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
return
}
defer func() { _ = serverSession.Close() }()
<-session.Context().Done()
if err := session.Exit(0); err != nil {
log.Debugf("session exit: %v", err)
}
}
func generateHostKey() (ssh.Signer, error) { func generateHostKey() (ssh.Signer, error) {
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519) keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil { if err != nil {
@@ -250,8 +291,52 @@ func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
return nil return nil
} }
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) { // directTCPIPHandler handles local port forwarding (direct-tcpip channel).
_ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy") func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, sshCtx ssh.Context) {
var payload struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}
if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
_, _ = fmt.Fprintf(p.stderr, "parse direct-tcpip payload: %v\n", err)
_ = newChan.Reject(cryptossh.ConnectionFailed, "invalid payload")
return
}
dest := fmt.Sprintf("%s:%d", payload.DestAddr, payload.DestPort)
log.Debugf("local port forwarding: %s", dest)
backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User())
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "backend connection for port forwarding: %v\n", err)
_ = newChan.Reject(cryptossh.ConnectionFailed, "backend connection failed")
return
}
backendChan, backendReqs, err := backendClient.OpenChannel("direct-tcpip", newChan.ExtraData())
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "open backend channel for %s: %v\n", dest, err)
var openErr *cryptossh.OpenChannelError
if errors.As(err, &openErr) {
_ = newChan.Reject(openErr.Reason, openErr.Message)
} else {
_ = newChan.Reject(cryptossh.ConnectionFailed, err.Error())
}
return
}
go cryptossh.DiscardRequests(backendReqs)
clientChan, clientReqs, err := newChan.Accept()
if err != nil {
log.Debugf("local port forwarding: accept channel: %v", err)
_ = backendChan.Close()
return
}
go cryptossh.DiscardRequests(clientReqs)
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
} }
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) { func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
@@ -354,12 +439,143 @@ func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.Wr
} }
} }
func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) { // tcpipForwardHandler handles remote port forwarding (tcpip-forward request).
return false, []byte("port forwarding not supported in proxy") func (p *SSHProxy) tcpipForwardHandler(sshCtx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
var reqPayload struct {
Host string
Port uint32
}
if err := cryptossh.Unmarshal(req.Payload, &reqPayload); err != nil {
_, _ = fmt.Fprintf(p.stderr, "parse tcpip-forward payload: %v\n", err)
return false, nil
}
log.Debugf("tcpip-forward request for %s:%d", reqPayload.Host, reqPayload.Port)
backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User())
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "backend connection for remote port forwarding: %v\n", err)
return false, nil
}
ok, payload, err := backendClient.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "forward tcpip-forward request for %s:%d: %v\n", reqPayload.Host, reqPayload.Port, err)
return false, nil
}
if ok {
actualPort := reqPayload.Port
if reqPayload.Port == 0 && len(payload) >= 4 {
actualPort = binary.BigEndian.Uint32(payload)
}
log.Debugf("remote port forwarding established for %s:%d", reqPayload.Host, actualPort)
p.forwardedChannelsOnce.Do(func() {
go p.handleForwardedChannels(sshCtx, backendClient)
})
}
return ok, payload
} }
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) { // cancelTcpipForwardHandler handles cancel-tcpip-forward request.
return true, nil func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
var reqPayload struct {
Host string
Port uint32
}
if err := cryptossh.Unmarshal(req.Payload, &reqPayload); err != nil {
_, _ = fmt.Fprintf(p.stderr, "parse cancel-tcpip-forward payload: %v\n", err)
return false, nil
}
log.Debugf("cancel-tcpip-forward request for %s:%d", reqPayload.Host, reqPayload.Port)
backendClient := p.getBackendClient()
if backendClient == nil {
return false, nil
}
ok, payload, err := backendClient.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "cancel-tcpip-forward for %s:%d: %v\n", reqPayload.Host, reqPayload.Port, err)
return false, nil
}
return ok, payload
}
// getOrCreateBackendClient returns the existing backend client or creates a new one.
func (p *SSHProxy) getOrCreateBackendClient(ctx context.Context, user string) (*cryptossh.Client, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.backendClient != nil {
return p.backendClient, nil
}
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
log.Debugf("connecting to backend %s", targetAddr)
client, err := p.dialBackend(ctx, targetAddr, user, p.jwtToken)
if err != nil {
return nil, err
}
log.Debugf("backend connection established to %s", targetAddr)
p.backendClient = client
return client, nil
}
// getBackendClient returns the existing backend client or nil.
func (p *SSHProxy) getBackendClient() *cryptossh.Client {
p.mu.RLock()
defer p.mu.RUnlock()
return p.backendClient
}
// handleForwardedChannels handles forwarded-tcpip channels from the backend for remote port forwarding.
// When the backend receives incoming connections on the forwarded port, it sends them as
// "forwarded-tcpip" channels which we need to proxy to the client.
func (p *SSHProxy) handleForwardedChannels(sshCtx ssh.Context, backendClient *cryptossh.Client) {
sshConn, ok := sshCtx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
if !ok || sshConn == nil {
log.Debugf("no SSH connection in context for forwarded channels")
return
}
channelChan := backendClient.HandleChannelOpen("forwarded-tcpip")
for {
select {
case <-sshCtx.Done():
return
case newChannel, ok := <-channelChan:
if !ok {
return
}
go p.handleForwardedChannel(sshCtx, sshConn, newChannel)
}
}
}
// handleForwardedChannel handles a single forwarded-tcpip channel from the backend.
func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh.ServerConn, newChannel cryptossh.NewChannel) {
backendChan, backendReqs, err := newChannel.Accept()
if err != nil {
log.Debugf("remote port forwarding: accept from backend: %v", err)
return
}
go cryptossh.DiscardRequests(backendReqs)
clientChan, clientReqs, err := sshConn.OpenChannel("forwarded-tcpip", newChannel.ExtraData())
if err != nil {
log.Debugf("remote port forwarding: open to client: %v", err)
_ = backendChan.Close()
return
}
go cryptossh.DiscardRequests(clientReqs)
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
} }
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) { func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {

View File

@@ -1,25 +1,32 @@
// Package server implements port forwarding for the SSH server.
//
// Security note: Port forwarding runs in the main server process without privilege separation.
// The attack surface is primarily io.Copy through well-tested standard library code, making it
// lower risk than shell execution which uses privilege-separated child processes. We enforce
// user-level port restrictions: non-privileged users cannot bind to ports < 1024.
package server package server
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io"
"net" "net"
"runtime"
"strconv" "strconv"
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh" cryptossh "golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
) )
// SessionKey uniquely identifies an SSH session const privilegedPortThreshold = 1024
type SessionKey string
// ConnectionKey uniquely identifies a port forwarding connection within a session // sessionKey uniquely identifies an SSH session
type ConnectionKey string type sessionKey string
// ForwardKey uniquely identifies a port forwarding listener // forwardKey uniquely identifies a port forwarding listener
type ForwardKey string type forwardKey string
// tcpipForwardMsg represents the structure for tcpip-forward SSH requests // tcpipForwardMsg represents the structure for tcpip-forward SSH requests
type tcpipForwardMsg struct { type tcpipForwardMsg struct {
@@ -47,34 +54,32 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
allowRemote := s.allowRemotePortForwarding allowRemote := s.allowRemotePortForwarding
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool { server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
logger := s.getRequestLogger(ctx)
if !allowLocal { if !allowLocal {
log.Warnf("local port forwarding denied for %s from %s: disabled by configuration", logger.Warnf("local port forwarding denied for %s:%d: disabled", dstHost, dstPort)
net.JoinHostPort(dstHost, fmt.Sprintf("%d", dstPort)), ctx.RemoteAddr())
return false return false
} }
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil { if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
log.Warnf("local port forwarding denied for %s:%d from %s: %v", dstHost, dstPort, ctx.RemoteAddr(), err) logger.Warnf("local port forwarding denied for %s:%d: %v", dstHost, dstPort, err)
return false return false
} }
log.Debugf("local port forwarding allowed: %s:%d", dstHost, dstPort)
return true return true
} }
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool { server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
logger := s.getRequestLogger(ctx)
if !allowRemote { if !allowRemote {
log.Warnf("remote port forwarding denied for %s from %s: disabled by configuration", logger.Warnf("remote port forwarding denied for %s:%d: disabled", bindHost, bindPort)
net.JoinHostPort(bindHost, fmt.Sprintf("%d", bindPort)), ctx.RemoteAddr())
return false return false
} }
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil { if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
log.Warnf("remote port forwarding denied for %s:%d from %s: %v", bindHost, bindPort, ctx.RemoteAddr(), err) logger.Warnf("remote port forwarding denied for %s:%d: %v", bindHost, bindPort, err)
return false return false
} }
log.Debugf("remote port forwarding allowed: %s:%d", bindHost, bindPort)
return true return true
} }
@@ -82,23 +87,20 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
} }
// checkPortForwardingPrivileges validates privilege requirements for port forwarding operations. // checkPortForwardingPrivileges validates privilege requirements for port forwarding operations.
// Returns nil if allowed, error if denied. // For remote port forwarding (binding), it enforces that non-privileged users cannot bind to
// ports below 1024, mirroring the restriction they would face if binding directly.
//
// Note: FeatureSupportsUserSwitch is true because we accept requests from any authenticated user,
// though we don't actually switch users - port forwarding runs in the server process. The resolved
// user is used for privileged port access checks.
func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error { func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error {
if ctx == nil { if ctx == nil {
return fmt.Errorf("%s port forwarding denied: no context", forwardType) return fmt.Errorf("%s port forwarding denied: no context", forwardType)
} }
username := ctx.User()
remoteAddr := "unknown"
if ctx.RemoteAddr() != nil {
remoteAddr = ctx.RemoteAddr().String()
}
logger := log.WithFields(log.Fields{"user": username, "remote": remoteAddr, "port": port})
result := s.CheckPrivileges(PrivilegeCheckRequest{ result := s.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: username, RequestedUsername: ctx.User(),
FeatureSupportsUserSwitch: false, FeatureSupportsUserSwitch: true,
FeatureName: forwardType + " port forwarding", FeatureName: forwardType + " port forwarding",
}) })
@@ -106,12 +108,42 @@ func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType stri
return result.Error return result.Error
} }
logger.Debugf("%s port forwarding allowed: user %s validated (port %d)", if err := s.checkPrivilegedPortAccess(forwardType, port, result); err != nil {
forwardType, result.User.Username, port) return err
}
return nil return nil
} }
// checkPrivilegedPortAccess enforces that non-privileged users cannot bind to privileged ports.
// This applies to remote port forwarding where the server binds a port on behalf of the user.
// On Windows, there is no privileged port restriction, so this check is skipped.
func (s *Server) checkPrivilegedPortAccess(forwardType string, port uint32, result PrivilegeCheckResult) error {
if runtime.GOOS == "windows" {
return nil
}
isBindOperation := forwardType == "remote" || forwardType == "tcpip-forward"
if !isBindOperation {
return nil
}
// Port 0 means "pick any available port", which will be >= 1024
if port == 0 || port >= privilegedPortThreshold {
return nil
}
if result.User != nil && isPrivilegedUsername(result.User.Username) {
return nil
}
username := "unknown"
if result.User != nil {
username = result.User.Username
}
return fmt.Errorf("user %s cannot bind to privileged port %d (requires root)", username, port)
}
// tcpipForwardHandler handles tcpip-forward requests for remote port forwarding. // tcpipForwardHandler handles tcpip-forward requests for remote port forwarding.
func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) { func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
logger := s.getRequestLogger(ctx) logger := s.getRequestLogger(ctx)
@@ -132,8 +164,6 @@ func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *crypto
return false, nil return false, nil
} }
logger.Debugf("tcpip-forward request: %s:%d", payload.Host, payload.Port)
sshConn, err := s.getSSHConnection(ctx) sshConn, err := s.getSSHConnection(ctx)
if err != nil { if err != nil {
logger.Warnf("tcpip-forward request denied: %v", err) logger.Warnf("tcpip-forward request denied: %v", err)
@@ -153,8 +183,10 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *
return false, nil return false, nil
} }
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
if s.removeRemoteForwardListener(key) { if s.removeRemoteForwardListener(key) {
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, payload.Port)
s.removeConnectionPortForward(ctx.RemoteAddr(), forwardAddr)
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port) logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
return true, nil return true, nil
} }
@@ -165,14 +197,11 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *
// handleRemoteForwardListener handles incoming connections for remote port forwarding. // handleRemoteForwardListener handles incoming connections for remote port forwarding.
func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) { func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) {
log.Debugf("starting remote forward listener handler for %s:%d", host, port) logger := s.getRequestLogger(ctx)
defer func() { defer func() {
log.Debugf("cleaning up remote forward listener for %s:%d", host, port)
if err := ln.Close(); err != nil { if err := ln.Close(); err != nil {
log.Debugf("remote forward listener close error: %v", err) logger.Debugf("remote forward listener close error for %s:%d: %v", host, port, err)
} else {
log.Debugf("remote forward listener closed successfully for %s:%d", host, port)
} }
}() }()
@@ -196,28 +225,43 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h
select { select {
case result := <-acceptChan: case result := <-acceptChan:
if result.err != nil { if result.err != nil {
log.Debugf("remote forward accept error: %v", result.err) logger.Debugf("remote forward accept error: %v", result.err)
return return
} }
go s.handleRemoteForwardConnection(ctx, result.conn, host, port) go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
case <-ctx.Done(): case <-ctx.Done():
log.Debugf("remote forward listener shutting down due to context cancellation for %s:%d", host, port) logger.Debugf("remote forward listener shutting down for %s:%d", host, port)
return return
} }
} }
} }
// getRequestLogger creates a logger with user and remote address context // getRequestLogger creates a logger with session/conn and jwt_user context
func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry { func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry {
remoteAddr := "unknown" sessionKey := s.findSessionKeyByContext(ctx)
username := "unknown"
if ctx != nil { s.mu.RLock()
if ctx.RemoteAddr() != nil { defer s.mu.RUnlock()
remoteAddr = ctx.RemoteAddr().String()
if state, exists := s.sessions[sessionKey]; exists {
logger := log.WithField("session", sessionKey)
if state.jwtUsername != "" {
logger = logger.WithField("jwt_user", state.jwtUsername)
} }
username = ctx.User() return logger
} }
return log.WithFields(log.Fields{"user": username, "remote": remoteAddr})
if ctx.RemoteAddr() != nil {
if connState, exists := s.connections[connKey(ctx.RemoteAddr().String())]; exists {
return s.connLogger(connState)
}
}
remoteAddr := "unknown"
if ctx.RemoteAddr() != nil {
remoteAddr = ctx.RemoteAddr().String()
}
return log.WithField("session", fmt.Sprintf("%s@%s", ctx.User(), remoteAddr))
} }
// isRemotePortForwardingAllowed checks if remote port forwarding is enabled // isRemotePortForwardingAllowed checks if remote port forwarding is enabled
@@ -227,6 +271,13 @@ func (s *Server) isRemotePortForwardingAllowed() bool {
return s.allowRemotePortForwarding return s.allowRemotePortForwarding
} }
// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled
func (s *Server) isPortForwardingEnabled() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.allowLocalPortForwarding || s.allowRemotePortForwarding
}
// parseTcpipForwardRequest parses the SSH request payload // parseTcpipForwardRequest parses the SSH request payload
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) { func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
var payload tcpipForwardMsg var payload tcpipForwardMsg
@@ -267,10 +318,11 @@ func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host) logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
} }
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
s.storeRemoteForwardListener(key, ln) s.storeRemoteForwardListener(key, ln)
s.markConnectionActivePortForward(sshConn, ctx.User(), ctx.RemoteAddr().String()) forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, actualPort)
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort) go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
response := make([]byte, 4) response := make([]byte, 4)
@@ -288,44 +340,34 @@ type acceptResult struct {
// handleRemoteForwardConnection handles a single remote port forwarding connection // handleRemoteForwardConnection handles a single remote port forwarding connection
func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) { func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) {
sessionKey := s.findSessionKeyByContext(ctx) logger := s.getRequestLogger(ctx)
connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port)
logger := log.WithFields(log.Fields{
"session": sessionKey,
"conn": connID,
})
defer func() { sshConn, ok := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
if err := conn.Close(); err != nil { if !ok || sshConn == nil {
logger.Debugf("connection close error: %v", err)
}
}()
sshConn := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
if sshConn == nil {
logger.Debugf("remote forward: no SSH connection in context") logger.Debugf("remote forward: no SSH connection in context")
_ = conn.Close()
return return
} }
remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr) remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
if !ok { if !ok {
logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr()) logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr())
_ = conn.Close()
return return
} }
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger) channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr)
if err != nil { if err != nil {
logger.Debugf("open forward channel: %v", err) logger.Debugf("open forward channel for %s:%d: %v", host, port, err)
_ = conn.Close()
return return
} }
s.proxyForwardConnection(ctx, logger, conn, channel) nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel)
} }
// openForwardChannel creates an SSH forwarded-tcpip channel // openForwardChannel creates an SSH forwarded-tcpip channel
func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr, logger *log.Entry) (cryptossh.Channel, error) { func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr) (cryptossh.Channel, error) {
logger.Tracef("opening forwarded-tcpip channel for %s:%d", host, port)
payload := struct { payload := struct {
ConnectedAddress string ConnectedAddress string
ConnectedPort uint32 ConnectedPort uint32
@@ -346,41 +388,3 @@ func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string,
go cryptossh.DiscardRequests(reqs) go cryptossh.DiscardRequests(reqs)
return channel, nil return channel, nil
} }
// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel
func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(channel, conn); err != nil {
logger.Debugf("copy error (conn->channel): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(conn, channel); err != nil {
logger.Debugf("copy error (channel->conn): %v", err)
}
done <- struct{}{}
}()
select {
case <-ctx.Done():
logger.Debugf("session ended, closing connections")
case <-done:
// First copy finished, wait for second copy or context cancellation
select {
case <-ctx.Done():
logger.Debugf("session ended, closing connections")
case <-done:
}
}
if err := channel.Close(); err != nil {
logger.Debugf("channel close error: %v", err)
}
if err := conn.Close(); err != nil {
logger.Debugf("connection close error: %v", err)
}
}

View File

@@ -9,6 +9,7 @@ import (
"io" "io"
"net" "net"
"net/netip" "net/netip"
"slices"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -40,6 +41,11 @@ const (
msgPrivilegedUserDisabled = "privileged user login is disabled" msgPrivilegedUserDisabled = "privileged user login is disabled"
cmdInteractiveShell = "<interactive shell>"
cmdPortForwarding = "<port forwarding>"
cmdSFTP = "<sftp>"
cmdNonInteractive = "<idle>"
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server // DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
DefaultJWTMaxTokenAge = 5 * 60 DefaultJWTMaxTokenAge = 5 * 60
) )
@@ -90,10 +96,10 @@ func logSessionExitError(logger *log.Entry, err error) {
} }
} }
// safeLogCommand returns a safe representation of the command for logging // safeLogCommand returns a safe representation of the command for logging.
func safeLogCommand(cmd []string) string { func safeLogCommand(cmd []string) string {
if len(cmd) == 0 { if len(cmd) == 0 {
return "<interactive shell>" return cmdInteractiveShell
} }
if len(cmd) == 1 { if len(cmd) == 1 {
return cmd[0] return cmd[0]
@@ -101,26 +107,50 @@ func safeLogCommand(cmd []string) string {
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1) return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
} }
type sshConnectionState struct { // connState tracks the state of an SSH connection for port forwarding and status display.
hasActivePortForward bool type connState struct {
username string username string
remoteAddr string remoteAddr net.Addr
portForwards []string
jwtUsername string
} }
// authKey uniquely identifies an authentication attempt by username and remote address.
// Used to temporarily store JWT username between passwordHandler and sessionHandler.
type authKey string type authKey string
// connKey uniquely identifies an SSH connection by its remote address.
// Used to track authenticated connections for status display and port forwarding.
type connKey string
func newAuthKey(username string, remoteAddr net.Addr) authKey { func newAuthKey(username string, remoteAddr net.Addr) authKey {
return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String())) return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String()))
} }
// sessionState tracks an active SSH session (shell, command, or subsystem like SFTP).
type sessionState struct {
session ssh.Session
sessionType string
jwtUsername string
}
type Server struct { type Server struct {
sshServer *ssh.Server sshServer *ssh.Server
mu sync.RWMutex mu sync.RWMutex
hostKeyPEM []byte hostKeyPEM []byte
sessions map[SessionKey]ssh.Session
sessionCancels map[ConnectionKey]context.CancelFunc // sessions tracks active SSH sessions (shell, command, SFTP).
sessionJWTUsers map[SessionKey]string // These are created when a client opens a session channel and requests shell/exec/subsystem.
pendingAuthJWT map[authKey]string sessions map[sessionKey]*sessionState
// pendingAuthJWT temporarily stores JWT username during the auth→session handoff.
// Populated in passwordHandler, consumed in sessionHandler/sftpSubsystemHandler.
pendingAuthJWT map[authKey]string
// connections tracks all SSH connections by their remote address.
// Populated at authentication time, stores JWT username and port forwards for status display.
connections map[connKey]*connState
allowLocalPortForwarding bool allowLocalPortForwarding bool
allowRemotePortForwarding bool allowRemotePortForwarding bool
@@ -132,8 +162,7 @@ type Server struct {
wgAddress wgaddr.Address wgAddress wgaddr.Address
remoteForwardListeners map[ForwardKey]net.Listener remoteForwardListeners map[forwardKey]net.Listener
sshConnections map[*cryptossh.ServerConn]*sshConnectionState
jwtValidator *jwt.Validator jwtValidator *jwt.Validator
jwtExtractor *jwt.ClaimsExtractor jwtExtractor *jwt.ClaimsExtractor
@@ -167,6 +196,7 @@ type SessionInfo struct {
RemoteAddress string RemoteAddress string
Command string Command string
JWTUsername string JWTUsername string
PortForwards []string
} }
// New creates an SSH server instance with the provided host key and optional JWT configuration // New creates an SSH server instance with the provided host key and optional JWT configuration
@@ -175,11 +205,10 @@ func New(config *Config) *Server {
s := &Server{ s := &Server{
mu: sync.RWMutex{}, mu: sync.RWMutex{},
hostKeyPEM: config.HostKeyPEM, hostKeyPEM: config.HostKeyPEM,
sessions: make(map[SessionKey]ssh.Session), sessions: make(map[sessionKey]*sessionState),
sessionJWTUsers: make(map[SessionKey]string),
pendingAuthJWT: make(map[authKey]string), pendingAuthJWT: make(map[authKey]string),
remoteForwardListeners: make(map[ForwardKey]net.Listener), remoteForwardListeners: make(map[forwardKey]net.Listener),
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState), connections: make(map[connKey]*connState),
jwtEnabled: config.JWT != nil, jwtEnabled: config.JWT != nil,
jwtConfig: config.JWT, jwtConfig: config.JWT,
authorizer: sshauth.NewAuthorizer(), // Initialize with empty config authorizer: sshauth.NewAuthorizer(), // Initialize with empty config
@@ -265,14 +294,8 @@ func (s *Server) Stop() error {
s.sshServer = nil s.sshServer = nil
maps.Clear(s.sessions) maps.Clear(s.sessions)
maps.Clear(s.sessionJWTUsers)
maps.Clear(s.pendingAuthJWT) maps.Clear(s.pendingAuthJWT)
maps.Clear(s.sshConnections) maps.Clear(s.connections)
for _, cancelFunc := range s.sessionCancels {
cancelFunc()
}
maps.Clear(s.sessionCancels)
for _, listener := range s.remoteForwardListeners { for _, listener := range s.remoteForwardListeners {
if err := listener.Close(); err != nil { if err := listener.Close(); err != nil {
@@ -284,32 +307,70 @@ func (s *Server) Stop() error {
return nil return nil
} }
// GetStatus returns the current status of the SSH server and active sessions // GetStatus returns the current status of the SSH server and active sessions.
func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) { func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
enabled = s.sshServer != nil enabled = s.sshServer != nil
reportedAddrs := make(map[string]bool)
for sessionKey, session := range s.sessions { for _, state := range s.sessions {
cmd := "<interactive shell>" info := s.buildSessionInfo(state)
if len(session.Command()) > 0 { reportedAddrs[info.RemoteAddress] = true
cmd = safeLogCommand(session.Command()) sessions = append(sessions, info)
}
// Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only)
for key, connState := range s.connections {
remoteAddr := string(key)
if reportedAddrs[remoteAddr] {
continue
}
cmd := cmdNonInteractive
if len(connState.portForwards) > 0 {
cmd = cmdPortForwarding
} }
jwtUsername := s.sessionJWTUsers[sessionKey]
sessions = append(sessions, SessionInfo{ sessions = append(sessions, SessionInfo{
Username: session.User(), Username: connState.username,
RemoteAddress: session.RemoteAddr().String(), RemoteAddress: remoteAddr,
Command: cmd, Command: cmd,
JWTUsername: jwtUsername, JWTUsername: connState.jwtUsername,
PortForwards: connState.portForwards,
}) })
} }
return enabled, sessions return enabled, sessions
} }
func (s *Server) buildSessionInfo(state *sessionState) SessionInfo {
session := state.session
cmd := state.sessionType
if cmd == "" {
cmd = safeLogCommand(session.Command())
}
remoteAddr := session.RemoteAddr().String()
info := SessionInfo{
Username: session.User(),
RemoteAddress: remoteAddr,
Command: cmd,
JWTUsername: state.jwtUsername,
}
connState, exists := s.connections[connKey(remoteAddr)]
if !exists {
return info
}
info.PortForwards = connState.portForwards
if len(connState.portForwards) > 0 && (cmd == cmdInteractiveShell || cmd == cmdNonInteractive) {
info.Command = cmdPortForwarding
}
return info
}
// SetNetstackNet sets the netstack network for userspace networking // SetNetstackNet sets the netstack network for userspace networking
func (s *Server) SetNetstackNet(net *netstack.Net) { func (s *Server) SetNetstackNet(net *netstack.Net) {
s.mu.Lock() s.mu.Lock()
@@ -520,69 +581,129 @@ func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]int
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool { func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
osUsername := ctx.User() osUsername := ctx.User()
remoteAddr := ctx.RemoteAddr() remoteAddr := ctx.RemoteAddr()
logger := s.getRequestLogger(ctx)
if err := s.ensureJWTValidator(); err != nil { if err := s.ensureJWTValidator(); err != nil {
log.Errorf("JWT validator initialization failed for user %s from %s: %v", osUsername, remoteAddr, err) logger.Errorf("JWT validator initialization failed: %v", err)
return false return false
} }
token, err := s.validateJWTToken(password) token, err := s.validateJWTToken(password)
if err != nil { if err != nil {
log.Warnf("JWT authentication failed for user %s from %s: %v", osUsername, remoteAddr, err) logger.Warnf("JWT authentication failed: %v", err)
return false return false
} }
userAuth, err := s.extractAndValidateUser(token) userAuth, err := s.extractAndValidateUser(token)
if err != nil { if err != nil {
log.Warnf("User validation failed for user %s from %s: %v", osUsername, remoteAddr, err) logger.Warnf("user validation failed: %v", err)
return false return false
} }
logger = logger.WithField("jwt_user", userAuth.UserId)
s.mu.RLock() s.mu.RLock()
authorizer := s.authorizer authorizer := s.authorizer
s.mu.RUnlock() s.mu.RUnlock()
if err := authorizer.Authorize(userAuth.UserId, osUsername); err != nil { msg, err := authorizer.Authorize(userAuth.UserId, osUsername)
log.Warnf("SSH authorization denied for user %s (JWT user ID: %s) from %s: %v", osUsername, userAuth.UserId, remoteAddr, err) if err != nil {
logger.Warnf("SSH auth denied: %v", err)
return false return false
} }
logger.Infof("SSH auth %s", msg)
key := newAuthKey(osUsername, remoteAddr) key := newAuthKey(osUsername, remoteAddr)
remoteAddrStr := ctx.RemoteAddr().String()
s.mu.Lock() s.mu.Lock()
s.pendingAuthJWT[key] = userAuth.UserId s.pendingAuthJWT[key] = userAuth.UserId
s.connections[connKey(remoteAddrStr)] = &connState{
username: ctx.User(),
remoteAddr: ctx.RemoteAddr(),
jwtUsername: userAuth.UserId,
}
s.mu.Unlock() s.mu.Unlock()
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", osUsername, userAuth.UserId, remoteAddr)
return true return true
} }
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) { func (s *Server) addConnectionPortForward(username string, remoteAddr net.Addr, forwardAddr string) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if state, exists := s.sshConnections[sshConn]; exists { key := connKey(remoteAddr.String())
state.hasActivePortForward = true if state, exists := s.connections[key]; exists {
} else { if !slices.Contains(state.portForwards, forwardAddr) {
s.sshConnections[sshConn] = &sshConnectionState{ state.portForwards = append(state.portForwards, forwardAddr)
hasActivePortForward: true,
username: username,
remoteAddr: remoteAddr,
} }
return
}
// Connection not in connections (non-JWT auth path)
s.connections[key] = &connState{
username: username,
remoteAddr: remoteAddr,
portForwards: []string{forwardAddr},
jwtUsername: s.pendingAuthJWT[newAuthKey(username, remoteAddr)],
} }
} }
func (s *Server) connectionCloseHandler(conn net.Conn, err error) { func (s *Server) removeConnectionPortForward(remoteAddr net.Addr, forwardAddr string) {
// We can't extract the SSH connection from net.Conn directly s.mu.Lock()
// Connection cleanup will happen during session cleanup or via timeout defer s.mu.Unlock()
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
state, exists := s.connections[connKey(remoteAddr.String())]
if !exists {
return
}
state.portForwards = slices.DeleteFunc(state.portForwards, func(addr string) bool {
return addr == forwardAddr
})
} }
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey { // trackedConn wraps a net.Conn to detect when it closes
type trackedConn struct {
net.Conn
server *Server
remoteAddr string
onceClose sync.Once
}
func (c *trackedConn) Close() error {
err := c.Conn.Close()
c.onceClose.Do(func() {
c.server.handleConnectionClose(c.remoteAddr)
})
return err
}
func (s *Server) handleConnectionClose(remoteAddr string) {
s.mu.Lock()
defer s.mu.Unlock()
key := connKey(remoteAddr)
state, exists := s.connections[key]
if exists && len(state.portForwards) > 0 {
s.connLogger(state).Info("port forwarding connection closed")
}
delete(s.connections, key)
}
func (s *Server) connLogger(state *connState) *log.Entry {
logger := log.WithField("session", fmt.Sprintf("%s@%s", state.username, state.remoteAddr))
if state.jwtUsername != "" {
logger = logger.WithField("jwt_user", state.jwtUsername)
}
return logger
}
func (s *Server) findSessionKeyByContext(ctx ssh.Context) sessionKey {
if ctx == nil { if ctx == nil {
return "unknown" return "unknown"
} }
// Try to match by SSH connection
sshConn := ctx.Value(ssh.ContextKeyConn) sshConn := ctx.Value(ssh.ContextKeyConn)
if sshConn == nil { if sshConn == nil {
return "unknown" return "unknown"
@@ -591,19 +712,14 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
// Look through sessions to find one with matching connection for sessionKey, state := range s.sessions {
for sessionKey, session := range s.sessions { if state.session.Context().Value(ssh.ContextKeyConn) == sshConn {
if session.Context().Value(ssh.ContextKeyConn) == sshConn {
return sessionKey return sessionKey
} }
} }
// If no session found, this might be during early connection setup
// Return a temporary key that we'll fix up later
if ctx.User() != "" && ctx.RemoteAddr() != nil { if ctx.User() != "" && ctx.RemoteAddr() != nil {
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String())) return sessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
return tempKey
} }
return "unknown" return "unknown"
@@ -644,7 +760,11 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
} }
log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr) log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr)
return conn return &trackedConn{
Conn: conn,
server: s,
remoteAddr: conn.RemoteAddr().String(),
}
} }
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) { func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
@@ -672,9 +792,8 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
"tcpip-forward": s.tcpipForwardHandler, "tcpip-forward": s.tcpipForwardHandler,
"cancel-tcpip-forward": s.cancelTcpipForwardHandler, "cancel-tcpip-forward": s.cancelTcpipForwardHandler,
}, },
ConnCallback: s.connectionValidator, ConnCallback: s.connectionValidator,
ConnectionFailedCallback: s.connectionCloseHandler, Version: serverVersion,
Version: serverVersion,
} }
if s.jwtEnabled { if s.jwtEnabled {
@@ -690,13 +809,13 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
return server, nil return server, nil
} }
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) { func (s *Server) storeRemoteForwardListener(key forwardKey, ln net.Listener) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.remoteForwardListeners[key] = ln s.remoteForwardListeners[key] = ln
} }
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool { func (s *Server) removeRemoteForwardListener(key forwardKey) bool {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -714,6 +833,8 @@ func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
} }
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) { func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
logger := s.getRequestLogger(ctx)
var payload struct { var payload struct {
Host string Host string
Port uint32 Port uint32
@@ -723,7 +844,7 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil { if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil { if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil {
log.Debugf("channel reject error: %v", err) logger.Debugf("channel reject error: %v", err)
} }
return return
} }
@@ -733,19 +854,20 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
s.mu.RUnlock() s.mu.RUnlock()
if !allowLocal { if !allowLocal {
log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port) logger.Warnf("local port forwarding denied for %s:%d: disabled", payload.Host, payload.Port)
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled") _ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
return return
} }
// Check privilege requirements for the destination port
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil { if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err) logger.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges") _ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
return return
} }
log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port) forwardAddr := fmt.Sprintf("-L %s:%d", payload.Host, payload.Port)
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx) ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
} }

View File

@@ -224,6 +224,96 @@ func TestServer_PortForwardingRestriction(t *testing.T) {
} }
} }
func TestServer_PrivilegedPortAccess(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
}
server := New(serverConfig)
server.SetAllowRemotePortForwarding(true)
tests := []struct {
name string
forwardType string
port uint32
username string
expectError bool
errorMsg string
skipOnWindows bool
}{
{
name: "non-root user remote forward privileged port",
forwardType: "remote",
port: 80,
username: "testuser",
expectError: true,
errorMsg: "cannot bind to privileged port",
skipOnWindows: true,
},
{
name: "non-root user tcpip-forward privileged port",
forwardType: "tcpip-forward",
port: 443,
username: "testuser",
expectError: true,
errorMsg: "cannot bind to privileged port",
skipOnWindows: true,
},
{
name: "non-root user remote forward unprivileged port",
forwardType: "remote",
port: 8080,
username: "testuser",
expectError: false,
},
{
name: "non-root user remote forward port 0",
forwardType: "remote",
port: 0,
username: "testuser",
expectError: false,
},
{
name: "root user remote forward privileged port",
forwardType: "remote",
port: 22,
username: "root",
expectError: false,
},
{
name: "local forward privileged port allowed for non-root",
forwardType: "local",
port: 80,
username: "testuser",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.skipOnWindows && runtime.GOOS == "windows" {
t.Skip("Windows does not have privileged port restrictions")
}
result := PrivilegeCheckResult{
Allowed: true,
User: &user.User{Username: tt.username},
}
err := server.checkPrivilegedPortAccess(tt.forwardType, tt.port, result)
if tt.expectError {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
require.NoError(t, err)
}
})
}
}
func TestServer_PortConflictHandling(t *testing.T) { func TestServer_PortConflictHandling(t *testing.T) {
// Test that multiple sessions requesting the same local port are handled naturally by the OS // Test that multiple sessions requesting the same local port are handled naturally by the OS
// Get current user for SSH connection // Get current user for SSH connection
@@ -392,3 +482,95 @@ func TestServer_IsPrivilegedUser(t *testing.T) {
}) })
} }
} }
func TestServer_PortForwardingOnlySession(t *testing.T) {
// Test that sessions without PTY and command are allowed when port forwarding is enabled
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
tests := []struct {
name string
allowLocalForwarding bool
allowRemoteForwarding bool
expectAllowed bool
description string
}{
{
name: "session_allowed_with_local_forwarding",
allowLocalForwarding: true,
allowRemoteForwarding: false,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when local forwarding is enabled",
},
{
name: "session_allowed_with_remote_forwarding",
allowLocalForwarding: false,
allowRemoteForwarding: true,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when remote forwarding is enabled",
},
{
name: "session_allowed_with_both",
allowLocalForwarding: true,
allowRemoteForwarding: true,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when both forwarding types enabled",
},
{
name: "session_denied_without_forwarding",
allowLocalForwarding: false,
allowRemoteForwarding: false,
expectAllowed: false,
description: "Port-forwarding-only session should be denied when all forwarding is disabled",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
serverAddr := StartTestServer(t, server)
defer func() {
_ = server.Stop()
}()
// Connect to the server without requesting PTY or command
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
client, err := sshclient.Dial(ctx, serverAddr, currentUser.Username, sshclient.DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
_ = client.Close()
}()
// Execute a command without PTY - this simulates ssh -T with no command
// The server should either allow it (port forwarding enabled) or reject it
output, err := client.ExecuteCommand(ctx, "")
if tt.expectAllowed {
// When allowed, the session stays open until cancelled
// ExecuteCommand with empty command should return without error
assert.NoError(t, err, "Session should be allowed when port forwarding is enabled")
assert.NotContains(t, output, "port forwarding is disabled",
"Output should not contain port forwarding disabled message")
} else if err != nil {
// When denied, we expect an error message about port forwarding being disabled
assert.Contains(t, err.Error(), "port forwarding is disabled",
"Should get port forwarding disabled message")
}
})
}
}

View File

@@ -6,37 +6,45 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"strings"
"time" "time"
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
) )
// associateJWTUsername extracts pending JWT username for the session and associates it with the session state.
// Returns the JWT username (empty if none) for logging purposes.
func (s *Server) associateJWTUsername(sess ssh.Session, sessionKey sessionKey) string {
key := newAuthKey(sess.User(), sess.RemoteAddr())
s.mu.Lock()
defer s.mu.Unlock()
jwtUsername := s.pendingAuthJWT[key]
if jwtUsername == "" {
return ""
}
if state, exists := s.sessions[sessionKey]; exists {
state.jwtUsername = jwtUsername
}
delete(s.pendingAuthJWT, key)
return jwtUsername
}
// sessionHandler handles SSH sessions // sessionHandler handles SSH sessions
func (s *Server) sessionHandler(session ssh.Session) { func (s *Server) sessionHandler(session ssh.Session) {
sessionKey := s.registerSession(session) sessionKey := s.registerSession(session, "")
jwtUsername := s.associateJWTUsername(session, sessionKey)
key := newAuthKey(session.User(), session.RemoteAddr())
s.mu.Lock()
jwtUsername := s.pendingAuthJWT[key]
if jwtUsername != "" {
s.sessionJWTUsers[sessionKey] = jwtUsername
delete(s.pendingAuthJWT, key)
}
s.mu.Unlock()
logger := log.WithField("session", sessionKey) logger := log.WithField("session", sessionKey)
if jwtUsername != "" { if jwtUsername != "" {
logger = logger.WithField("jwt_user", jwtUsername) logger = logger.WithField("jwt_user", jwtUsername)
logger.Infof("SSH session started (JWT user: %s)", jwtUsername)
} else {
logger.Infof("SSH session started")
} }
logger.Info("SSH session started")
sessionStart := time.Now() sessionStart := time.Now()
defer s.unregisterSession(sessionKey, session) defer s.unregisterSession(sessionKey)
defer func() { defer func() {
duration := time.Since(sessionStart).Round(time.Millisecond) duration := time.Since(sessionStart).Round(time.Millisecond)
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) { if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
@@ -65,27 +73,52 @@ func (s *Server) sessionHandler(session ssh.Session) {
// ssh <host> <cmd> - non-Pty command execution // ssh <host> <cmd> - non-Pty command execution
s.handleCommand(logger, session, privilegeResult, nil) s.handleCommand(logger, session, privilegeResult, nil)
default: default:
s.rejectInvalidSession(logger, session) // ssh -T (or ssh -N) - no PTY, no command
s.handleNonInteractiveSession(logger, session)
} }
} }
func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) { // handleNonInteractiveSession handles sessions that have no PTY and no command.
if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil { // These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N).
logger.Debugf(errWriteSession, err) func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) {
s.updateSessionType(session, cmdNonInteractive)
if !s.isPortForwardingEnabled() {
if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil {
logger.Debugf(errWriteSession, err)
}
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
logger.Infof("rejected non-interactive session: port forwarding disabled")
return
} }
if err := session.Exit(1); err != nil {
<-session.Context().Done()
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err) logSessionExitError(logger, err)
} }
logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr())
} }
func (s *Server) registerSession(session ssh.Session) SessionKey { func (s *Server) updateSessionType(session ssh.Session, sessionType string) {
s.mu.Lock()
defer s.mu.Unlock()
for _, state := range s.sessions {
if state.session == session {
state.sessionType = sessionType
return
}
}
}
func (s *Server) registerSession(session ssh.Session, sessionType string) sessionKey {
sessionID := session.Context().Value(ssh.ContextKeySessionID) sessionID := session.Context().Value(ssh.ContextKeySessionID)
if sessionID == nil { if sessionID == nil {
sessionID = fmt.Sprintf("%p", session) sessionID = fmt.Sprintf("%p", session)
} }
// Create a short 4-byte identifier from the full session ID
hasher := sha256.New() hasher := sha256.New()
hasher.Write([]byte(fmt.Sprintf("%v", sessionID))) hasher.Write([]byte(fmt.Sprintf("%v", sessionID)))
hash := hasher.Sum(nil) hash := hasher.Sum(nil)
@@ -93,43 +126,23 @@ func (s *Server) registerSession(session ssh.Session) SessionKey {
remoteAddr := session.RemoteAddr().String() remoteAddr := session.RemoteAddr().String()
username := session.User() username := session.User()
sessionKey := SessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID)) sessionKey := sessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID))
s.mu.Lock() s.mu.Lock()
s.sessions[sessionKey] = session s.sessions[sessionKey] = &sessionState{
session: session,
sessionType: sessionType,
}
s.mu.Unlock() s.mu.Unlock()
return sessionKey return sessionKey
} }
func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) { func (s *Server) unregisterSession(sessionKey sessionKey) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionKey) delete(s.sessions, sessionKey)
delete(s.sessionJWTUsers, sessionKey)
// Cancel all port forwarding connections for this session
var connectionsToCancel []ConnectionKey
for key := range s.sessionCancels {
if strings.HasPrefix(string(key), string(sessionKey)+"-") {
connectionsToCancel = append(connectionsToCancel, key)
}
}
for _, key := range connectionsToCancel {
if cancelFunc, exists := s.sessionCancels[key]; exists {
log.WithField("session", sessionKey).Debugf("cancelling port forwarding context: %s", key)
cancelFunc()
delete(s.sessionCancels, key)
}
}
if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil {
if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok {
delete(s.sshConnections, sshConn)
}
}
s.mu.Unlock()
} }
func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) { func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) {

View File

@@ -18,14 +18,26 @@ func (s *Server) SetAllowSFTP(allow bool) {
// sftpSubsystemHandler handles SFTP subsystem requests // sftpSubsystemHandler handles SFTP subsystem requests
func (s *Server) sftpSubsystemHandler(sess ssh.Session) { func (s *Server) sftpSubsystemHandler(sess ssh.Session) {
sessionKey := s.registerSession(sess, cmdSFTP)
defer s.unregisterSession(sessionKey)
jwtUsername := s.associateJWTUsername(sess, sessionKey)
logger := log.WithField("session", sessionKey)
if jwtUsername != "" {
logger = logger.WithField("jwt_user", jwtUsername)
}
logger.Info("SFTP session started")
defer logger.Info("SFTP session closed")
s.mu.RLock() s.mu.RLock()
allowSFTP := s.allowSFTP allowSFTP := s.allowSFTP
s.mu.RUnlock() s.mu.RUnlock()
if !allowSFTP { if !allowSFTP {
log.Debugf("SFTP subsystem request denied: SFTP disabled") logger.Debug("SFTP subsystem request denied: SFTP disabled")
if err := sess.Exit(1); err != nil { if err := sess.Exit(1); err != nil {
log.Debugf("SFTP session exit failed: %v", err) logger.Debugf("SFTP session exit: %v", err)
} }
return return
} }
@@ -37,31 +49,27 @@ func (s *Server) sftpSubsystemHandler(sess ssh.Session) {
}) })
if !result.Allowed { if !result.Allowed {
log.Warnf("SFTP access denied for user %s from %s: %v", sess.User(), sess.RemoteAddr(), result.Error) logger.Warnf("SFTP access denied: %v", result.Error)
if err := sess.Exit(1); err != nil { if err := sess.Exit(1); err != nil {
log.Debugf("exit SFTP session: %v", err) logger.Debugf("exit SFTP session: %v", err)
} }
return return
} }
log.Debugf("SFTP subsystem request from user %s (effective user %s)", sess.User(), result.User.Username)
if !result.RequiresUserSwitching { if !result.RequiresUserSwitching {
if err := s.executeSftpDirect(sess); err != nil { if err := s.executeSftpDirect(sess); err != nil {
log.Errorf("SFTP direct execution: %v", err) logger.Errorf("SFTP direct execution: %v", err)
} }
return return
} }
if err := s.executeSftpWithPrivilegeDrop(sess, result.User); err != nil { if err := s.executeSftpWithPrivilegeDrop(sess, result.User); err != nil {
log.Errorf("SFTP privilege drop execution: %v", err) logger.Errorf("SFTP privilege drop execution: %v", err)
} }
} }
// executeSftpDirect executes SFTP directly without privilege dropping // executeSftpDirect executes SFTP directly without privilege dropping
func (s *Server) executeSftpDirect(sess ssh.Session) error { func (s *Server) executeSftpDirect(sess ssh.Session) error {
log.Debugf("starting SFTP session for user %s (no privilege dropping)", sess.User())
sftpServer, err := sftp.NewServer(sess) sftpServer, err := sftp.NewServer(sess)
if err != nil { if err != nil {
return fmt.Errorf("SFTP server creation: %w", err) return fmt.Errorf("SFTP server creation: %w", err)

View File

@@ -82,10 +82,11 @@ type NsServerGroupStateOutput struct {
} }
type SSHSessionOutput struct { type SSHSessionOutput struct {
Username string `json:"username" yaml:"username"` Username string `json:"username" yaml:"username"`
RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"` RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
Command string `json:"command" yaml:"command"` Command string `json:"command" yaml:"command"`
JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"` JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"`
PortForwards []string `json:"portForwards,omitempty" yaml:"portForwards,omitempty"`
} }
type SSHServerStateOutput struct { type SSHServerStateOutput struct {
@@ -220,6 +221,7 @@ func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput {
RemoteAddress: session.GetRemoteAddress(), RemoteAddress: session.GetRemoteAddress(),
Command: session.GetCommand(), Command: session.GetCommand(),
JWTUsername: session.GetJwtUsername(), JWTUsername: session.GetJwtUsername(),
PortForwards: session.GetPortForwards(),
}) })
} }
@@ -475,6 +477,9 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
) )
} }
sshServerStatus += "\n " + sessionDisplay sshServerStatus += "\n " + sessionDisplay
for _, pf := range session.PortForwards {
sshServerStatus += "\n " + pf
}
} }
} }
} }