mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client] Add port forwarding to ssh proxy (#5031)
* Implement port forwarding for the ssh proxy * Allow user switching for port forwarding
This commit is contained in:
@@ -634,7 +634,11 @@ func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
|
if err := validateDestinationPort(remoteAddr); err != nil {
|
||||||
|
return fmt.Errorf("invalid remote address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Local port forwarding: %s -> %s", localAddr, remoteAddr)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
|
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||||
@@ -652,7 +656,11 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
|
if err := validateDestinationPort(localAddr); err != nil {
|
||||||
|
return fmt.Errorf("invalid local address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Remote port forwarding: %s -> %s", remoteAddr, localAddr)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
|
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||||
@@ -663,6 +671,35 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateDestinationPort checks that the destination address has a valid port.
|
||||||
|
// Port 0 is only valid for bind addresses (where the OS picks an available port),
|
||||||
|
// not for destination addresses where we need to connect.
|
||||||
|
func validateDestinationPort(addr string) error {
|
||||||
|
if strings.HasPrefix(addr, "/") || strings.HasPrefix(addr, "./") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, portStr, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse address %s: %w", addr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid port %s: %w", portStr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if port == 0 {
|
||||||
|
return fmt.Errorf("port 0 is not valid for destination address")
|
||||||
|
}
|
||||||
|
|
||||||
|
if port < 0 || port > 65535 {
|
||||||
|
return fmt.Errorf("port %d out of range (1-65535)", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
|
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
|
||||||
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
|
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
|
||||||
func parsePortForwardSpec(spec string) (string, string, error) {
|
func parsePortForwardSpec(spec string) (string, string, error) {
|
||||||
|
|||||||
@@ -2013,6 +2013,7 @@ type SSHSessionInfo struct {
|
|||||||
RemoteAddress string `protobuf:"bytes,2,opt,name=remoteAddress,proto3" json:"remoteAddress,omitempty"`
|
RemoteAddress string `protobuf:"bytes,2,opt,name=remoteAddress,proto3" json:"remoteAddress,omitempty"`
|
||||||
Command string `protobuf:"bytes,3,opt,name=command,proto3" json:"command,omitempty"`
|
Command string `protobuf:"bytes,3,opt,name=command,proto3" json:"command,omitempty"`
|
||||||
JwtUsername string `protobuf:"bytes,4,opt,name=jwtUsername,proto3" json:"jwtUsername,omitempty"`
|
JwtUsername string `protobuf:"bytes,4,opt,name=jwtUsername,proto3" json:"jwtUsername,omitempty"`
|
||||||
|
PortForwards []string `protobuf:"bytes,5,rep,name=portForwards,proto3" json:"portForwards,omitempty"`
|
||||||
unknownFields protoimpl.UnknownFields
|
unknownFields protoimpl.UnknownFields
|
||||||
sizeCache protoimpl.SizeCache
|
sizeCache protoimpl.SizeCache
|
||||||
}
|
}
|
||||||
@@ -2075,6 +2076,13 @@ func (x *SSHSessionInfo) GetJwtUsername() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x *SSHSessionInfo) GetPortForwards() []string {
|
||||||
|
if x != nil {
|
||||||
|
return x.PortForwards
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// SSHServerState contains the latest state of the SSH server
|
// SSHServerState contains the latest state of the SSH server
|
||||||
type SSHServerState struct {
|
type SSHServerState struct {
|
||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
@@ -5706,12 +5714,13 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\aservers\x18\x01 \x03(\tR\aservers\x12\x18\n" +
|
"\aservers\x18\x01 \x03(\tR\aservers\x12\x18\n" +
|
||||||
"\adomains\x18\x02 \x03(\tR\adomains\x12\x18\n" +
|
"\adomains\x18\x02 \x03(\tR\adomains\x12\x18\n" +
|
||||||
"\aenabled\x18\x03 \x01(\bR\aenabled\x12\x14\n" +
|
"\aenabled\x18\x03 \x01(\bR\aenabled\x12\x14\n" +
|
||||||
"\x05error\x18\x04 \x01(\tR\x05error\"\x8e\x01\n" +
|
"\x05error\x18\x04 \x01(\tR\x05error\"\xb2\x01\n" +
|
||||||
"\x0eSSHSessionInfo\x12\x1a\n" +
|
"\x0eSSHSessionInfo\x12\x1a\n" +
|
||||||
"\busername\x18\x01 \x01(\tR\busername\x12$\n" +
|
"\busername\x18\x01 \x01(\tR\busername\x12$\n" +
|
||||||
"\rremoteAddress\x18\x02 \x01(\tR\rremoteAddress\x12\x18\n" +
|
"\rremoteAddress\x18\x02 \x01(\tR\rremoteAddress\x12\x18\n" +
|
||||||
"\acommand\x18\x03 \x01(\tR\acommand\x12 \n" +
|
"\acommand\x18\x03 \x01(\tR\acommand\x12 \n" +
|
||||||
"\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\"^\n" +
|
"\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\x12\"\n" +
|
||||||
|
"\fportForwards\x18\x05 \x03(\tR\fportForwards\"^\n" +
|
||||||
"\x0eSSHServerState\x12\x18\n" +
|
"\x0eSSHServerState\x12\x18\n" +
|
||||||
"\aenabled\x18\x01 \x01(\bR\aenabled\x122\n" +
|
"\aenabled\x18\x01 \x01(\bR\aenabled\x122\n" +
|
||||||
"\bsessions\x18\x02 \x03(\v2\x16.daemon.SSHSessionInfoR\bsessions\"\xaf\x04\n" +
|
"\bsessions\x18\x02 \x03(\v2\x16.daemon.SSHSessionInfoR\bsessions\"\xaf\x04\n" +
|
||||||
|
|||||||
@@ -372,6 +372,7 @@ message SSHSessionInfo {
|
|||||||
string remoteAddress = 2;
|
string remoteAddress = 2;
|
||||||
string command = 3;
|
string command = 3;
|
||||||
string jwtUsername = 4;
|
string jwtUsername = 4;
|
||||||
|
repeated string portForwards = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSHServerState contains the latest state of the SSH server
|
// SSHServerState contains the latest state of the SSH server
|
||||||
|
|||||||
@@ -1104,6 +1104,7 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
|
|||||||
RemoteAddress: session.RemoteAddress,
|
RemoteAddress: session.RemoteAddress,
|
||||||
Command: session.Command,
|
Command: session.Command,
|
||||||
JwtUsername: session.JWTUsername,
|
JwtUsername: session.JWTUsername,
|
||||||
|
PortForwards: session.PortForwards,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -98,19 +98,17 @@ func (a *Authorizer) Update(config *Config) {
|
|||||||
len(config.AuthorizedUsers), len(machineUsers))
|
len(config.AuthorizedUsers), len(machineUsers))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authorize validates if a user is authorized to login as the specified OS user
|
// Authorize validates if a user is authorized to login as the specified OS user.
|
||||||
// Returns nil if authorized, or an error describing why authorization failed
|
// Returns a success message describing how authorization was granted, or an error.
|
||||||
func (a *Authorizer) Authorize(jwtUserID, osUsername string) error {
|
func (a *Authorizer) Authorize(jwtUserID, osUsername string) (string, error) {
|
||||||
if jwtUserID == "" {
|
if jwtUserID == "" {
|
||||||
log.Warnf("SSH auth denied: JWT user ID is empty for OS user '%s'", osUsername)
|
return "", fmt.Errorf("JWT user ID is empty for OS user %q: %w", osUsername, ErrEmptyUserID)
|
||||||
return ErrEmptyUserID
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hash the JWT user ID for comparison
|
// Hash the JWT user ID for comparison
|
||||||
hashedUserID, err := sshuserhash.HashUserID(jwtUserID)
|
hashedUserID, err := sshuserhash.HashUserID(jwtUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("SSH auth denied: failed to hash user ID '%s' for OS user '%s': %v", jwtUserID, osUsername, err)
|
return "", fmt.Errorf("hash user ID %q for OS user %q: %w", jwtUserID, osUsername, err)
|
||||||
return fmt.Errorf("failed to hash user ID: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
a.mu.RLock()
|
a.mu.RLock()
|
||||||
@@ -119,8 +117,7 @@ func (a *Authorizer) Authorize(jwtUserID, osUsername string) error {
|
|||||||
// Find the index of this user in the authorized list
|
// Find the index of this user in the authorized list
|
||||||
userIndex, found := a.findUserIndex(hashedUserID)
|
userIndex, found := a.findUserIndex(hashedUserID)
|
||||||
if !found {
|
if !found {
|
||||||
log.Warnf("SSH auth denied: user '%s' (hash: %s) not in authorized list for OS user '%s'", jwtUserID, hashedUserID, osUsername)
|
return "", fmt.Errorf("user %q (hash: %s) not in authorized list for OS user %q: %w", jwtUserID, hashedUserID, osUsername, ErrUserNotAuthorized)
|
||||||
return ErrUserNotAuthorized
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex)
|
return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex)
|
||||||
@@ -128,12 +125,11 @@ func (a *Authorizer) Authorize(jwtUserID, osUsername string) error {
|
|||||||
|
|
||||||
// checkMachineUserMapping validates if a user's index is authorized for the specified OS user
|
// checkMachineUserMapping validates if a user's index is authorized for the specified OS user
|
||||||
// Checks wildcard mapping first, then specific OS user mappings
|
// Checks wildcard mapping first, then specific OS user mappings
|
||||||
func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) error {
|
func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) (string, error) {
|
||||||
// If wildcard exists and user's index is in the wildcard list, allow access to any OS user
|
// If wildcard exists and user's index is in the wildcard list, allow access to any OS user
|
||||||
if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard {
|
if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard {
|
||||||
if a.isIndexInList(uint32(userIndex), wildcardIndexes) {
|
if a.isIndexInList(uint32(userIndex), wildcardIndexes) {
|
||||||
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' via wildcard (index: %d)", jwtUserID, osUsername, userIndex)
|
return fmt.Sprintf("granted via wildcard (index: %d)", userIndex), nil
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,18 +137,15 @@ func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userI
|
|||||||
allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername]
|
allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername]
|
||||||
if !hasMachineUserMapping {
|
if !hasMachineUserMapping {
|
||||||
// No mapping for this OS user - deny by default (fail closed)
|
// No mapping for this OS user - deny by default (fail closed)
|
||||||
log.Warnf("SSH auth denied: no machine user mapping for OS user '%s' (JWT user: %s)", osUsername, jwtUserID)
|
return "", fmt.Errorf("no machine user mapping for OS user %q (JWT user: %s): %w", osUsername, jwtUserID, ErrNoMachineUserMapping)
|
||||||
return ErrNoMachineUserMapping
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if user's index is in the allowed indexes for this specific OS user
|
// Check if user's index is in the allowed indexes for this specific OS user
|
||||||
if !a.isIndexInList(uint32(userIndex), allowedIndexes) {
|
if !a.isIndexInList(uint32(userIndex), allowedIndexes) {
|
||||||
log.Warnf("SSH auth denied: user '%s' not mapped to OS user '%s' (user index: %d)", jwtUserID, osUsername, userIndex)
|
return "", fmt.Errorf("user %q not mapped to OS user %q (index: %d): %w", jwtUserID, osUsername, userIndex, ErrUserNotMappedToOSUser)
|
||||||
return ErrUserNotMappedToOSUser
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' (index: %d)", jwtUserID, osUsername, userIndex)
|
return fmt.Sprintf("granted (index: %d)", userIndex), nil
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserIDClaim returns the JWT claim name used to extract user IDs
|
// GetUserIDClaim returns the JWT claim name used to extract user IDs
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ func TestAuthorizer_Authorize_UserNotInList(t *testing.T) {
|
|||||||
authorizer.Update(config)
|
authorizer.Update(config)
|
||||||
|
|
||||||
// Try to authorize a different user
|
// Try to authorize a different user
|
||||||
err = authorizer.Authorize("unauthorized-user", "root")
|
_, err = authorizer.Authorize("unauthorized-user", "root")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||||
}
|
}
|
||||||
@@ -45,15 +45,15 @@ func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T)
|
|||||||
authorizer.Update(config)
|
authorizer.Update(config)
|
||||||
|
|
||||||
// All attempts should fail when no machine user mappings exist (fail closed)
|
// All attempts should fail when no machine user mappings exist (fail closed)
|
||||||
err = authorizer.Authorize("user1", "root")
|
_, err = authorizer.Authorize("user1", "root")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
err = authorizer.Authorize("user2", "admin")
|
_, err = authorizer.Authorize("user2", "admin")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
err = authorizer.Authorize("user1", "postgres")
|
_, err = authorizer.Authorize("user1", "postgres")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
}
|
}
|
||||||
@@ -80,21 +80,21 @@ func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testi
|
|||||||
authorizer.Update(config)
|
authorizer.Update(config)
|
||||||
|
|
||||||
// user1 (index 0) should access root and admin
|
// user1 (index 0) should access root and admin
|
||||||
err = authorizer.Authorize("user1", "root")
|
_, err = authorizer.Authorize("user1", "root")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = authorizer.Authorize("user1", "admin")
|
_, err = authorizer.Authorize("user1", "admin")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// user2 (index 1) should access root and postgres
|
// user2 (index 1) should access root and postgres
|
||||||
err = authorizer.Authorize("user2", "root")
|
_, err = authorizer.Authorize("user2", "root")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = authorizer.Authorize("user2", "postgres")
|
_, err = authorizer.Authorize("user2", "postgres")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// user3 (index 2) should access postgres
|
// user3 (index 2) should access postgres
|
||||||
err = authorizer.Authorize("user3", "postgres")
|
_, err = authorizer.Authorize("user3", "postgres")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,22 +121,22 @@ func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testin
|
|||||||
authorizer.Update(config)
|
authorizer.Update(config)
|
||||||
|
|
||||||
// user1 (index 0) should NOT access postgres
|
// user1 (index 0) should NOT access postgres
|
||||||
err = authorizer.Authorize("user1", "postgres")
|
_, err = authorizer.Authorize("user1", "postgres")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
|
|
||||||
// user2 (index 1) should NOT access admin
|
// user2 (index 1) should NOT access admin
|
||||||
err = authorizer.Authorize("user2", "admin")
|
_, err = authorizer.Authorize("user2", "admin")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
|
|
||||||
// user3 (index 2) should NOT access root
|
// user3 (index 2) should NOT access root
|
||||||
err = authorizer.Authorize("user3", "root")
|
_, err = authorizer.Authorize("user3", "root")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
|
|
||||||
// user3 (index 2) should NOT access admin
|
// user3 (index 2) should NOT access admin
|
||||||
err = authorizer.Authorize("user3", "admin")
|
_, err = authorizer.Authorize("user3", "admin")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
}
|
}
|
||||||
@@ -158,7 +158,7 @@ func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) {
|
|||||||
authorizer.Update(config)
|
authorizer.Update(config)
|
||||||
|
|
||||||
// user1 should NOT access an unmapped OS user (fail closed)
|
// user1 should NOT access an unmapped OS user (fail closed)
|
||||||
err = authorizer.Authorize("user1", "postgres")
|
_, err = authorizer.Authorize("user1", "postgres")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
}
|
}
|
||||||
@@ -178,7 +178,7 @@ func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) {
|
|||||||
authorizer.Update(config)
|
authorizer.Update(config)
|
||||||
|
|
||||||
// Empty user ID should fail
|
// Empty user ID should fail
|
||||||
err = authorizer.Authorize("", "root")
|
_, err = authorizer.Authorize("", "root")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrEmptyUserID)
|
assert.ErrorIs(t, err, ErrEmptyUserID)
|
||||||
}
|
}
|
||||||
@@ -211,12 +211,12 @@ func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) {
|
|||||||
|
|
||||||
// All users should be authorized for root
|
// All users should be authorized for root
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
err := authorizer.Authorize("user"+string(rune('0'+i)), "root")
|
_, err := authorizer.Authorize("user"+string(rune('0'+i)), "root")
|
||||||
assert.NoError(t, err, "user%d should be authorized", i)
|
assert.NoError(t, err, "user%d should be authorized", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
// User not in list should fail
|
// User not in list should fail
|
||||||
err := authorizer.Authorize("unknown-user", "root")
|
_, err := authorizer.Authorize("unknown-user", "root")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||||
}
|
}
|
||||||
@@ -236,14 +236,14 @@ func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) {
|
|||||||
authorizer.Update(config)
|
authorizer.Update(config)
|
||||||
|
|
||||||
// user1 should be authorized
|
// user1 should be authorized
|
||||||
err = authorizer.Authorize("user1", "root")
|
_, err = authorizer.Authorize("user1", "root")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Clear configuration
|
// Clear configuration
|
||||||
authorizer.Update(nil)
|
authorizer.Update(nil)
|
||||||
|
|
||||||
// user1 should no longer be authorized
|
// user1 should no longer be authorized
|
||||||
err = authorizer.Authorize("user1", "root")
|
_, err = authorizer.Authorize("user1", "root")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||||
}
|
}
|
||||||
@@ -267,16 +267,16 @@ func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) {
|
|||||||
authorizer.Update(config)
|
authorizer.Update(config)
|
||||||
|
|
||||||
// root should work
|
// root should work
|
||||||
err = authorizer.Authorize("user1", "root")
|
_, err = authorizer.Authorize("user1", "root")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// postgres should fail (no mapping)
|
// postgres should fail (no mapping)
|
||||||
err = authorizer.Authorize("user1", "postgres")
|
_, err = authorizer.Authorize("user1", "postgres")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
// admin should fail (no mapping)
|
// admin should fail (no mapping)
|
||||||
err = authorizer.Authorize("user1", "admin")
|
_, err = authorizer.Authorize("user1", "admin")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
}
|
}
|
||||||
@@ -301,7 +301,7 @@ func TestAuthorizer_CustomUserIDClaim(t *testing.T) {
|
|||||||
assert.Equal(t, "email", authorizer.GetUserIDClaim())
|
assert.Equal(t, "email", authorizer.GetUserIDClaim())
|
||||||
|
|
||||||
// Authorize with email as user ID
|
// Authorize with email as user ID
|
||||||
err = authorizer.Authorize("user@example.com", "root")
|
_, err = authorizer.Authorize("user@example.com", "root")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -349,19 +349,19 @@ func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) {
|
|||||||
authorizer.Update(config)
|
authorizer.Update(config)
|
||||||
|
|
||||||
// First user should have access
|
// First user should have access
|
||||||
err := authorizer.Authorize("user"+string(rune(0)), "root")
|
_, err := authorizer.Authorize("user"+string(rune(0)), "root")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Middle user should have access
|
// Middle user should have access
|
||||||
err = authorizer.Authorize("user"+string(rune(500)), "root")
|
_, err = authorizer.Authorize("user"+string(rune(500)), "root")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Last user should have access
|
// Last user should have access
|
||||||
err = authorizer.Authorize("user"+string(rune(999)), "root")
|
_, err = authorizer.Authorize("user"+string(rune(999)), "root")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// User not in mapping should NOT have access
|
// User not in mapping should NOT have access
|
||||||
err = authorizer.Authorize("user"+string(rune(100)), "root")
|
_, err = authorizer.Authorize("user"+string(rune(100)), "root")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -393,7 +393,7 @@ func TestAuthorizer_ConcurrentAuthorization(t *testing.T) {
|
|||||||
if idx%2 == 0 {
|
if idx%2 == 0 {
|
||||||
user = "user2"
|
user = "user2"
|
||||||
}
|
}
|
||||||
err := authorizer.Authorize(user, "root")
|
_, err := authorizer.Authorize(user, "root")
|
||||||
errChan <- err
|
errChan <- err
|
||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
@@ -426,22 +426,22 @@ func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) {
|
|||||||
authorizer.Update(config)
|
authorizer.Update(config)
|
||||||
|
|
||||||
// All authorized users should be able to access any OS user
|
// All authorized users should be able to access any OS user
|
||||||
err = authorizer.Authorize("user1", "root")
|
_, err = authorizer.Authorize("user1", "root")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = authorizer.Authorize("user2", "postgres")
|
_, err = authorizer.Authorize("user2", "postgres")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = authorizer.Authorize("user3", "admin")
|
_, err = authorizer.Authorize("user3", "admin")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = authorizer.Authorize("user1", "ubuntu")
|
_, err = authorizer.Authorize("user1", "ubuntu")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = authorizer.Authorize("user2", "nginx")
|
_, err = authorizer.Authorize("user2", "nginx")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = authorizer.Authorize("user3", "docker")
|
_, err = authorizer.Authorize("user3", "docker")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -462,11 +462,11 @@ func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) {
|
|||||||
authorizer.Update(config)
|
authorizer.Update(config)
|
||||||
|
|
||||||
// user1 should have access
|
// user1 should have access
|
||||||
err = authorizer.Authorize("user1", "root")
|
_, err = authorizer.Authorize("user1", "root")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Unauthorized user should still be denied even with wildcard
|
// Unauthorized user should still be denied even with wildcard
|
||||||
err = authorizer.Authorize("unauthorized-user", "root")
|
_, err = authorizer.Authorize("unauthorized-user", "root")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||||
}
|
}
|
||||||
@@ -492,17 +492,17 @@ func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) {
|
|||||||
authorizer.Update(config)
|
authorizer.Update(config)
|
||||||
|
|
||||||
// Both users should be able to access root via wildcard (takes precedence over specific mapping)
|
// Both users should be able to access root via wildcard (takes precedence over specific mapping)
|
||||||
err = authorizer.Authorize("user1", "root")
|
_, err = authorizer.Authorize("user1", "root")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = authorizer.Authorize("user2", "root")
|
_, err = authorizer.Authorize("user2", "root")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Both users should be able to access any other OS user via wildcard
|
// Both users should be able to access any other OS user via wildcard
|
||||||
err = authorizer.Authorize("user1", "postgres")
|
_, err = authorizer.Authorize("user1", "postgres")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = authorizer.Authorize("user2", "admin")
|
_, err = authorizer.Authorize("user2", "admin")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -526,29 +526,29 @@ func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) {
|
|||||||
authorizer.Update(config)
|
authorizer.Update(config)
|
||||||
|
|
||||||
// user1 can access root
|
// user1 can access root
|
||||||
err = authorizer.Authorize("user1", "root")
|
_, err = authorizer.Authorize("user1", "root")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// user2 can access postgres
|
// user2 can access postgres
|
||||||
err = authorizer.Authorize("user2", "postgres")
|
_, err = authorizer.Authorize("user2", "postgres")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// user1 cannot access postgres
|
// user1 cannot access postgres
|
||||||
err = authorizer.Authorize("user1", "postgres")
|
_, err = authorizer.Authorize("user1", "postgres")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
|
|
||||||
// user2 cannot access root
|
// user2 cannot access root
|
||||||
err = authorizer.Authorize("user2", "root")
|
_, err = authorizer.Authorize("user2", "root")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
|
|
||||||
// Neither can access unmapped OS users
|
// Neither can access unmapped OS users
|
||||||
err = authorizer.Authorize("user1", "admin")
|
_, err = authorizer.Authorize("user1", "admin")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
err = authorizer.Authorize("user2", "admin")
|
_, err = authorizer.Authorize("user2", "admin")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
}
|
}
|
||||||
@@ -578,35 +578,35 @@ func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
|
|||||||
authorizer.Update(config)
|
authorizer.Update(config)
|
||||||
|
|
||||||
// wasm (index 0) should access any OS user via wildcard
|
// wasm (index 0) should access any OS user via wildcard
|
||||||
err = authorizer.Authorize("wasm", "root")
|
_, err = authorizer.Authorize("wasm", "root")
|
||||||
assert.NoError(t, err, "wasm should access root via wildcard")
|
assert.NoError(t, err, "wasm should access root via wildcard")
|
||||||
|
|
||||||
err = authorizer.Authorize("wasm", "alice")
|
_, err = authorizer.Authorize("wasm", "alice")
|
||||||
assert.NoError(t, err, "wasm should access alice via wildcard")
|
assert.NoError(t, err, "wasm should access alice via wildcard")
|
||||||
|
|
||||||
err = authorizer.Authorize("wasm", "bob")
|
_, err = authorizer.Authorize("wasm", "bob")
|
||||||
assert.NoError(t, err, "wasm should access bob via wildcard")
|
assert.NoError(t, err, "wasm should access bob via wildcard")
|
||||||
|
|
||||||
err = authorizer.Authorize("wasm", "postgres")
|
_, err = authorizer.Authorize("wasm", "postgres")
|
||||||
assert.NoError(t, err, "wasm should access postgres via wildcard")
|
assert.NoError(t, err, "wasm should access postgres via wildcard")
|
||||||
|
|
||||||
// user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres
|
// user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres
|
||||||
err = authorizer.Authorize("user2", "alice")
|
_, err = authorizer.Authorize("user2", "alice")
|
||||||
assert.NoError(t, err, "user2 should access alice via explicit mapping")
|
assert.NoError(t, err, "user2 should access alice via explicit mapping")
|
||||||
|
|
||||||
err = authorizer.Authorize("user2", "bob")
|
_, err = authorizer.Authorize("user2", "bob")
|
||||||
assert.NoError(t, err, "user2 should access bob via explicit mapping")
|
assert.NoError(t, err, "user2 should access bob via explicit mapping")
|
||||||
|
|
||||||
err = authorizer.Authorize("user2", "root")
|
_, err = authorizer.Authorize("user2", "root")
|
||||||
assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)")
|
assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)")
|
||||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
err = authorizer.Authorize("user2", "postgres")
|
_, err = authorizer.Authorize("user2", "postgres")
|
||||||
assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)")
|
assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)")
|
||||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
// Unauthorized user should still be denied
|
// Unauthorized user should still be denied
|
||||||
err = authorizer.Authorize("user3", "root")
|
_, err = authorizer.Authorize("user3", "root")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
|
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -551,14 +550,15 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
|
|||||||
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
|
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := localConn.Close(); err != nil {
|
if err := localConn.Close(); err != nil {
|
||||||
log.Debugf("local connection close error: %v", err)
|
log.Debugf("local port forwarding: close local connection: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
channel, err := c.client.Dial("tcp", remoteAddr)
|
channel, err := c.client.Dial("tcp", remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "administratively prohibited") {
|
var openErr *ssh.OpenChannelError
|
||||||
_, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n")
|
if errors.As(err, &openErr) && openErr.Reason == ssh.Prohibited {
|
||||||
|
_, _ = fmt.Fprintf(os.Stderr, "channel open failed: port forwarding is disabled\n")
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err)
|
log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err)
|
||||||
}
|
}
|
||||||
@@ -566,19 +566,11 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
|
|||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := channel.Close(); err != nil {
|
if err := channel.Close(); err != nil {
|
||||||
log.Debugf("remote channel close error: %v", err)
|
log.Debugf("local port forwarding: close remote channel: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
|
||||||
if _, err := io.Copy(channel, localConn); err != nil {
|
|
||||||
log.Debugf("local forward copy error (local->remote): %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if _, err := io.Copy(localConn, channel); err != nil {
|
|
||||||
log.Debugf("local forward copy error (remote->local): %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
|
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
|
||||||
@@ -633,7 +625,7 @@ func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error {
|
|||||||
return fmt.Errorf("send tcpip-forward request: %w", err)
|
return fmt.Errorf("send tcpip-forward request: %w", err)
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)")
|
return fmt.Errorf("remote port forwarding denied by server")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -676,7 +668,7 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
|
|||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := channel.Close(); err != nil {
|
if err := channel.Close(); err != nil {
|
||||||
log.Debugf("remote channel close error: %v", err)
|
log.Debugf("remote port forwarding: close remote channel: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -688,19 +680,11 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
|
|||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := localConn.Close(); err != nil {
|
if err := localConn.Close(); err != nil {
|
||||||
log.Debugf("local connection close error: %v", err)
|
log.Debugf("remote port forwarding: close local connection: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
|
||||||
if _, err := io.Copy(localConn, channel); err != nil {
|
|
||||||
log.Debugf("remote forward copy error (remote->local): %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if _, err := io.Copy(channel, localConn); err != nil {
|
|
||||||
log.Debugf("remote forward copy error (local->remote): %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// tcpipForwardMsg represents the structure for tcpip-forward requests
|
// tcpipForwardMsg represents the structure for tcpip-forward requests
|
||||||
|
|||||||
@@ -193,3 +193,64 @@ func buildAddressList(hostname string, remote net.Addr) []string {
|
|||||||
}
|
}
|
||||||
return addresses
|
return addresses
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections.
|
||||||
|
// It waits for both directions to complete before returning.
|
||||||
|
// The caller is responsible for closing the connections.
|
||||||
|
func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) {
|
||||||
|
done := make(chan struct{}, 2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) {
|
||||||
|
logger.Debugf("copy error (1->2): %v", err)
|
||||||
|
}
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) {
|
||||||
|
logger.Debugf("copy error (2->1): %v", err)
|
||||||
|
}
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-done
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
func isExpectedCopyError(err error) bool {
|
||||||
|
return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections.
|
||||||
|
// It waits for both directions to complete or for context cancellation before returning.
|
||||||
|
// Both connections are closed when the function returns.
|
||||||
|
func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) {
|
||||||
|
done := make(chan struct{}, 2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) {
|
||||||
|
logger.Debugf("copy error (1->2): %v", err)
|
||||||
|
}
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) {
|
||||||
|
logger.Debugf("copy error (2->1): %v", err)
|
||||||
|
}
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-done:
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = conn1.Close()
|
||||||
|
_ = conn2.Close()
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -42,6 +43,14 @@ type SSHProxy struct {
|
|||||||
conn *grpc.ClientConn
|
conn *grpc.ClientConn
|
||||||
daemonClient proto.DaemonServiceClient
|
daemonClient proto.DaemonServiceClient
|
||||||
browserOpener func(string) error
|
browserOpener func(string) error
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
backendClient *cryptossh.Client
|
||||||
|
// jwtToken is set once in runProxySSHServer before any handlers are called,
|
||||||
|
// so concurrent access is safe without additional synchronization.
|
||||||
|
jwtToken string
|
||||||
|
|
||||||
|
forwardedChannelsOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) {
|
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) {
|
||||||
@@ -63,6 +72,17 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browse
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *SSHProxy) Close() error {
|
func (p *SSHProxy) Close() error {
|
||||||
|
p.mu.Lock()
|
||||||
|
backendClient := p.backendClient
|
||||||
|
p.backendClient = nil
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
if backendClient != nil {
|
||||||
|
if err := backendClient.Close(); err != nil {
|
||||||
|
log.Debugf("close backend client: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if p.conn != nil {
|
if p.conn != nil {
|
||||||
return p.conn.Close()
|
return p.conn.Close()
|
||||||
}
|
}
|
||||||
@@ -77,16 +97,16 @@ func (p *SSHProxy) Connect(ctx context.Context) error {
|
|||||||
return fmt.Errorf(jwtAuthErrorMsg, err)
|
return fmt.Errorf(jwtAuthErrorMsg, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.runProxySSHServer(ctx, jwtToken)
|
log.Debugf("JWT authentication successful, starting proxy to %s:%d", p.targetHost, p.targetPort)
|
||||||
|
return p.runProxySSHServer(jwtToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error {
|
func (p *SSHProxy) runProxySSHServer(jwtToken string) error {
|
||||||
|
p.jwtToken = jwtToken
|
||||||
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
|
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
|
||||||
|
|
||||||
sshServer := &ssh.Server{
|
sshServer := &ssh.Server{
|
||||||
Handler: func(s ssh.Session) {
|
Handler: p.handleSSHSession,
|
||||||
p.handleSSHSession(ctx, s, jwtToken)
|
|
||||||
},
|
|
||||||
ChannelHandlers: map[string]ssh.ChannelHandler{
|
ChannelHandlers: map[string]ssh.ChannelHandler{
|
||||||
"session": ssh.DefaultSessionHandler,
|
"session": ssh.DefaultSessionHandler,
|
||||||
"direct-tcpip": p.directTCPIPHandler,
|
"direct-tcpip": p.directTCPIPHandler,
|
||||||
@@ -119,15 +139,20 @@ func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) {
|
func (p *SSHProxy) handleSSHSession(session ssh.Session) {
|
||||||
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
|
ptyReq, winCh, isPty := session.Pty()
|
||||||
|
hasCommand := len(session.Command()) > 0
|
||||||
|
|
||||||
sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken)
|
sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
|
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() { _ = sshClient.Close() }()
|
|
||||||
|
if !isPty && !hasCommand {
|
||||||
|
p.handleNonInteractiveSession(session, sshClient)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
serverSession, err := sshClient.NewSession()
|
serverSession, err := sshClient.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -140,7 +165,6 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
|
|||||||
serverSession.Stdout = session
|
serverSession.Stdout = session
|
||||||
serverSession.Stderr = session.Stderr()
|
serverSession.Stderr = session.Stderr()
|
||||||
|
|
||||||
ptyReq, winCh, isPty := session.Pty()
|
|
||||||
if isPty {
|
if isPty {
|
||||||
if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil {
|
if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil {
|
||||||
log.Debugf("PTY request to backend: %v", err)
|
log.Debugf("PTY request to backend: %v", err)
|
||||||
@@ -155,7 +179,7 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(session.Command()) > 0 {
|
if hasCommand {
|
||||||
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
|
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
|
||||||
log.Debugf("run command: %v", err)
|
log.Debugf("run command: %v", err)
|
||||||
p.handleProxyExitCode(session, err)
|
p.handleProxyExitCode(session, err)
|
||||||
@@ -176,12 +200,29 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
|
|||||||
func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
|
func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
|
||||||
var exitErr *cryptossh.ExitError
|
var exitErr *cryptossh.ExitError
|
||||||
if errors.As(err, &exitErr) {
|
if errors.As(err, &exitErr) {
|
||||||
if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil {
|
if err := session.Exit(exitErr.ExitStatus()); err != nil {
|
||||||
log.Debugf("set exit status: %v", exitErr)
|
log.Debugf("set exit status: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
|
||||||
|
// Create a backend session to mirror the client's session request.
|
||||||
|
// This keeps the connection alive on the server side while port forwarding channels operate.
|
||||||
|
serverSession, err := sshClient.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = serverSession.Close() }()
|
||||||
|
|
||||||
|
<-session.Context().Done()
|
||||||
|
|
||||||
|
if err := session.Exit(0); err != nil {
|
||||||
|
log.Debugf("session exit: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func generateHostKey() (ssh.Signer, error) {
|
func generateHostKey() (ssh.Signer, error) {
|
||||||
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -250,8 +291,52 @@ func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) {
|
// directTCPIPHandler handles local port forwarding (direct-tcpip channel).
|
||||||
_ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy")
|
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, sshCtx ssh.Context) {
|
||||||
|
var payload struct {
|
||||||
|
DestAddr string
|
||||||
|
DestPort uint32
|
||||||
|
OriginAddr string
|
||||||
|
OriginPort uint32
|
||||||
|
}
|
||||||
|
if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
|
||||||
|
_, _ = fmt.Fprintf(p.stderr, "parse direct-tcpip payload: %v\n", err)
|
||||||
|
_ = newChan.Reject(cryptossh.ConnectionFailed, "invalid payload")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dest := fmt.Sprintf("%s:%d", payload.DestAddr, payload.DestPort)
|
||||||
|
log.Debugf("local port forwarding: %s", dest)
|
||||||
|
|
||||||
|
backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User())
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(p.stderr, "backend connection for port forwarding: %v\n", err)
|
||||||
|
_ = newChan.Reject(cryptossh.ConnectionFailed, "backend connection failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
backendChan, backendReqs, err := backendClient.OpenChannel("direct-tcpip", newChan.ExtraData())
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(p.stderr, "open backend channel for %s: %v\n", dest, err)
|
||||||
|
var openErr *cryptossh.OpenChannelError
|
||||||
|
if errors.As(err, &openErr) {
|
||||||
|
_ = newChan.Reject(openErr.Reason, openErr.Message)
|
||||||
|
} else {
|
||||||
|
_ = newChan.Reject(cryptossh.ConnectionFailed, err.Error())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go cryptossh.DiscardRequests(backendReqs)
|
||||||
|
|
||||||
|
clientChan, clientReqs, err := newChan.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("local port forwarding: accept channel: %v", err)
|
||||||
|
_ = backendChan.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go cryptossh.DiscardRequests(clientReqs)
|
||||||
|
|
||||||
|
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
|
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
|
||||||
@@ -354,12 +439,143 @@ func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.Wr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
|
// tcpipForwardHandler handles remote port forwarding (tcpip-forward request).
|
||||||
return false, []byte("port forwarding not supported in proxy")
|
func (p *SSHProxy) tcpipForwardHandler(sshCtx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
|
||||||
|
var reqPayload struct {
|
||||||
|
Host string
|
||||||
|
Port uint32
|
||||||
|
}
|
||||||
|
if err := cryptossh.Unmarshal(req.Payload, &reqPayload); err != nil {
|
||||||
|
_, _ = fmt.Fprintf(p.stderr, "parse tcpip-forward payload: %v\n", err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("tcpip-forward request for %s:%d", reqPayload.Host, reqPayload.Port)
|
||||||
|
|
||||||
|
backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User())
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(p.stderr, "backend connection for remote port forwarding: %v\n", err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, payload, err := backendClient.SendRequest(req.Type, req.WantReply, req.Payload)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(p.stderr, "forward tcpip-forward request for %s:%d: %v\n", reqPayload.Host, reqPayload.Port, err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
actualPort := reqPayload.Port
|
||||||
|
if reqPayload.Port == 0 && len(payload) >= 4 {
|
||||||
|
actualPort = binary.BigEndian.Uint32(payload)
|
||||||
|
}
|
||||||
|
log.Debugf("remote port forwarding established for %s:%d", reqPayload.Host, actualPort)
|
||||||
|
p.forwardedChannelsOnce.Do(func() {
|
||||||
|
go p.handleForwardedChannels(sshCtx, backendClient)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return ok, payload
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
|
// cancelTcpipForwardHandler handles cancel-tcpip-forward request.
|
||||||
return true, nil
|
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
|
||||||
|
var reqPayload struct {
|
||||||
|
Host string
|
||||||
|
Port uint32
|
||||||
|
}
|
||||||
|
if err := cryptossh.Unmarshal(req.Payload, &reqPayload); err != nil {
|
||||||
|
_, _ = fmt.Fprintf(p.stderr, "parse cancel-tcpip-forward payload: %v\n", err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("cancel-tcpip-forward request for %s:%d", reqPayload.Host, reqPayload.Port)
|
||||||
|
|
||||||
|
backendClient := p.getBackendClient()
|
||||||
|
if backendClient == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, payload, err := backendClient.SendRequest(req.Type, req.WantReply, req.Payload)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(p.stderr, "cancel-tcpip-forward for %s:%d: %v\n", reqPayload.Host, reqPayload.Port, err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ok, payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOrCreateBackendClient returns the existing backend client or creates a new one.
|
||||||
|
func (p *SSHProxy) getOrCreateBackendClient(ctx context.Context, user string) (*cryptossh.Client, error) {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
if p.backendClient != nil {
|
||||||
|
return p.backendClient, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
|
||||||
|
log.Debugf("connecting to backend %s", targetAddr)
|
||||||
|
|
||||||
|
client, err := p.dialBackend(ctx, targetAddr, user, p.jwtToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("backend connection established to %s", targetAddr)
|
||||||
|
p.backendClient = client
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getBackendClient returns the existing backend client or nil.
|
||||||
|
func (p *SSHProxy) getBackendClient() *cryptossh.Client {
|
||||||
|
p.mu.RLock()
|
||||||
|
defer p.mu.RUnlock()
|
||||||
|
return p.backendClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleForwardedChannels handles forwarded-tcpip channels from the backend for remote port forwarding.
|
||||||
|
// When the backend receives incoming connections on the forwarded port, it sends them as
|
||||||
|
// "forwarded-tcpip" channels which we need to proxy to the client.
|
||||||
|
func (p *SSHProxy) handleForwardedChannels(sshCtx ssh.Context, backendClient *cryptossh.Client) {
|
||||||
|
sshConn, ok := sshCtx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
|
||||||
|
if !ok || sshConn == nil {
|
||||||
|
log.Debugf("no SSH connection in context for forwarded channels")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
channelChan := backendClient.HandleChannelOpen("forwarded-tcpip")
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-sshCtx.Done():
|
||||||
|
return
|
||||||
|
case newChannel, ok := <-channelChan:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go p.handleForwardedChannel(sshCtx, sshConn, newChannel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleForwardedChannel handles a single forwarded-tcpip channel from the backend.
|
||||||
|
func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh.ServerConn, newChannel cryptossh.NewChannel) {
|
||||||
|
backendChan, backendReqs, err := newChannel.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("remote port forwarding: accept from backend: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go cryptossh.DiscardRequests(backendReqs)
|
||||||
|
|
||||||
|
clientChan, clientReqs, err := sshConn.OpenChannel("forwarded-tcpip", newChannel.ExtraData())
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("remote port forwarding: open to client: %v", err)
|
||||||
|
_ = backendChan.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go cryptossh.DiscardRequests(clientReqs)
|
||||||
|
|
||||||
|
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
|
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
|
||||||
|
|||||||
@@ -1,25 +1,32 @@
|
|||||||
|
// Package server implements port forwarding for the SSH server.
|
||||||
|
//
|
||||||
|
// Security note: Port forwarding runs in the main server process without privilege separation.
|
||||||
|
// The attack surface is primarily io.Copy through well-tested standard library code, making it
|
||||||
|
// lower risk than shell execution which uses privilege-separated child processes. We enforce
|
||||||
|
// user-level port restrictions: non-privileged users cannot bind to ports < 1024.
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/gliderlabs/ssh"
|
"github.com/gliderlabs/ssh"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
cryptossh "golang.org/x/crypto/ssh"
|
cryptossh "golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SessionKey uniquely identifies an SSH session
|
const privilegedPortThreshold = 1024
|
||||||
type SessionKey string
|
|
||||||
|
|
||||||
// ConnectionKey uniquely identifies a port forwarding connection within a session
|
// sessionKey uniquely identifies an SSH session
|
||||||
type ConnectionKey string
|
type sessionKey string
|
||||||
|
|
||||||
// ForwardKey uniquely identifies a port forwarding listener
|
// forwardKey uniquely identifies a port forwarding listener
|
||||||
type ForwardKey string
|
type forwardKey string
|
||||||
|
|
||||||
// tcpipForwardMsg represents the structure for tcpip-forward SSH requests
|
// tcpipForwardMsg represents the structure for tcpip-forward SSH requests
|
||||||
type tcpipForwardMsg struct {
|
type tcpipForwardMsg struct {
|
||||||
@@ -47,34 +54,32 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
|
|||||||
allowRemote := s.allowRemotePortForwarding
|
allowRemote := s.allowRemotePortForwarding
|
||||||
|
|
||||||
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
|
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
|
||||||
|
logger := s.getRequestLogger(ctx)
|
||||||
if !allowLocal {
|
if !allowLocal {
|
||||||
log.Warnf("local port forwarding denied for %s from %s: disabled by configuration",
|
logger.Warnf("local port forwarding denied for %s:%d: disabled", dstHost, dstPort)
|
||||||
net.JoinHostPort(dstHost, fmt.Sprintf("%d", dstPort)), ctx.RemoteAddr())
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
|
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
|
||||||
log.Warnf("local port forwarding denied for %s:%d from %s: %v", dstHost, dstPort, ctx.RemoteAddr(), err)
|
logger.Warnf("local port forwarding denied for %s:%d: %v", dstHost, dstPort, err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("local port forwarding allowed: %s:%d", dstHost, dstPort)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
|
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
|
||||||
|
logger := s.getRequestLogger(ctx)
|
||||||
if !allowRemote {
|
if !allowRemote {
|
||||||
log.Warnf("remote port forwarding denied for %s from %s: disabled by configuration",
|
logger.Warnf("remote port forwarding denied for %s:%d: disabled", bindHost, bindPort)
|
||||||
net.JoinHostPort(bindHost, fmt.Sprintf("%d", bindPort)), ctx.RemoteAddr())
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
|
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
|
||||||
log.Warnf("remote port forwarding denied for %s:%d from %s: %v", bindHost, bindPort, ctx.RemoteAddr(), err)
|
logger.Warnf("remote port forwarding denied for %s:%d: %v", bindHost, bindPort, err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("remote port forwarding allowed: %s:%d", bindHost, bindPort)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,23 +87,20 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// checkPortForwardingPrivileges validates privilege requirements for port forwarding operations.
|
// checkPortForwardingPrivileges validates privilege requirements for port forwarding operations.
|
||||||
// Returns nil if allowed, error if denied.
|
// For remote port forwarding (binding), it enforces that non-privileged users cannot bind to
|
||||||
|
// ports below 1024, mirroring the restriction they would face if binding directly.
|
||||||
|
//
|
||||||
|
// Note: FeatureSupportsUserSwitch is true because we accept requests from any authenticated user,
|
||||||
|
// though we don't actually switch users - port forwarding runs in the server process. The resolved
|
||||||
|
// user is used for privileged port access checks.
|
||||||
func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error {
|
func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
return fmt.Errorf("%s port forwarding denied: no context", forwardType)
|
return fmt.Errorf("%s port forwarding denied: no context", forwardType)
|
||||||
}
|
}
|
||||||
|
|
||||||
username := ctx.User()
|
|
||||||
remoteAddr := "unknown"
|
|
||||||
if ctx.RemoteAddr() != nil {
|
|
||||||
remoteAddr = ctx.RemoteAddr().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
logger := log.WithFields(log.Fields{"user": username, "remote": remoteAddr, "port": port})
|
|
||||||
|
|
||||||
result := s.CheckPrivileges(PrivilegeCheckRequest{
|
result := s.CheckPrivileges(PrivilegeCheckRequest{
|
||||||
RequestedUsername: username,
|
RequestedUsername: ctx.User(),
|
||||||
FeatureSupportsUserSwitch: false,
|
FeatureSupportsUserSwitch: true,
|
||||||
FeatureName: forwardType + " port forwarding",
|
FeatureName: forwardType + " port forwarding",
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -106,12 +108,42 @@ func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType stri
|
|||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debugf("%s port forwarding allowed: user %s validated (port %d)",
|
if err := s.checkPrivilegedPortAccess(forwardType, port, result); err != nil {
|
||||||
forwardType, result.User.Username, port)
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkPrivilegedPortAccess enforces that non-privileged users cannot bind to privileged ports.
|
||||||
|
// This applies to remote port forwarding where the server binds a port on behalf of the user.
|
||||||
|
// On Windows, there is no privileged port restriction, so this check is skipped.
|
||||||
|
func (s *Server) checkPrivilegedPortAccess(forwardType string, port uint32, result PrivilegeCheckResult) error {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
isBindOperation := forwardType == "remote" || forwardType == "tcpip-forward"
|
||||||
|
if !isBindOperation {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Port 0 means "pick any available port", which will be >= 1024
|
||||||
|
if port == 0 || port >= privilegedPortThreshold {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.User != nil && isPrivilegedUsername(result.User.Username) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
username := "unknown"
|
||||||
|
if result.User != nil {
|
||||||
|
username = result.User.Username
|
||||||
|
}
|
||||||
|
return fmt.Errorf("user %s cannot bind to privileged port %d (requires root)", username, port)
|
||||||
|
}
|
||||||
|
|
||||||
// tcpipForwardHandler handles tcpip-forward requests for remote port forwarding.
|
// tcpipForwardHandler handles tcpip-forward requests for remote port forwarding.
|
||||||
func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
|
func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
|
||||||
logger := s.getRequestLogger(ctx)
|
logger := s.getRequestLogger(ctx)
|
||||||
@@ -132,8 +164,6 @@ func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *crypto
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debugf("tcpip-forward request: %s:%d", payload.Host, payload.Port)
|
|
||||||
|
|
||||||
sshConn, err := s.getSSHConnection(ctx)
|
sshConn, err := s.getSSHConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warnf("tcpip-forward request denied: %v", err)
|
logger.Warnf("tcpip-forward request denied: %v", err)
|
||||||
@@ -153,8 +183,10 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
|
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
|
||||||
if s.removeRemoteForwardListener(key) {
|
if s.removeRemoteForwardListener(key) {
|
||||||
|
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, payload.Port)
|
||||||
|
s.removeConnectionPortForward(ctx.RemoteAddr(), forwardAddr)
|
||||||
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
|
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
@@ -165,14 +197,11 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *
|
|||||||
|
|
||||||
// handleRemoteForwardListener handles incoming connections for remote port forwarding.
|
// handleRemoteForwardListener handles incoming connections for remote port forwarding.
|
||||||
func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) {
|
func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) {
|
||||||
log.Debugf("starting remote forward listener handler for %s:%d", host, port)
|
logger := s.getRequestLogger(ctx)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
log.Debugf("cleaning up remote forward listener for %s:%d", host, port)
|
|
||||||
if err := ln.Close(); err != nil {
|
if err := ln.Close(); err != nil {
|
||||||
log.Debugf("remote forward listener close error: %v", err)
|
logger.Debugf("remote forward listener close error for %s:%d: %v", host, port, err)
|
||||||
} else {
|
|
||||||
log.Debugf("remote forward listener closed successfully for %s:%d", host, port)
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -196,28 +225,43 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h
|
|||||||
select {
|
select {
|
||||||
case result := <-acceptChan:
|
case result := <-acceptChan:
|
||||||
if result.err != nil {
|
if result.err != nil {
|
||||||
log.Debugf("remote forward accept error: %v", result.err)
|
logger.Debugf("remote forward accept error: %v", result.err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
|
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
log.Debugf("remote forward listener shutting down due to context cancellation for %s:%d", host, port)
|
logger.Debugf("remote forward listener shutting down for %s:%d", host, port)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRequestLogger creates a logger with user and remote address context
|
// getRequestLogger creates a logger with session/conn and jwt_user context
|
||||||
func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry {
|
func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry {
|
||||||
remoteAddr := "unknown"
|
sessionKey := s.findSessionKeyByContext(ctx)
|
||||||
username := "unknown"
|
|
||||||
if ctx != nil {
|
s.mu.RLock()
|
||||||
if ctx.RemoteAddr() != nil {
|
defer s.mu.RUnlock()
|
||||||
remoteAddr = ctx.RemoteAddr().String()
|
|
||||||
|
if state, exists := s.sessions[sessionKey]; exists {
|
||||||
|
logger := log.WithField("session", sessionKey)
|
||||||
|
if state.jwtUsername != "" {
|
||||||
|
logger = logger.WithField("jwt_user", state.jwtUsername)
|
||||||
}
|
}
|
||||||
username = ctx.User()
|
return logger
|
||||||
}
|
}
|
||||||
return log.WithFields(log.Fields{"user": username, "remote": remoteAddr})
|
|
||||||
|
if ctx.RemoteAddr() != nil {
|
||||||
|
if connState, exists := s.connections[connKey(ctx.RemoteAddr().String())]; exists {
|
||||||
|
return s.connLogger(connState)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteAddr := "unknown"
|
||||||
|
if ctx.RemoteAddr() != nil {
|
||||||
|
remoteAddr = ctx.RemoteAddr().String()
|
||||||
|
}
|
||||||
|
return log.WithField("session", fmt.Sprintf("%s@%s", ctx.User(), remoteAddr))
|
||||||
}
|
}
|
||||||
|
|
||||||
// isRemotePortForwardingAllowed checks if remote port forwarding is enabled
|
// isRemotePortForwardingAllowed checks if remote port forwarding is enabled
|
||||||
@@ -227,6 +271,13 @@ func (s *Server) isRemotePortForwardingAllowed() bool {
|
|||||||
return s.allowRemotePortForwarding
|
return s.allowRemotePortForwarding
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled
|
||||||
|
func (s *Server) isPortForwardingEnabled() bool {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.allowLocalPortForwarding || s.allowRemotePortForwarding
|
||||||
|
}
|
||||||
|
|
||||||
// parseTcpipForwardRequest parses the SSH request payload
|
// parseTcpipForwardRequest parses the SSH request payload
|
||||||
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
|
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
|
||||||
var payload tcpipForwardMsg
|
var payload tcpipForwardMsg
|
||||||
@@ -267,10 +318,11 @@ func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn
|
|||||||
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
|
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
|
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
|
||||||
s.storeRemoteForwardListener(key, ln)
|
s.storeRemoteForwardListener(key, ln)
|
||||||
|
|
||||||
s.markConnectionActivePortForward(sshConn, ctx.User(), ctx.RemoteAddr().String())
|
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, actualPort)
|
||||||
|
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
|
||||||
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
|
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
|
||||||
|
|
||||||
response := make([]byte, 4)
|
response := make([]byte, 4)
|
||||||
@@ -288,44 +340,34 @@ type acceptResult struct {
|
|||||||
|
|
||||||
// handleRemoteForwardConnection handles a single remote port forwarding connection
|
// handleRemoteForwardConnection handles a single remote port forwarding connection
|
||||||
func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) {
|
func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) {
|
||||||
sessionKey := s.findSessionKeyByContext(ctx)
|
logger := s.getRequestLogger(ctx)
|
||||||
connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port)
|
|
||||||
logger := log.WithFields(log.Fields{
|
|
||||||
"session": sessionKey,
|
|
||||||
"conn": connID,
|
|
||||||
})
|
|
||||||
|
|
||||||
defer func() {
|
sshConn, ok := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
|
||||||
if err := conn.Close(); err != nil {
|
if !ok || sshConn == nil {
|
||||||
logger.Debugf("connection close error: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
sshConn := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
|
|
||||||
if sshConn == nil {
|
|
||||||
logger.Debugf("remote forward: no SSH connection in context")
|
logger.Debugf("remote forward: no SSH connection in context")
|
||||||
|
_ = conn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
|
remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
|
||||||
if !ok {
|
if !ok {
|
||||||
logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr())
|
logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr())
|
||||||
|
_ = conn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger)
|
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debugf("open forward channel: %v", err)
|
logger.Debugf("open forward channel for %s:%d: %v", host, port, err)
|
||||||
|
_ = conn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.proxyForwardConnection(ctx, logger, conn, channel)
|
nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// openForwardChannel creates an SSH forwarded-tcpip channel
|
// openForwardChannel creates an SSH forwarded-tcpip channel
|
||||||
func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr, logger *log.Entry) (cryptossh.Channel, error) {
|
func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr) (cryptossh.Channel, error) {
|
||||||
logger.Tracef("opening forwarded-tcpip channel for %s:%d", host, port)
|
|
||||||
|
|
||||||
payload := struct {
|
payload := struct {
|
||||||
ConnectedAddress string
|
ConnectedAddress string
|
||||||
ConnectedPort uint32
|
ConnectedPort uint32
|
||||||
@@ -346,41 +388,3 @@ func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string,
|
|||||||
go cryptossh.DiscardRequests(reqs)
|
go cryptossh.DiscardRequests(reqs)
|
||||||
return channel, nil
|
return channel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel
|
|
||||||
func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) {
|
|
||||||
done := make(chan struct{}, 2)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if _, err := io.Copy(channel, conn); err != nil {
|
|
||||||
logger.Debugf("copy error (conn->channel): %v", err)
|
|
||||||
}
|
|
||||||
done <- struct{}{}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if _, err := io.Copy(conn, channel); err != nil {
|
|
||||||
logger.Debugf("copy error (channel->conn): %v", err)
|
|
||||||
}
|
|
||||||
done <- struct{}{}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
logger.Debugf("session ended, closing connections")
|
|
||||||
case <-done:
|
|
||||||
// First copy finished, wait for second copy or context cancellation
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
logger.Debugf("session ended, closing connections")
|
|
||||||
case <-done:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := channel.Close(); err != nil {
|
|
||||||
logger.Debugf("channel close error: %v", err)
|
|
||||||
}
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
logger.Debugf("connection close error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -40,6 +41,11 @@ const (
|
|||||||
|
|
||||||
msgPrivilegedUserDisabled = "privileged user login is disabled"
|
msgPrivilegedUserDisabled = "privileged user login is disabled"
|
||||||
|
|
||||||
|
cmdInteractiveShell = "<interactive shell>"
|
||||||
|
cmdPortForwarding = "<port forwarding>"
|
||||||
|
cmdSFTP = "<sftp>"
|
||||||
|
cmdNonInteractive = "<idle>"
|
||||||
|
|
||||||
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
|
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
|
||||||
DefaultJWTMaxTokenAge = 5 * 60
|
DefaultJWTMaxTokenAge = 5 * 60
|
||||||
)
|
)
|
||||||
@@ -90,10 +96,10 @@ func logSessionExitError(logger *log.Entry, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// safeLogCommand returns a safe representation of the command for logging
|
// safeLogCommand returns a safe representation of the command for logging.
|
||||||
func safeLogCommand(cmd []string) string {
|
func safeLogCommand(cmd []string) string {
|
||||||
if len(cmd) == 0 {
|
if len(cmd) == 0 {
|
||||||
return "<interactive shell>"
|
return cmdInteractiveShell
|
||||||
}
|
}
|
||||||
if len(cmd) == 1 {
|
if len(cmd) == 1 {
|
||||||
return cmd[0]
|
return cmd[0]
|
||||||
@@ -101,26 +107,50 @@ func safeLogCommand(cmd []string) string {
|
|||||||
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
|
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
|
||||||
}
|
}
|
||||||
|
|
||||||
type sshConnectionState struct {
|
// connState tracks the state of an SSH connection for port forwarding and status display.
|
||||||
hasActivePortForward bool
|
type connState struct {
|
||||||
username string
|
username string
|
||||||
remoteAddr string
|
remoteAddr net.Addr
|
||||||
|
portForwards []string
|
||||||
|
jwtUsername string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// authKey uniquely identifies an authentication attempt by username and remote address.
|
||||||
|
// Used to temporarily store JWT username between passwordHandler and sessionHandler.
|
||||||
type authKey string
|
type authKey string
|
||||||
|
|
||||||
|
// connKey uniquely identifies an SSH connection by its remote address.
|
||||||
|
// Used to track authenticated connections for status display and port forwarding.
|
||||||
|
type connKey string
|
||||||
|
|
||||||
func newAuthKey(username string, remoteAddr net.Addr) authKey {
|
func newAuthKey(username string, remoteAddr net.Addr) authKey {
|
||||||
return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String()))
|
return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sessionState tracks an active SSH session (shell, command, or subsystem like SFTP).
|
||||||
|
type sessionState struct {
|
||||||
|
session ssh.Session
|
||||||
|
sessionType string
|
||||||
|
jwtUsername string
|
||||||
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
sshServer *ssh.Server
|
sshServer *ssh.Server
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
hostKeyPEM []byte
|
hostKeyPEM []byte
|
||||||
sessions map[SessionKey]ssh.Session
|
|
||||||
sessionCancels map[ConnectionKey]context.CancelFunc
|
// sessions tracks active SSH sessions (shell, command, SFTP).
|
||||||
sessionJWTUsers map[SessionKey]string
|
// These are created when a client opens a session channel and requests shell/exec/subsystem.
|
||||||
pendingAuthJWT map[authKey]string
|
sessions map[sessionKey]*sessionState
|
||||||
|
|
||||||
|
// pendingAuthJWT temporarily stores JWT username during the auth→session handoff.
|
||||||
|
// Populated in passwordHandler, consumed in sessionHandler/sftpSubsystemHandler.
|
||||||
|
pendingAuthJWT map[authKey]string
|
||||||
|
|
||||||
|
// connections tracks all SSH connections by their remote address.
|
||||||
|
// Populated at authentication time, stores JWT username and port forwards for status display.
|
||||||
|
connections map[connKey]*connState
|
||||||
|
|
||||||
|
|
||||||
allowLocalPortForwarding bool
|
allowLocalPortForwarding bool
|
||||||
allowRemotePortForwarding bool
|
allowRemotePortForwarding bool
|
||||||
@@ -132,8 +162,7 @@ type Server struct {
|
|||||||
|
|
||||||
wgAddress wgaddr.Address
|
wgAddress wgaddr.Address
|
||||||
|
|
||||||
remoteForwardListeners map[ForwardKey]net.Listener
|
remoteForwardListeners map[forwardKey]net.Listener
|
||||||
sshConnections map[*cryptossh.ServerConn]*sshConnectionState
|
|
||||||
|
|
||||||
jwtValidator *jwt.Validator
|
jwtValidator *jwt.Validator
|
||||||
jwtExtractor *jwt.ClaimsExtractor
|
jwtExtractor *jwt.ClaimsExtractor
|
||||||
@@ -167,6 +196,7 @@ type SessionInfo struct {
|
|||||||
RemoteAddress string
|
RemoteAddress string
|
||||||
Command string
|
Command string
|
||||||
JWTUsername string
|
JWTUsername string
|
||||||
|
PortForwards []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates an SSH server instance with the provided host key and optional JWT configuration
|
// New creates an SSH server instance with the provided host key and optional JWT configuration
|
||||||
@@ -175,11 +205,10 @@ func New(config *Config) *Server {
|
|||||||
s := &Server{
|
s := &Server{
|
||||||
mu: sync.RWMutex{},
|
mu: sync.RWMutex{},
|
||||||
hostKeyPEM: config.HostKeyPEM,
|
hostKeyPEM: config.HostKeyPEM,
|
||||||
sessions: make(map[SessionKey]ssh.Session),
|
sessions: make(map[sessionKey]*sessionState),
|
||||||
sessionJWTUsers: make(map[SessionKey]string),
|
|
||||||
pendingAuthJWT: make(map[authKey]string),
|
pendingAuthJWT: make(map[authKey]string),
|
||||||
remoteForwardListeners: make(map[ForwardKey]net.Listener),
|
remoteForwardListeners: make(map[forwardKey]net.Listener),
|
||||||
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
|
connections: make(map[connKey]*connState),
|
||||||
jwtEnabled: config.JWT != nil,
|
jwtEnabled: config.JWT != nil,
|
||||||
jwtConfig: config.JWT,
|
jwtConfig: config.JWT,
|
||||||
authorizer: sshauth.NewAuthorizer(), // Initialize with empty config
|
authorizer: sshauth.NewAuthorizer(), // Initialize with empty config
|
||||||
@@ -265,14 +294,8 @@ func (s *Server) Stop() error {
|
|||||||
s.sshServer = nil
|
s.sshServer = nil
|
||||||
|
|
||||||
maps.Clear(s.sessions)
|
maps.Clear(s.sessions)
|
||||||
maps.Clear(s.sessionJWTUsers)
|
|
||||||
maps.Clear(s.pendingAuthJWT)
|
maps.Clear(s.pendingAuthJWT)
|
||||||
maps.Clear(s.sshConnections)
|
maps.Clear(s.connections)
|
||||||
|
|
||||||
for _, cancelFunc := range s.sessionCancels {
|
|
||||||
cancelFunc()
|
|
||||||
}
|
|
||||||
maps.Clear(s.sessionCancels)
|
|
||||||
|
|
||||||
for _, listener := range s.remoteForwardListeners {
|
for _, listener := range s.remoteForwardListeners {
|
||||||
if err := listener.Close(); err != nil {
|
if err := listener.Close(); err != nil {
|
||||||
@@ -284,32 +307,70 @@ func (s *Server) Stop() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStatus returns the current status of the SSH server and active sessions
|
// GetStatus returns the current status of the SSH server and active sessions.
|
||||||
func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
|
func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
enabled = s.sshServer != nil
|
enabled = s.sshServer != nil
|
||||||
|
reportedAddrs := make(map[string]bool)
|
||||||
|
|
||||||
for sessionKey, session := range s.sessions {
|
for _, state := range s.sessions {
|
||||||
cmd := "<interactive shell>"
|
info := s.buildSessionInfo(state)
|
||||||
if len(session.Command()) > 0 {
|
reportedAddrs[info.RemoteAddress] = true
|
||||||
cmd = safeLogCommand(session.Command())
|
sessions = append(sessions, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only)
|
||||||
|
for key, connState := range s.connections {
|
||||||
|
remoteAddr := string(key)
|
||||||
|
if reportedAddrs[remoteAddr] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cmd := cmdNonInteractive
|
||||||
|
if len(connState.portForwards) > 0 {
|
||||||
|
cmd = cmdPortForwarding
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtUsername := s.sessionJWTUsers[sessionKey]
|
|
||||||
|
|
||||||
sessions = append(sessions, SessionInfo{
|
sessions = append(sessions, SessionInfo{
|
||||||
Username: session.User(),
|
Username: connState.username,
|
||||||
RemoteAddress: session.RemoteAddr().String(),
|
RemoteAddress: remoteAddr,
|
||||||
Command: cmd,
|
Command: cmd,
|
||||||
JWTUsername: jwtUsername,
|
JWTUsername: connState.jwtUsername,
|
||||||
|
PortForwards: connState.portForwards,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return enabled, sessions
|
return enabled, sessions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) buildSessionInfo(state *sessionState) SessionInfo {
|
||||||
|
session := state.session
|
||||||
|
cmd := state.sessionType
|
||||||
|
if cmd == "" {
|
||||||
|
cmd = safeLogCommand(session.Command())
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteAddr := session.RemoteAddr().String()
|
||||||
|
info := SessionInfo{
|
||||||
|
Username: session.User(),
|
||||||
|
RemoteAddress: remoteAddr,
|
||||||
|
Command: cmd,
|
||||||
|
JWTUsername: state.jwtUsername,
|
||||||
|
}
|
||||||
|
|
||||||
|
connState, exists := s.connections[connKey(remoteAddr)]
|
||||||
|
if !exists {
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
info.PortForwards = connState.portForwards
|
||||||
|
if len(connState.portForwards) > 0 && (cmd == cmdInteractiveShell || cmd == cmdNonInteractive) {
|
||||||
|
info.Command = cmdPortForwarding
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
// SetNetstackNet sets the netstack network for userspace networking
|
// SetNetstackNet sets the netstack network for userspace networking
|
||||||
func (s *Server) SetNetstackNet(net *netstack.Net) {
|
func (s *Server) SetNetstackNet(net *netstack.Net) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
@@ -520,69 +581,129 @@ func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]int
|
|||||||
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
|
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
|
||||||
osUsername := ctx.User()
|
osUsername := ctx.User()
|
||||||
remoteAddr := ctx.RemoteAddr()
|
remoteAddr := ctx.RemoteAddr()
|
||||||
|
logger := s.getRequestLogger(ctx)
|
||||||
|
|
||||||
if err := s.ensureJWTValidator(); err != nil {
|
if err := s.ensureJWTValidator(); err != nil {
|
||||||
log.Errorf("JWT validator initialization failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
logger.Errorf("JWT validator initialization failed: %v", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := s.validateJWTToken(password)
|
token, err := s.validateJWTToken(password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("JWT authentication failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
logger.Warnf("JWT authentication failed: %v", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
userAuth, err := s.extractAndValidateUser(token)
|
userAuth, err := s.extractAndValidateUser(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("User validation failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
logger.Warnf("user validation failed: %v", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger = logger.WithField("jwt_user", userAuth.UserId)
|
||||||
|
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
authorizer := s.authorizer
|
authorizer := s.authorizer
|
||||||
s.mu.RUnlock()
|
s.mu.RUnlock()
|
||||||
|
|
||||||
if err := authorizer.Authorize(userAuth.UserId, osUsername); err != nil {
|
msg, err := authorizer.Authorize(userAuth.UserId, osUsername)
|
||||||
log.Warnf("SSH authorization denied for user %s (JWT user ID: %s) from %s: %v", osUsername, userAuth.UserId, remoteAddr, err)
|
if err != nil {
|
||||||
|
logger.Warnf("SSH auth denied: %v", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Infof("SSH auth %s", msg)
|
||||||
|
|
||||||
key := newAuthKey(osUsername, remoteAddr)
|
key := newAuthKey(osUsername, remoteAddr)
|
||||||
|
remoteAddrStr := ctx.RemoteAddr().String()
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.pendingAuthJWT[key] = userAuth.UserId
|
s.pendingAuthJWT[key] = userAuth.UserId
|
||||||
|
s.connections[connKey(remoteAddrStr)] = &connState{
|
||||||
|
username: ctx.User(),
|
||||||
|
remoteAddr: ctx.RemoteAddr(),
|
||||||
|
jwtUsername: userAuth.UserId,
|
||||||
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", osUsername, userAuth.UserId, remoteAddr)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
|
func (s *Server) addConnectionPortForward(username string, remoteAddr net.Addr, forwardAddr string) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
if state, exists := s.sshConnections[sshConn]; exists {
|
key := connKey(remoteAddr.String())
|
||||||
state.hasActivePortForward = true
|
if state, exists := s.connections[key]; exists {
|
||||||
} else {
|
if !slices.Contains(state.portForwards, forwardAddr) {
|
||||||
s.sshConnections[sshConn] = &sshConnectionState{
|
state.portForwards = append(state.portForwards, forwardAddr)
|
||||||
hasActivePortForward: true,
|
|
||||||
username: username,
|
|
||||||
remoteAddr: remoteAddr,
|
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection not in connections (non-JWT auth path)
|
||||||
|
s.connections[key] = &connState{
|
||||||
|
username: username,
|
||||||
|
remoteAddr: remoteAddr,
|
||||||
|
portForwards: []string{forwardAddr},
|
||||||
|
jwtUsername: s.pendingAuthJWT[newAuthKey(username, remoteAddr)],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
|
func (s *Server) removeConnectionPortForward(remoteAddr net.Addr, forwardAddr string) {
|
||||||
// We can't extract the SSH connection from net.Conn directly
|
s.mu.Lock()
|
||||||
// Connection cleanup will happen during session cleanup or via timeout
|
defer s.mu.Unlock()
|
||||||
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
|
|
||||||
|
state, exists := s.connections[connKey(remoteAddr.String())]
|
||||||
|
if !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state.portForwards = slices.DeleteFunc(state.portForwards, func(addr string) bool {
|
||||||
|
return addr == forwardAddr
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
|
// trackedConn wraps a net.Conn to detect when it closes
|
||||||
|
type trackedConn struct {
|
||||||
|
net.Conn
|
||||||
|
server *Server
|
||||||
|
remoteAddr string
|
||||||
|
onceClose sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *trackedConn) Close() error {
|
||||||
|
err := c.Conn.Close()
|
||||||
|
c.onceClose.Do(func() {
|
||||||
|
c.server.handleConnectionClose(c.remoteAddr)
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleConnectionClose(remoteAddr string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
key := connKey(remoteAddr)
|
||||||
|
state, exists := s.connections[key]
|
||||||
|
if exists && len(state.portForwards) > 0 {
|
||||||
|
s.connLogger(state).Info("port forwarding connection closed")
|
||||||
|
}
|
||||||
|
delete(s.connections, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) connLogger(state *connState) *log.Entry {
|
||||||
|
logger := log.WithField("session", fmt.Sprintf("%s@%s", state.username, state.remoteAddr))
|
||||||
|
if state.jwtUsername != "" {
|
||||||
|
logger = logger.WithField("jwt_user", state.jwtUsername)
|
||||||
|
}
|
||||||
|
return logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) findSessionKeyByContext(ctx ssh.Context) sessionKey {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
return "unknown"
|
return "unknown"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to match by SSH connection
|
|
||||||
sshConn := ctx.Value(ssh.ContextKeyConn)
|
sshConn := ctx.Value(ssh.ContextKeyConn)
|
||||||
if sshConn == nil {
|
if sshConn == nil {
|
||||||
return "unknown"
|
return "unknown"
|
||||||
@@ -591,19 +712,14 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
|
|||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
// Look through sessions to find one with matching connection
|
for sessionKey, state := range s.sessions {
|
||||||
for sessionKey, session := range s.sessions {
|
if state.session.Context().Value(ssh.ContextKeyConn) == sshConn {
|
||||||
if session.Context().Value(ssh.ContextKeyConn) == sshConn {
|
|
||||||
return sessionKey
|
return sessionKey
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no session found, this might be during early connection setup
|
|
||||||
// Return a temporary key that we'll fix up later
|
|
||||||
if ctx.User() != "" && ctx.RemoteAddr() != nil {
|
if ctx.User() != "" && ctx.RemoteAddr() != nil {
|
||||||
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
|
return sessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
|
||||||
log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
|
|
||||||
return tempKey
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return "unknown"
|
return "unknown"
|
||||||
@@ -644,7 +760,11 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr)
|
log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr)
|
||||||
return conn
|
return &trackedConn{
|
||||||
|
Conn: conn,
|
||||||
|
server: s,
|
||||||
|
remoteAddr: conn.RemoteAddr().String(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
||||||
@@ -672,9 +792,8 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
|||||||
"tcpip-forward": s.tcpipForwardHandler,
|
"tcpip-forward": s.tcpipForwardHandler,
|
||||||
"cancel-tcpip-forward": s.cancelTcpipForwardHandler,
|
"cancel-tcpip-forward": s.cancelTcpipForwardHandler,
|
||||||
},
|
},
|
||||||
ConnCallback: s.connectionValidator,
|
ConnCallback: s.connectionValidator,
|
||||||
ConnectionFailedCallback: s.connectionCloseHandler,
|
Version: serverVersion,
|
||||||
Version: serverVersion,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.jwtEnabled {
|
if s.jwtEnabled {
|
||||||
@@ -690,13 +809,13 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
|||||||
return server, nil
|
return server, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
|
func (s *Server) storeRemoteForwardListener(key forwardKey, ln net.Listener) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
s.remoteForwardListeners[key] = ln
|
s.remoteForwardListeners[key] = ln
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
|
func (s *Server) removeRemoteForwardListener(key forwardKey) bool {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
@@ -714,6 +833,8 @@ func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
|
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
|
||||||
|
logger := s.getRequestLogger(ctx)
|
||||||
|
|
||||||
var payload struct {
|
var payload struct {
|
||||||
Host string
|
Host string
|
||||||
Port uint32
|
Port uint32
|
||||||
@@ -723,7 +844,7 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
|
|||||||
|
|
||||||
if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
|
if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
|
||||||
if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil {
|
if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil {
|
||||||
log.Debugf("channel reject error: %v", err)
|
logger.Debugf("channel reject error: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -733,19 +854,20 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
|
|||||||
s.mu.RUnlock()
|
s.mu.RUnlock()
|
||||||
|
|
||||||
if !allowLocal {
|
if !allowLocal {
|
||||||
log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port)
|
logger.Warnf("local port forwarding denied for %s:%d: disabled", payload.Host, payload.Port)
|
||||||
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
|
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check privilege requirements for the destination port
|
|
||||||
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
|
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
|
||||||
log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
|
logger.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
|
||||||
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
|
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
|
forwardAddr := fmt.Sprintf("-L %s:%d", payload.Host, payload.Port)
|
||||||
|
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
|
||||||
|
logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
|
||||||
|
|
||||||
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
|
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -224,6 +224,96 @@ func TestServer_PortForwardingRestriction(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServer_PrivilegedPortAccess(t *testing.T) {
|
||||||
|
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
serverConfig := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
|
server.SetAllowRemotePortForwarding(true)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
forwardType string
|
||||||
|
port uint32
|
||||||
|
username string
|
||||||
|
expectError bool
|
||||||
|
errorMsg string
|
||||||
|
skipOnWindows bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "non-root user remote forward privileged port",
|
||||||
|
forwardType: "remote",
|
||||||
|
port: 80,
|
||||||
|
username: "testuser",
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "cannot bind to privileged port",
|
||||||
|
skipOnWindows: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-root user tcpip-forward privileged port",
|
||||||
|
forwardType: "tcpip-forward",
|
||||||
|
port: 443,
|
||||||
|
username: "testuser",
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "cannot bind to privileged port",
|
||||||
|
skipOnWindows: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-root user remote forward unprivileged port",
|
||||||
|
forwardType: "remote",
|
||||||
|
port: 8080,
|
||||||
|
username: "testuser",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-root user remote forward port 0",
|
||||||
|
forwardType: "remote",
|
||||||
|
port: 0,
|
||||||
|
username: "testuser",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root user remote forward privileged port",
|
||||||
|
forwardType: "remote",
|
||||||
|
port: 22,
|
||||||
|
username: "root",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "local forward privileged port allowed for non-root",
|
||||||
|
forwardType: "local",
|
||||||
|
port: 80,
|
||||||
|
username: "testuser",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.skipOnWindows && runtime.GOOS == "windows" {
|
||||||
|
t.Skip("Windows does not have privileged port restrictions")
|
||||||
|
}
|
||||||
|
|
||||||
|
result := PrivilegeCheckResult{
|
||||||
|
Allowed: true,
|
||||||
|
User: &user.User{Username: tt.username},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := server.checkPrivilegedPortAccess(tt.forwardType, tt.port, result)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer_PortConflictHandling(t *testing.T) {
|
func TestServer_PortConflictHandling(t *testing.T) {
|
||||||
// Test that multiple sessions requesting the same local port are handled naturally by the OS
|
// Test that multiple sessions requesting the same local port are handled naturally by the OS
|
||||||
// Get current user for SSH connection
|
// Get current user for SSH connection
|
||||||
@@ -392,3 +482,95 @@ func TestServer_IsPrivilegedUser(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServer_PortForwardingOnlySession(t *testing.T) {
|
||||||
|
// Test that sessions without PTY and command are allowed when port forwarding is enabled
|
||||||
|
currentUser, err := user.Current()
|
||||||
|
require.NoError(t, err, "Should be able to get current user")
|
||||||
|
|
||||||
|
// Generate host key for server
|
||||||
|
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
allowLocalForwarding bool
|
||||||
|
allowRemoteForwarding bool
|
||||||
|
expectAllowed bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "session_allowed_with_local_forwarding",
|
||||||
|
allowLocalForwarding: true,
|
||||||
|
allowRemoteForwarding: false,
|
||||||
|
expectAllowed: true,
|
||||||
|
description: "Port-forwarding-only session should be allowed when local forwarding is enabled",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "session_allowed_with_remote_forwarding",
|
||||||
|
allowLocalForwarding: false,
|
||||||
|
allowRemoteForwarding: true,
|
||||||
|
expectAllowed: true,
|
||||||
|
description: "Port-forwarding-only session should be allowed when remote forwarding is enabled",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "session_allowed_with_both",
|
||||||
|
allowLocalForwarding: true,
|
||||||
|
allowRemoteForwarding: true,
|
||||||
|
expectAllowed: true,
|
||||||
|
description: "Port-forwarding-only session should be allowed when both forwarding types enabled",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "session_denied_without_forwarding",
|
||||||
|
allowLocalForwarding: false,
|
||||||
|
allowRemoteForwarding: false,
|
||||||
|
expectAllowed: false,
|
||||||
|
description: "Port-forwarding-only session should be denied when all forwarding is disabled",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
serverConfig := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: nil,
|
||||||
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
|
server.SetAllowRootLogin(true)
|
||||||
|
server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
|
||||||
|
server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
|
||||||
|
|
||||||
|
serverAddr := StartTestServer(t, server)
|
||||||
|
defer func() {
|
||||||
|
_ = server.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Connect to the server without requesting PTY or command
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
client, err := sshclient.Dial(ctx, serverAddr, currentUser.Username, sshclient.DialOptions{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = client.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Execute a command without PTY - this simulates ssh -T with no command
|
||||||
|
// The server should either allow it (port forwarding enabled) or reject it
|
||||||
|
output, err := client.ExecuteCommand(ctx, "")
|
||||||
|
if tt.expectAllowed {
|
||||||
|
// When allowed, the session stays open until cancelled
|
||||||
|
// ExecuteCommand with empty command should return without error
|
||||||
|
assert.NoError(t, err, "Session should be allowed when port forwarding is enabled")
|
||||||
|
assert.NotContains(t, output, "port forwarding is disabled",
|
||||||
|
"Output should not contain port forwarding disabled message")
|
||||||
|
} else if err != nil {
|
||||||
|
// When denied, we expect an error message about port forwarding being disabled
|
||||||
|
assert.Contains(t, err.Error(), "port forwarding is disabled",
|
||||||
|
"Should get port forwarding disabled message")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,37 +6,45 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gliderlabs/ssh"
|
"github.com/gliderlabs/ssh"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
cryptossh "golang.org/x/crypto/ssh"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// associateJWTUsername extracts pending JWT username for the session and associates it with the session state.
|
||||||
|
// Returns the JWT username (empty if none) for logging purposes.
|
||||||
|
func (s *Server) associateJWTUsername(sess ssh.Session, sessionKey sessionKey) string {
|
||||||
|
key := newAuthKey(sess.User(), sess.RemoteAddr())
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
jwtUsername := s.pendingAuthJWT[key]
|
||||||
|
if jwtUsername == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if state, exists := s.sessions[sessionKey]; exists {
|
||||||
|
state.jwtUsername = jwtUsername
|
||||||
|
}
|
||||||
|
delete(s.pendingAuthJWT, key)
|
||||||
|
return jwtUsername
|
||||||
|
}
|
||||||
|
|
||||||
// sessionHandler handles SSH sessions
|
// sessionHandler handles SSH sessions
|
||||||
func (s *Server) sessionHandler(session ssh.Session) {
|
func (s *Server) sessionHandler(session ssh.Session) {
|
||||||
sessionKey := s.registerSession(session)
|
sessionKey := s.registerSession(session, "")
|
||||||
|
jwtUsername := s.associateJWTUsername(session, sessionKey)
|
||||||
key := newAuthKey(session.User(), session.RemoteAddr())
|
|
||||||
s.mu.Lock()
|
|
||||||
jwtUsername := s.pendingAuthJWT[key]
|
|
||||||
if jwtUsername != "" {
|
|
||||||
s.sessionJWTUsers[sessionKey] = jwtUsername
|
|
||||||
delete(s.pendingAuthJWT, key)
|
|
||||||
}
|
|
||||||
s.mu.Unlock()
|
|
||||||
|
|
||||||
logger := log.WithField("session", sessionKey)
|
logger := log.WithField("session", sessionKey)
|
||||||
if jwtUsername != "" {
|
if jwtUsername != "" {
|
||||||
logger = logger.WithField("jwt_user", jwtUsername)
|
logger = logger.WithField("jwt_user", jwtUsername)
|
||||||
logger.Infof("SSH session started (JWT user: %s)", jwtUsername)
|
|
||||||
} else {
|
|
||||||
logger.Infof("SSH session started")
|
|
||||||
}
|
}
|
||||||
|
logger.Info("SSH session started")
|
||||||
sessionStart := time.Now()
|
sessionStart := time.Now()
|
||||||
|
|
||||||
defer s.unregisterSession(sessionKey, session)
|
defer s.unregisterSession(sessionKey)
|
||||||
defer func() {
|
defer func() {
|
||||||
duration := time.Since(sessionStart).Round(time.Millisecond)
|
duration := time.Since(sessionStart).Round(time.Millisecond)
|
||||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||||
@@ -65,27 +73,52 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
|||||||
// ssh <host> <cmd> - non-Pty command execution
|
// ssh <host> <cmd> - non-Pty command execution
|
||||||
s.handleCommand(logger, session, privilegeResult, nil)
|
s.handleCommand(logger, session, privilegeResult, nil)
|
||||||
default:
|
default:
|
||||||
s.rejectInvalidSession(logger, session)
|
// ssh -T (or ssh -N) - no PTY, no command
|
||||||
|
s.handleNonInteractiveSession(logger, session)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) {
|
// handleNonInteractiveSession handles sessions that have no PTY and no command.
|
||||||
if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil {
|
// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N).
|
||||||
logger.Debugf(errWriteSession, err)
|
func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) {
|
||||||
|
s.updateSessionType(session, cmdNonInteractive)
|
||||||
|
|
||||||
|
if !s.isPortForwardingEnabled() {
|
||||||
|
if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil {
|
||||||
|
logger.Debugf(errWriteSession, err)
|
||||||
|
}
|
||||||
|
if err := session.Exit(1); err != nil {
|
||||||
|
logSessionExitError(logger, err)
|
||||||
|
}
|
||||||
|
logger.Infof("rejected non-interactive session: port forwarding disabled")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if err := session.Exit(1); err != nil {
|
|
||||||
|
<-session.Context().Done()
|
||||||
|
|
||||||
|
if err := session.Exit(0); err != nil {
|
||||||
logSessionExitError(logger, err)
|
logSessionExitError(logger, err)
|
||||||
}
|
}
|
||||||
logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) registerSession(session ssh.Session) SessionKey {
|
func (s *Server) updateSessionType(session ssh.Session, sessionType string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
for _, state := range s.sessions {
|
||||||
|
if state.session == session {
|
||||||
|
state.sessionType = sessionType
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) registerSession(session ssh.Session, sessionType string) sessionKey {
|
||||||
sessionID := session.Context().Value(ssh.ContextKeySessionID)
|
sessionID := session.Context().Value(ssh.ContextKeySessionID)
|
||||||
if sessionID == nil {
|
if sessionID == nil {
|
||||||
sessionID = fmt.Sprintf("%p", session)
|
sessionID = fmt.Sprintf("%p", session)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a short 4-byte identifier from the full session ID
|
|
||||||
hasher := sha256.New()
|
hasher := sha256.New()
|
||||||
hasher.Write([]byte(fmt.Sprintf("%v", sessionID)))
|
hasher.Write([]byte(fmt.Sprintf("%v", sessionID)))
|
||||||
hash := hasher.Sum(nil)
|
hash := hasher.Sum(nil)
|
||||||
@@ -93,43 +126,23 @@ func (s *Server) registerSession(session ssh.Session) SessionKey {
|
|||||||
|
|
||||||
remoteAddr := session.RemoteAddr().String()
|
remoteAddr := session.RemoteAddr().String()
|
||||||
username := session.User()
|
username := session.User()
|
||||||
sessionKey := SessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID))
|
sessionKey := sessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID))
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.sessions[sessionKey] = session
|
s.sessions[sessionKey] = &sessionState{
|
||||||
|
session: session,
|
||||||
|
sessionType: sessionType,
|
||||||
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
return sessionKey
|
return sessionKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) {
|
func (s *Server) unregisterSession(sessionKey sessionKey) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
delete(s.sessions, sessionKey)
|
delete(s.sessions, sessionKey)
|
||||||
delete(s.sessionJWTUsers, sessionKey)
|
|
||||||
|
|
||||||
// Cancel all port forwarding connections for this session
|
|
||||||
var connectionsToCancel []ConnectionKey
|
|
||||||
for key := range s.sessionCancels {
|
|
||||||
if strings.HasPrefix(string(key), string(sessionKey)+"-") {
|
|
||||||
connectionsToCancel = append(connectionsToCancel, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, key := range connectionsToCancel {
|
|
||||||
if cancelFunc, exists := s.sessionCancels[key]; exists {
|
|
||||||
log.WithField("session", sessionKey).Debugf("cancelling port forwarding context: %s", key)
|
|
||||||
cancelFunc()
|
|
||||||
delete(s.sessionCancels, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil {
|
|
||||||
if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok {
|
|
||||||
delete(s.sshConnections, sshConn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) {
|
func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) {
|
||||||
|
|||||||
@@ -18,14 +18,26 @@ func (s *Server) SetAllowSFTP(allow bool) {
|
|||||||
|
|
||||||
// sftpSubsystemHandler handles SFTP subsystem requests
|
// sftpSubsystemHandler handles SFTP subsystem requests
|
||||||
func (s *Server) sftpSubsystemHandler(sess ssh.Session) {
|
func (s *Server) sftpSubsystemHandler(sess ssh.Session) {
|
||||||
|
sessionKey := s.registerSession(sess, cmdSFTP)
|
||||||
|
defer s.unregisterSession(sessionKey)
|
||||||
|
|
||||||
|
jwtUsername := s.associateJWTUsername(sess, sessionKey)
|
||||||
|
|
||||||
|
logger := log.WithField("session", sessionKey)
|
||||||
|
if jwtUsername != "" {
|
||||||
|
logger = logger.WithField("jwt_user", jwtUsername)
|
||||||
|
}
|
||||||
|
logger.Info("SFTP session started")
|
||||||
|
defer logger.Info("SFTP session closed")
|
||||||
|
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
allowSFTP := s.allowSFTP
|
allowSFTP := s.allowSFTP
|
||||||
s.mu.RUnlock()
|
s.mu.RUnlock()
|
||||||
|
|
||||||
if !allowSFTP {
|
if !allowSFTP {
|
||||||
log.Debugf("SFTP subsystem request denied: SFTP disabled")
|
logger.Debug("SFTP subsystem request denied: SFTP disabled")
|
||||||
if err := sess.Exit(1); err != nil {
|
if err := sess.Exit(1); err != nil {
|
||||||
log.Debugf("SFTP session exit failed: %v", err)
|
logger.Debugf("SFTP session exit: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -37,31 +49,27 @@ func (s *Server) sftpSubsystemHandler(sess ssh.Session) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if !result.Allowed {
|
if !result.Allowed {
|
||||||
log.Warnf("SFTP access denied for user %s from %s: %v", sess.User(), sess.RemoteAddr(), result.Error)
|
logger.Warnf("SFTP access denied: %v", result.Error)
|
||||||
if err := sess.Exit(1); err != nil {
|
if err := sess.Exit(1); err != nil {
|
||||||
log.Debugf("exit SFTP session: %v", err)
|
logger.Debugf("exit SFTP session: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("SFTP subsystem request from user %s (effective user %s)", sess.User(), result.User.Username)
|
|
||||||
|
|
||||||
if !result.RequiresUserSwitching {
|
if !result.RequiresUserSwitching {
|
||||||
if err := s.executeSftpDirect(sess); err != nil {
|
if err := s.executeSftpDirect(sess); err != nil {
|
||||||
log.Errorf("SFTP direct execution: %v", err)
|
logger.Errorf("SFTP direct execution: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.executeSftpWithPrivilegeDrop(sess, result.User); err != nil {
|
if err := s.executeSftpWithPrivilegeDrop(sess, result.User); err != nil {
|
||||||
log.Errorf("SFTP privilege drop execution: %v", err)
|
logger.Errorf("SFTP privilege drop execution: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeSftpDirect executes SFTP directly without privilege dropping
|
// executeSftpDirect executes SFTP directly without privilege dropping
|
||||||
func (s *Server) executeSftpDirect(sess ssh.Session) error {
|
func (s *Server) executeSftpDirect(sess ssh.Session) error {
|
||||||
log.Debugf("starting SFTP session for user %s (no privilege dropping)", sess.User())
|
|
||||||
|
|
||||||
sftpServer, err := sftp.NewServer(sess)
|
sftpServer, err := sftp.NewServer(sess)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("SFTP server creation: %w", err)
|
return fmt.Errorf("SFTP server creation: %w", err)
|
||||||
|
|||||||
@@ -82,10 +82,11 @@ type NsServerGroupStateOutput struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SSHSessionOutput struct {
|
type SSHSessionOutput struct {
|
||||||
Username string `json:"username" yaml:"username"`
|
Username string `json:"username" yaml:"username"`
|
||||||
RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
|
RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
|
||||||
Command string `json:"command" yaml:"command"`
|
Command string `json:"command" yaml:"command"`
|
||||||
JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"`
|
JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"`
|
||||||
|
PortForwards []string `json:"portForwards,omitempty" yaml:"portForwards,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SSHServerStateOutput struct {
|
type SSHServerStateOutput struct {
|
||||||
@@ -220,6 +221,7 @@ func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput {
|
|||||||
RemoteAddress: session.GetRemoteAddress(),
|
RemoteAddress: session.GetRemoteAddress(),
|
||||||
Command: session.GetCommand(),
|
Command: session.GetCommand(),
|
||||||
JWTUsername: session.GetJwtUsername(),
|
JWTUsername: session.GetJwtUsername(),
|
||||||
|
PortForwards: session.GetPortForwards(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -475,6 +477,9 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
sshServerStatus += "\n " + sessionDisplay
|
sshServerStatus += "\n " + sessionDisplay
|
||||||
|
for _, pf := range session.PortForwards {
|
||||||
|
sshServerStatus += "\n " + pf
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user