mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-12 19:09:54 +00:00
Compare commits
3 Commits
embedded-v
...
socket-grp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d65927275d | ||
|
|
064f7bf0fd | ||
|
|
644615fed6 |
36
client/cmd/peercred_bsd.go
Normal file
36
client/cmd/peercred_bsd.go
Normal file
@@ -0,0 +1,36 @@
|
||||
//go:build darwin || freebsd
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// peerUID returns the uid of the process on the other end of a unix socket
|
||||
// connection, read via LOCAL_PEERCRED (xucred). Note: xucred carries the uid
|
||||
// and group list but no pid, so audit on these platforms is uid-based.
|
||||
func peerUID(c net.Conn) (int, error) {
|
||||
uc, ok := c.(*net.UnixConn)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("connection is not a unix socket: %T", c)
|
||||
}
|
||||
raw, err := uc.SyscallConn()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("raw conn: %w", err)
|
||||
}
|
||||
|
||||
var cred *unix.Xucred
|
||||
var credErr error
|
||||
if err := raw.Control(func(fd uintptr) {
|
||||
cred, credErr = unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
|
||||
}); err != nil {
|
||||
return 0, fmt.Errorf("getsockopt control: %w", err)
|
||||
}
|
||||
if credErr != nil {
|
||||
return 0, fmt.Errorf("LOCAL_PEERCRED: %w", credErr)
|
||||
}
|
||||
return int(cred.Uid), nil
|
||||
}
|
||||
35
client/cmd/peercred_linux.go
Normal file
35
client/cmd/peercred_linux.go
Normal file
@@ -0,0 +1,35 @@
|
||||
//go:build linux
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// peerUID returns the uid of the process on the other end of a unix socket
|
||||
// connection, read from the kernel via SO_PEERCRED.
|
||||
func peerUID(c net.Conn) (int, error) {
|
||||
uc, ok := c.(*net.UnixConn)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("connection is not a unix socket: %T", c)
|
||||
}
|
||||
raw, err := uc.SyscallConn()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("raw conn: %w", err)
|
||||
}
|
||||
|
||||
var cred *unix.Ucred
|
||||
var credErr error
|
||||
if err := raw.Control(func(fd uintptr) {
|
||||
cred, credErr = unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
|
||||
}); err != nil {
|
||||
return 0, fmt.Errorf("getsockopt control: %w", err)
|
||||
}
|
||||
if credErr != nil {
|
||||
return 0, fmt.Errorf("SO_PEERCRED: %w", credErr)
|
||||
}
|
||||
return int(cred.Uid), nil
|
||||
}
|
||||
16
client/cmd/peercred_unsupported.go
Normal file
16
client/cmd/peercred_unsupported.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build !linux && !darwin && !freebsd
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// peerUID is unimplemented on this platform, so the trust-on-first-use socket
|
||||
// migration cannot run here. Configure --socket-owner explicitly, or use
|
||||
// --disable-strict-socket. (Windows uses a TCP socket and never reaches this.)
|
||||
func peerUID(net.Conn) (int, error) {
|
||||
return 0, fmt.Errorf("peer credential check not supported on %s", runtime.GOOS)
|
||||
}
|
||||
@@ -77,6 +77,8 @@ var (
|
||||
updateSettingsDisabled bool
|
||||
captureEnabled bool
|
||||
networksDisabled bool
|
||||
socketOwner string
|
||||
strictSocketDisabled bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
|
||||
@@ -57,6 +57,9 @@ func init() {
|
||||
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||
|
||||
serviceCmd.PersistentFlags().StringVar(&socketOwner, "socket-owner", "", "user to own the daemon control socket; restricts it to that user plus the netbird group (0660). If unset, the first client to connect claims ownership (trust-on-first-use)")
|
||||
serviceCmd.PersistentFlags().BoolVar(&strictSocketDisabled, "disable-strict-socket", false, "leave the daemon control socket world-writable (0666) instead of restricting it; set via the (root-only) service command")
|
||||
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,10 +4,15 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
@@ -16,6 +21,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -54,10 +60,36 @@ func (p *program) Start(svc service.Service) error {
|
||||
go func() {
|
||||
defer listen.Close()
|
||||
|
||||
srvListener := listen
|
||||
if split[0] == "unix" {
|
||||
if err := os.Chmod(split[1], 0666); err != nil {
|
||||
log.Errorf("failed setting daemon permissions: %v", split[1])
|
||||
return
|
||||
owner := effectiveSocketOwner()
|
||||
switch {
|
||||
case strictSocketDisabled:
|
||||
// Opt-out (root-only, via service.json): leave it world-writable.
|
||||
if err := os.Chmod(split[1], 0666); err != nil {
|
||||
log.Errorf("failed setting daemon permissions: %v", split[1])
|
||||
return
|
||||
}
|
||||
case owner != "":
|
||||
// Seeded owner (flag, MDM, or persisted TOFU result): restrict
|
||||
// before serving so there is no open window.
|
||||
uid, err := lookupUser(owner)
|
||||
if err != nil {
|
||||
log.Errorf("lookup socket owner %q: %v", owner, err)
|
||||
return
|
||||
}
|
||||
if err := restrictSocket(split[1], uid); err != nil {
|
||||
log.Errorf("restrict socket to %q: %v", owner, err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
// Trust-on-first-use: open the socket now; tofuListener locks it
|
||||
// to the first caller's uid on the first connection.
|
||||
if err := os.Chmod(split[1], 0666); err != nil {
|
||||
log.Errorf("failed setting daemon permissions: %v", split[1])
|
||||
return
|
||||
}
|
||||
srvListener = &tofuListener{Listener: listen, path: split[1], owner: -1}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,13 +104,180 @@ func (p *program) Start(svc service.Service) error {
|
||||
p.serverInstanceMu.Unlock()
|
||||
|
||||
log.Printf("started daemon server: %v", split[1])
|
||||
if err := p.serv.Serve(listen); err != nil {
|
||||
if err := p.serv.Serve(srvListener); err != nil {
|
||||
log.Errorf("failed to serve daemon requests: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupUser(username string) (int, error) {
|
||||
u, err := shell.LookupWithGetent(username)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("lookup user %s: %w", username, err)
|
||||
}
|
||||
uid, err := strconv.Atoi(u.Uid)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("parse uid %s: %w", u.Uid, err)
|
||||
}
|
||||
return uid, nil
|
||||
}
|
||||
|
||||
// addGroup creates a system group if it doesn't already exist and returns the gid.
|
||||
// Must run as root.
|
||||
func addGroup(name string) (int, error) {
|
||||
group, err := shell.LookupGroupWithGetent(name)
|
||||
if err == nil {
|
||||
gid, err := strconv.ParseInt(group.Gid, 10, 64)
|
||||
return int(gid), err
|
||||
}
|
||||
|
||||
// looup failed, create the group
|
||||
groupadd, err := exec.LookPath("groupadd")
|
||||
if err != nil {
|
||||
// Fallback for Alpine/BusyBox systems.
|
||||
if groupadd, err = exec.LookPath("addgroup"); err != nil {
|
||||
return -1, errors.New("neither groupadd nor addgroup found")
|
||||
}
|
||||
}
|
||||
|
||||
// Use --system for a service/daemon group (no login, low GID).
|
||||
out, err := exec.Command(groupadd, "--system", name).CombinedOutput()
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("create group %q: %w: %s", name, err, out)
|
||||
}
|
||||
if group, err := shell.LookupWithGetent(name); err == nil {
|
||||
gid, err := strconv.ParseInt(group.Gid, 10, 64)
|
||||
return int(gid), err
|
||||
}
|
||||
return -1, fmt.Errorf("lookup group %q: %w", name, err)
|
||||
}
|
||||
|
||||
// restrictSocket locks the unix socket down to the owner uid plus the netbird
|
||||
// group (0660). If the group cannot be created or applied, it fails closed to
|
||||
// owner-only 0600 — it never leaves the socket world-writable.
|
||||
func restrictSocket(path string, uid int) error {
|
||||
// TODO: introduce flag to configure this (LDAP/AD usecase)
|
||||
gid, err := addGroup("netbird")
|
||||
if err != nil {
|
||||
log.Errorf("create netbird group, failing closed to owner-only 0600: %v", err)
|
||||
return chownChmod(path, uid, -1, 0600)
|
||||
}
|
||||
if err := chownChmod(path, uid, gid, 0660); err != nil {
|
||||
log.Errorf("apply netbird group to socket, failing closed to owner-only 0600: %v", err)
|
||||
return chownChmod(path, uid, -1, 0600)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// chownChmod sets ownership and mode on the socket. A gid of -1 leaves the
|
||||
// group unchanged.
|
||||
func chownChmod(path string, uid, gid int, mode os.FileMode) error {
|
||||
if err := os.Chown(path, uid, gid); err != nil {
|
||||
return fmt.Errorf("chown socket %s: %w", path, err)
|
||||
}
|
||||
if err := os.Chmod(path, mode); err != nil {
|
||||
return fmt.Errorf("chmod socket %s: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// tofuListener implements trust-on-first-use for the daemon control socket.
|
||||
// The socket starts world-writable; the first caller's uid (read via the
|
||||
// platform peer-credential mechanism) becomes the owner. On that first
|
||||
// connection the socket is restricted (see restrictSocket) and the owner is
|
||||
// persisted so the open window never reopens on later starts. Connections that
|
||||
// raced in during the open window and are neither the owner nor root are
|
||||
// dropped. Changing the socket mode does not disturb the already-open
|
||||
// connection, so the first caller's request is served normally.
|
||||
type tofuListener struct {
|
||||
net.Listener
|
||||
path string
|
||||
mu sync.Mutex
|
||||
owner int // -1 until claimed
|
||||
}
|
||||
|
||||
func (l *tofuListener) Accept() (net.Conn, error) {
|
||||
for {
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uid, err := peerUID(c)
|
||||
if err != nil {
|
||||
log.Errorf("read peer credentials, dropping connection: %v", err)
|
||||
_ = c.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
if l.owner == -1 {
|
||||
if err := restrictSocket(l.path, uid); err != nil {
|
||||
l.mu.Unlock()
|
||||
_ = c.Close()
|
||||
// Refuse to serve on a socket we could not lock down.
|
||||
return nil, fmt.Errorf("restrict socket on first connection: %w", err)
|
||||
}
|
||||
l.owner = uid
|
||||
persistSocketOwner(uid)
|
||||
log.Infof("control socket restricted to first caller (uid %d)", uid)
|
||||
l.mu.Unlock()
|
||||
return c, nil
|
||||
}
|
||||
owner := l.owner
|
||||
l.mu.Unlock()
|
||||
|
||||
// New connects are already gated by the 0660 perms set above; this only
|
||||
// drops anything that slipped in during the brief open window.
|
||||
if uid != owner && uid != 0 {
|
||||
log.Warnf("dropping non-owner connection (uid %d) during socket bootstrap", uid)
|
||||
_ = c.Close()
|
||||
continue
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
// effectiveSocketOwner returns the configured socket owner: the --socket-owner
|
||||
// flag when set, otherwise the owner persisted by a previous TOFU migration.
|
||||
func effectiveSocketOwner() string {
|
||||
if socketOwner != "" {
|
||||
return socketOwner
|
||||
}
|
||||
params, err := loadServiceParams()
|
||||
if err != nil {
|
||||
log.Errorf("load service params for socket owner: %v", err)
|
||||
return ""
|
||||
}
|
||||
if params != nil {
|
||||
return params.SocketOwner
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// persistSocketOwner records the TOFU-selected owner (by username) so the next
|
||||
// daemon start restricts the socket immediately, with no open window.
|
||||
func persistSocketOwner(uid int) {
|
||||
u, err := user.LookupId(strconv.Itoa(uid))
|
||||
if err != nil {
|
||||
log.Errorf("resolve uid %d to username for persistence: %v", uid, err)
|
||||
return
|
||||
}
|
||||
params, err := loadServiceParams()
|
||||
if err != nil {
|
||||
log.Errorf("load service params to persist socket owner: %v", err)
|
||||
return
|
||||
}
|
||||
if params == nil {
|
||||
params = currentServiceParams()
|
||||
}
|
||||
params.SocketOwner = u.Username
|
||||
if err := saveServiceParams(params); err != nil {
|
||||
log.Errorf("persist socket owner: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *program) Stop(srv service.Service) error {
|
||||
p.serverInstanceMu.Lock()
|
||||
if p.serverInstance != nil {
|
||||
|
||||
@@ -67,6 +67,14 @@ func buildServiceArguments() []string {
|
||||
args = append(args, "--disable-networks")
|
||||
}
|
||||
|
||||
if socketOwner != "" {
|
||||
args = append(args, "--socket-owner", socketOwner)
|
||||
}
|
||||
|
||||
if strictSocketDisabled {
|
||||
args = append(args, "--disable-strict-socket")
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
@@ -127,6 +135,8 @@ var installCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("SUDO_UID: %s\n", os.Getenv("SUDO_UID"))
|
||||
|
||||
if err := loadAndApplyServiceParams(cmd); err != nil {
|
||||
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
|
||||
}
|
||||
|
||||
@@ -30,6 +30,8 @@ type serviceParams struct {
|
||||
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||
EnableCapture bool `json:"enable_capture,omitempty"`
|
||||
DisableNetworks bool `json:"disable_networks,omitempty"`
|
||||
SocketOwner string `json:"socket_owner,omitempty"`
|
||||
DisableStrictSocket bool `json:"disable_strict_socket,omitempty"`
|
||||
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||
}
|
||||
|
||||
@@ -82,6 +84,8 @@ func currentServiceParams() *serviceParams {
|
||||
DisableUpdateSettings: updateSettingsDisabled,
|
||||
EnableCapture: captureEnabled,
|
||||
DisableNetworks: networksDisabled,
|
||||
SocketOwner: socketOwner,
|
||||
DisableStrictSocket: strictSocketDisabled,
|
||||
}
|
||||
|
||||
if len(serviceEnvVars) > 0 {
|
||||
@@ -154,6 +158,14 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
||||
networksDisabled = params.DisableNetworks
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("socket-owner") {
|
||||
socketOwner = params.SocketOwner
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("disable-strict-socket") {
|
||||
strictSocketDisabled = params.DisableStrictSocket
|
||||
}
|
||||
|
||||
applyServiceEnvParams(cmd, params)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build cgo && !osusergo && !windows
|
||||
|
||||
package server
|
||||
package shell
|
||||
|
||||
import "os/user"
|
||||
|
||||
@@ -8,17 +8,22 @@ import "os/user"
|
||||
// When CGO is enabled, os/user uses libc (getpwnam_r) which goes through
|
||||
// the NSS stack natively. If it fails, the user truly doesn't exist and
|
||||
// getent would also fail.
|
||||
func lookupWithGetent(username string) (*user.User, error) {
|
||||
func LookupWithGetent(username string) (*user.User, error) {
|
||||
return user.Lookup(username)
|
||||
}
|
||||
|
||||
// currentUserWithGetent with CGO delegates directly to os/user.Current.
|
||||
func currentUserWithGetent() (*user.User, error) {
|
||||
func CurrentUserWithGetent() (*user.User, error) {
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
// LookupGroupWithGetent returns the resolved group from either a gid or groupname
|
||||
func LookupGroupWithGetent(name string) (*user.Group, error) {
|
||||
return user.LookupGroup(name)
|
||||
}
|
||||
|
||||
// groupIdsWithFallback with CGO delegates directly to user.GroupIds.
|
||||
// libc's getgrouplist handles NSS groups natively.
|
||||
func groupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
func GroupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
return u.GroupIds()
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build (!cgo || osusergo) && !windows
|
||||
|
||||
package server
|
||||
package shell
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
// lookupWithGetent looks up a user by name, falling back to getent if os/user fails.
|
||||
// Without CGO, os/user only reads /etc/passwd and misses NSS-provided users.
|
||||
// getent goes through the host's NSS stack.
|
||||
func lookupWithGetent(username string) (*user.User, error) {
|
||||
func LookupWithGetent(username string) (*user.User, error) {
|
||||
u, err := user.Lookup(username)
|
||||
if err == nil {
|
||||
return u, nil
|
||||
@@ -22,7 +22,7 @@ func lookupWithGetent(username string) (*user.User, error) {
|
||||
stdErr := err
|
||||
log.Debugf("os/user.Lookup(%q) failed, trying getent: %v", username, err)
|
||||
|
||||
u, _, getentErr := runGetent(username)
|
||||
u, _, getentErr := runGetentPasswd(username)
|
||||
if getentErr != nil {
|
||||
log.Debugf("getent fallback for %q also failed: %v", username, getentErr)
|
||||
return nil, stdErr
|
||||
@@ -31,8 +31,25 @@ func lookupWithGetent(username string) (*user.User, error) {
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// LookupGroupWithGetent returns the resolved group from either a gid or groupname
|
||||
func LookupGroupWithGetent(name string) (*user.Group, error) {
|
||||
g, err := user.LookupGroup(name)
|
||||
if err == nil {
|
||||
return g, nil
|
||||
}
|
||||
|
||||
stdErr := err
|
||||
log.Debugf("os/user.LookupGroup(%q) failed, trying getent: %v", name, err)
|
||||
g, getentErr := runGetentGroup(name)
|
||||
if getentErr != nil {
|
||||
log.Debugf("getent fallback for %q also failed: %v", name, getentErr)
|
||||
return nil, stdErr
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// currentUserWithGetent gets the current user, falling back to getent if os/user fails.
|
||||
func currentUserWithGetent() (*user.User, error) {
|
||||
func CurrentUserWithGetent() (*user.User, error) {
|
||||
u, err := user.Current()
|
||||
if err == nil {
|
||||
return u, nil
|
||||
@@ -42,7 +59,7 @@ func currentUserWithGetent() (*user.User, error) {
|
||||
uid := strconv.Itoa(os.Getuid())
|
||||
log.Debugf("os/user.Current() failed, trying getent with UID %s: %v", uid, err)
|
||||
|
||||
u, _, getentErr := runGetent(uid)
|
||||
u, _, getentErr := runGetentPasswd(uid)
|
||||
if getentErr != nil {
|
||||
return nil, stdErr
|
||||
}
|
||||
@@ -57,7 +74,7 @@ func currentUserWithGetent() (*user.User, error) {
|
||||
// only reads /etc/group and silently returns incomplete results for NSS users
|
||||
// (no error, just missing groups). The id command goes through NSS and returns
|
||||
// the full set.
|
||||
func groupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
func GroupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
ids, err := runIdGroups(u.Username)
|
||||
if err == nil {
|
||||
return ids, nil
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package shell
|
||||
|
||||
import (
|
||||
"os/user"
|
||||
@@ -15,7 +15,7 @@ func TestLookupWithGetent_CurrentUser(t *testing.T) {
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := lookupWithGetent(current.Username)
|
||||
u, err := LookupWithGetent(current.Username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, current.Username, u.Username)
|
||||
assert.Equal(t, current.Uid, u.Uid)
|
||||
@@ -23,7 +23,7 @@ func TestLookupWithGetent_CurrentUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLookupWithGetent_NonexistentUser(t *testing.T) {
|
||||
_, err := lookupWithGetent("nonexistent_user_xyzzy_12345")
|
||||
_, err := LookupWithGetent("nonexistent_user_xyzzy_12345")
|
||||
require.Error(t, err, "should fail for nonexistent user")
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestCurrentUserWithGetent(t *testing.T) {
|
||||
stdUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := currentUserWithGetent()
|
||||
u, err := CurrentUserWithGetent()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, stdUser.Uid, u.Uid)
|
||||
assert.Equal(t, stdUser.Username, u.Username)
|
||||
@@ -41,7 +41,7 @@ func TestGroupIdsWithFallback_CurrentUser(t *testing.T) {
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, err := groupIdsWithFallback(current)
|
||||
groups, err := GroupIdsWithFallback(current)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, groups, "current user should have at least one group")
|
||||
|
||||
@@ -56,7 +56,7 @@ func TestGroupIdsWithFallback_CurrentUser(t *testing.T) {
|
||||
func TestGetShellFromGetent_CurrentUser(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
// Windows stub always returns empty, which is correct
|
||||
shell := getShellFromGetent("1000")
|
||||
shell := GetShellFromGetent("1000")
|
||||
assert.Empty(t, shell, "Windows stub should return empty")
|
||||
return
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func TestGetShellFromGetent_CurrentUser(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// getent may not be available on all systems (e.g., macOS without Homebrew getent)
|
||||
shell := getShellFromGetent(current.Uid)
|
||||
shell := GetShellFromGetent(current.Uid)
|
||||
if shell == "" {
|
||||
t.Log("getShellFromGetent returned empty, getent may not be available")
|
||||
return
|
||||
@@ -78,7 +78,7 @@ func TestLookupWithGetent_RootUser(t *testing.T) {
|
||||
t.Skip("no root user on Windows")
|
||||
}
|
||||
|
||||
u, err := lookupWithGetent("root")
|
||||
u, err := LookupWithGetent("root")
|
||||
if err != nil {
|
||||
t.Skip("root user not available on this system")
|
||||
}
|
||||
@@ -91,20 +91,20 @@ func TestLookupWithGetent_RootUser(t *testing.T) {
|
||||
// consistent and correct results when composed together.
|
||||
func TestIntegration_FullLookupChain(t *testing.T) {
|
||||
// Step 1: currentUserWithGetent must resolve the running user.
|
||||
current, err := currentUserWithGetent()
|
||||
current, err := CurrentUserWithGetent()
|
||||
require.NoError(t, err, "currentUserWithGetent must resolve the running user")
|
||||
require.NotEmpty(t, current.Uid)
|
||||
require.NotEmpty(t, current.Username)
|
||||
|
||||
// Step 2: lookupWithGetent by the same username must return matching identity.
|
||||
byName, err := lookupWithGetent(current.Username)
|
||||
byName, err := LookupWithGetent(current.Username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, current.Uid, byName.Uid, "lookup by name should return same UID")
|
||||
assert.Equal(t, current.Gid, byName.Gid, "lookup by name should return same GID")
|
||||
assert.Equal(t, current.HomeDir, byName.HomeDir, "lookup by name should return same home")
|
||||
|
||||
// Step 3: groupIdsWithFallback must return at least the primary GID.
|
||||
groups, err := groupIdsWithFallback(current)
|
||||
groups, err := GroupIdsWithFallback(current)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, groups, "user must have at least one group")
|
||||
|
||||
@@ -123,7 +123,7 @@ func TestIntegration_FullLookupChain(t *testing.T) {
|
||||
// Step 4: getShellFromGetent should either return a valid shell path or empty
|
||||
// (empty is OK when getent is not available, e.g. macOS without Homebrew getent).
|
||||
if runtime.GOOS != "windows" {
|
||||
shell := getShellFromGetent(current.Uid)
|
||||
shell := GetShellFromGetent(current.Uid)
|
||||
if shell != "" {
|
||||
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
|
||||
}
|
||||
@@ -138,10 +138,10 @@ func TestIntegration_LookupAndGroupsConsistency(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate the SSH server flow: lookup user, then get their groups.
|
||||
resolved, err := lookupWithGetent(current.Username)
|
||||
resolved, err := LookupWithGetent(current.Username)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, err := groupIdsWithFallback(resolved)
|
||||
groups, err := GroupIdsWithFallback(resolved)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, groups, "resolved user must have groups")
|
||||
|
||||
@@ -154,19 +154,3 @@ func TestIntegration_LookupAndGroupsConsistency(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_ShellLookupChain tests the full shell resolution chain
|
||||
// (getShellFromPasswd -> getShellFromGetent -> $SHELL -> default) on Unix.
|
||||
func TestIntegration_ShellLookupChain(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Unix shell lookup not applicable on Windows")
|
||||
}
|
||||
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// getUserShell is the top-level function used by the SSH server.
|
||||
shell := getUserShell(current.Uid)
|
||||
require.NotEmpty(t, shell, "getUserShell must always return a shell")
|
||||
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
package shell
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -14,19 +14,25 @@ import (
|
||||
|
||||
const getentTimeout = 5 * time.Second
|
||||
|
||||
// getShellFromGetent gets a user's login shell via getent by UID.
|
||||
// GetShellFromGetent gets a user's login shell via getent by UID.
|
||||
// This is needed even with CGO because getShellFromPasswd reads /etc/passwd
|
||||
// directly and won't find NSS-provided users there.
|
||||
func getShellFromGetent(userID string) string {
|
||||
_, shell, err := runGetent(userID)
|
||||
func GetShellFromGetent(userID string) string {
|
||||
_, shell, err := runGetentPasswd(userID)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return shell
|
||||
}
|
||||
|
||||
// runGetent executes `getent passwd <query>` and returns the user and login shell.
|
||||
func runGetent(query string) (*user.User, string, error) {
|
||||
// GetUserFromGetent returns the resolved group from either a uid or username
|
||||
func GetUserFromGetent(user string) (*user.User, error) {
|
||||
u, _, err := runGetentPasswd(user)
|
||||
return u, err
|
||||
}
|
||||
|
||||
// runGetentPasswd executes `getent passwd <query>` and returns the user and login shell.
|
||||
func runGetentPasswd(query string) (*user.User, string, error) {
|
||||
if !validateGetentInput(query) {
|
||||
return nil, "", fmt.Errorf("invalid getent input: %q", query)
|
||||
}
|
||||
@@ -42,6 +48,23 @@ func runGetent(query string) (*user.User, string, error) {
|
||||
return parseGetentPasswd(string(out))
|
||||
}
|
||||
|
||||
// runGetentGroup executes `getent group <query>` and returns the group
|
||||
func runGetentGroup(query string) (*user.Group, error) {
|
||||
if !validateGetentInput(query) {
|
||||
return nil, fmt.Errorf("invalid getent input: %q", query)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), getentTimeout)
|
||||
defer cancel()
|
||||
|
||||
out, err := exec.CommandContext(ctx, "getent", "group", query).Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getent passwd%s: %w", query, err)
|
||||
}
|
||||
|
||||
return parseGetentGroup(string(out))
|
||||
}
|
||||
|
||||
// parseGetentPasswd parses getent passwd output: "name:x:uid:gid:gecos:home:shell"
|
||||
func parseGetentPasswd(output string) (*user.User, string, error) {
|
||||
fields := strings.SplitN(strings.TrimSpace(output), ":", 8)
|
||||
@@ -67,6 +90,20 @@ func parseGetentPasswd(output string) (*user.User, string, error) {
|
||||
}, shell, nil
|
||||
}
|
||||
|
||||
// parseGetentGroup parses getent group output: "group:x:gid:user"
|
||||
func parseGetentGroup(output string) (*user.Group, error) {
|
||||
fields := strings.SplitN(strings.TrimSpace(output), ":", 8)
|
||||
if len(fields) < 4 {
|
||||
return nil, fmt.Errorf("unexpected getent output (need 4+ fields): %q", output)
|
||||
}
|
||||
|
||||
if fields[0] == "" || fields[2] == "" {
|
||||
return nil, fmt.Errorf("missing required fields in getent output: %q", output)
|
||||
}
|
||||
|
||||
return &user.Group{Gid: fields[2], Name: fields[0]}, nil
|
||||
}
|
||||
|
||||
// validateGetentInput checks that the input is safe to pass to getent or id.
|
||||
// Allows POSIX usernames, numeric UIDs, and common NSS extensions
|
||||
// (@ for Kerberos, $ for Samba, + for NIS compat).
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
package shell
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
@@ -195,7 +195,7 @@ func TestRunGetent_RootUser(t *testing.T) {
|
||||
t.Skip("getent not available on this system")
|
||||
}
|
||||
|
||||
u, shell, err := runGetent("root")
|
||||
u, shell, err := runGetentPasswd("root")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "root", u.Username)
|
||||
assert.Equal(t, "0", u.Uid)
|
||||
@@ -208,7 +208,7 @@ func TestRunGetent_ByUID(t *testing.T) {
|
||||
t.Skip("getent not available on this system")
|
||||
}
|
||||
|
||||
u, _, err := runGetent("0")
|
||||
u, _, err := runGetentPasswd("0")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "root", u.Username)
|
||||
assert.Equal(t, "0", u.Uid)
|
||||
@@ -219,15 +219,15 @@ func TestRunGetent_NonexistentUser(t *testing.T) {
|
||||
t.Skip("getent not available on this system")
|
||||
}
|
||||
|
||||
_, _, err := runGetent("nonexistent_user_xyzzy_12345")
|
||||
_, _, err := runGetentPasswd("nonexistent_user_xyzzy_12345")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRunGetent_InvalidInput(t *testing.T) {
|
||||
_, _, err := runGetent("")
|
||||
_, _, err := runGetentPasswd("")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, _, err = runGetent("user\x00name")
|
||||
_, _, err = runGetentPasswd("user\x00name")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -236,7 +236,7 @@ func TestRunGetent_NotAvailable(t *testing.T) {
|
||||
t.Skip("getent is available, can't test missing case")
|
||||
}
|
||||
|
||||
_, _, err := runGetent("root")
|
||||
_, _, err := runGetentPasswd("root")
|
||||
assert.Error(t, err, "should fail when getent is not installed")
|
||||
}
|
||||
|
||||
@@ -283,7 +283,7 @@ func TestGetentResultsMatchStdlib(t *testing.T) {
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
getentUser, _, err := runGetent(current.Username)
|
||||
getentUser, _, err := runGetentPasswd(current.Username)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, current.Username, getentUser.Username, "username should match")
|
||||
@@ -300,7 +300,7 @@ func TestGetentResultsMatchStdlib_ByUID(t *testing.T) {
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
getentUser, _, err := runGetent(current.Uid)
|
||||
getentUser, _, err := runGetentPasswd(current.Uid)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, current.Username, getentUser.Username, "username should match when looked up by UID")
|
||||
@@ -356,7 +356,7 @@ func TestGetShellFromPasswd_CurrentUser(t *testing.T) {
|
||||
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
|
||||
|
||||
if _, err := exec.LookPath("getent"); err == nil {
|
||||
_, getentShell, getentErr := runGetent(current.Uid)
|
||||
_, getentShell, getentErr := runGetentPasswd(current.Uid)
|
||||
if getentErr == nil && getentShell != "" {
|
||||
assert.Equal(t, getentShell, shell, "shell from /etc/passwd should match getent")
|
||||
}
|
||||
@@ -400,7 +400,7 @@ func TestGetShellFromPasswd_MatchesGetentForKnownUsers(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
|
||||
_, getentShell, err := runGetent(uid)
|
||||
_, getentShell, err := runGetentPasswd(uid)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -1,26 +1,26 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
package shell
|
||||
|
||||
import "os/user"
|
||||
|
||||
// lookupWithGetent on Windows just delegates to os/user.Lookup.
|
||||
// Windows does not use NSS/getent; its user lookup works without CGO.
|
||||
func lookupWithGetent(username string) (*user.User, error) {
|
||||
func LookupWithGetent(username string) (*user.User, error) {
|
||||
return user.Lookup(username)
|
||||
}
|
||||
|
||||
// currentUserWithGetent on Windows just delegates to os/user.Current.
|
||||
func currentUserWithGetent() (*user.User, error) {
|
||||
func CurrentUserWithGetent() (*user.User, error) {
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
// getShellFromGetent is a no-op on Windows; shell resolution uses PowerShell detection.
|
||||
func getShellFromGetent(_ string) string {
|
||||
func GetShellFromGetent(_ string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// groupIdsWithFallback on Windows just delegates to u.GroupIds().
|
||||
func groupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
func GroupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
return u.GroupIds()
|
||||
}
|
||||
@@ -1,17 +1,14 @@
|
||||
package server
|
||||
package shell
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -24,7 +21,7 @@ const (
|
||||
|
||||
// getUserShell returns the appropriate shell for the given user ID
|
||||
// Handles all platform-specific logic and fallbacks consistently
|
||||
func getUserShell(userID string) string {
|
||||
func GetUserShell(userID string) string {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return getWindowsUserShell()
|
||||
@@ -56,7 +53,7 @@ func getUnixUserShell(userID string) string {
|
||||
return shell
|
||||
}
|
||||
|
||||
if shell := getShellFromGetent(userID); shell != "" {
|
||||
if shell := GetShellFromGetent(userID); shell != "" {
|
||||
return shell
|
||||
}
|
||||
|
||||
@@ -101,8 +98,8 @@ func getShellFromPasswd(userID string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// prepareUserEnv prepares environment variables for user execution
|
||||
func prepareUserEnv(user *user.User, shell string) []string {
|
||||
// PrepareUserEnv prepares environment variables for user execution
|
||||
func PrepareUserEnv(user *user.User, shell string) []string {
|
||||
pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games"
|
||||
if runtime.GOOS == "windows" {
|
||||
pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`
|
||||
@@ -119,7 +116,7 @@ func prepareUserEnv(user *user.User, shell string) []string {
|
||||
|
||||
// acceptEnv checks if environment variable from SSH client should be accepted
|
||||
// This is a whitelist of variables that SSH clients can send to the server
|
||||
func acceptEnv(envVar string) bool {
|
||||
func AcceptEnv(envVar string) bool {
|
||||
varName := envVar
|
||||
if idx := strings.Index(envVar, "="); idx != -1 {
|
||||
varName = envVar[:idx]
|
||||
@@ -156,29 +153,3 @@ func acceptEnv(envVar string) bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// prepareSSHEnv prepares SSH protocol-specific environment variables
|
||||
// These variables provide information about the SSH connection itself
|
||||
func prepareSSHEnv(session ssh.Session) []string {
|
||||
remoteAddr := session.RemoteAddr()
|
||||
localAddr := session.LocalAddr()
|
||||
|
||||
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
|
||||
if err != nil {
|
||||
remoteHost = remoteAddr.String()
|
||||
remotePort = "0"
|
||||
}
|
||||
|
||||
localHost, localPort, err := net.SplitHostPort(localAddr.String())
|
||||
if err != nil {
|
||||
localHost = localAddr.String()
|
||||
localPort = strconv.Itoa(InternalSSHPort)
|
||||
}
|
||||
|
||||
return []string{
|
||||
// SSH_CLIENT format: "client_ip client_port server_port"
|
||||
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
|
||||
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
|
||||
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
|
||||
}
|
||||
}
|
||||
26
client/internal/shell/shell_test.go
Normal file
26
client/internal/shell/shell_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package shell
|
||||
|
||||
import (
|
||||
"os/user"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestIntegration_ShellLookupChain tests the full shell resolution chain
|
||||
// (getShellFromPasswd -> getShellFromGetent -> $SHELL -> default) on Unix.
|
||||
func TestIntegration_ShellLookupChain(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Unix shell lookup not applicable on Windows")
|
||||
}
|
||||
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// getUserShell is the top-level function used by the SSH server.
|
||||
shell := GetUserShell(current.Uid)
|
||||
require.NotEmpty(t, shell, "getUserShell must always return a shell")
|
||||
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
|
||||
}
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -146,10 +147,10 @@ func (s *Server) createShellCommand(ctx context.Context, shell string, args []st
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Unix
|
||||
func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string {
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
if shell.AcceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,10 +247,10 @@ func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, sess
|
||||
userEnv, err := s.getUserEnvironment(logger, username, domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
if shell.AcceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
@@ -260,7 +260,7 @@ func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, sess
|
||||
env := userEnv
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
if shell.AcceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
@@ -273,7 +273,7 @@ func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privileg
|
||||
return false
|
||||
}
|
||||
|
||||
shell := getUserShell(privilegeResult.User.Uid)
|
||||
shell := shell.GetUserShell(privilegeResult.User.Uid)
|
||||
logger.Infof("starting interactive shell: %s", shell)
|
||||
|
||||
s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil)
|
||||
@@ -384,7 +384,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _
|
||||
}
|
||||
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
shell := getUserShell(localUser.Uid)
|
||||
shell := shell.GetUserShell(localUser.Uid)
|
||||
|
||||
req := PtyExecutionRequest{
|
||||
Shell: shell,
|
||||
|
||||
@@ -3,11 +3,15 @@ package server
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -23,8 +27,8 @@ func isPlatformUnix() bool {
|
||||
|
||||
// Dependency injection variables for testing - allows mocking dynamic runtime checks
|
||||
var (
|
||||
getCurrentUser = currentUserWithGetent
|
||||
lookupUser = lookupWithGetent
|
||||
getCurrentUser = shell.CurrentUserWithGetent
|
||||
lookupUser = shell.LookupWithGetent
|
||||
getCurrentOS = func() string { return runtime.GOOS }
|
||||
getIsProcessPrivileged = isCurrentProcessPrivileged
|
||||
|
||||
@@ -409,3 +413,29 @@ func isWindowsElevated() bool {
|
||||
log.Debugf("Windows user switching not supported: not running as privileged user (current: %s)", currentUser.Uid)
|
||||
return false
|
||||
}
|
||||
|
||||
// prepareSSHEnv prepares SSH protocol-specific environment variables
|
||||
// These variables provide information about the SSH connection itself
|
||||
func prepareSSHEnv(session ssh.Session) []string {
|
||||
remoteAddr := session.RemoteAddr()
|
||||
localAddr := session.LocalAddr()
|
||||
|
||||
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
|
||||
if err != nil {
|
||||
remoteHost = remoteAddr.String()
|
||||
remotePort = "0"
|
||||
}
|
||||
|
||||
localHost, localPort, err := net.SplitHostPort(localAddr.String())
|
||||
if err != nil {
|
||||
localHost = localAddr.String()
|
||||
localPort = strconv.Itoa(InternalSSHPort)
|
||||
}
|
||||
|
||||
return []string{
|
||||
// SSH_CLIENT format: "client_ip client_port server_port"
|
||||
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
|
||||
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
|
||||
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -160,7 +161,7 @@ func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []u
|
||||
// getSupplementaryGroups retrieves supplementary group IDs for a user.
|
||||
// Uses id/getent fallback for NSS users in CGO_ENABLED=0 builds.
|
||||
func (s *Server) getSupplementaryGroups(u *user.User) ([]uint32, error) {
|
||||
groupIDStrings, err := groupIdsWithFallback(u)
|
||||
groupIDStrings, err := shell.GroupIdsWithFallback(u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group IDs for user %s: %w", u.Username, err)
|
||||
}
|
||||
@@ -196,7 +197,7 @@ func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, l
|
||||
GID: gid,
|
||||
Groups: groups,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: getUserShell(localUser.Uid),
|
||||
Shell: shell.GetUserShell(localUser.Uid),
|
||||
Command: session.RawCommand(),
|
||||
PTY: hasPty,
|
||||
}
|
||||
@@ -228,7 +229,7 @@ func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq s
|
||||
func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.User, ptyReq ssh.Pty) *exec.Cmd {
|
||||
log.Debugf("creating direct Pty command for user %s (no user switching needed)", localUser.Username)
|
||||
|
||||
shell := getUserShell(localUser.Uid)
|
||||
shell := shell.GetUserShell(localUser.Uid)
|
||||
args := s.getShellCommandArgs(shell, session.RawCommand())
|
||||
|
||||
cmd := s.createShellCommand(session.Context(), shell, args)
|
||||
@@ -245,12 +246,12 @@ func (s *Server) preparePtyEnv(localUser *user.User, ptyReq ssh.Pty, session ssh
|
||||
termType = "xterm-256color"
|
||||
}
|
||||
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
env = append(env, fmt.Sprintf("TERM=%s", termType))
|
||||
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
if shell.AcceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
)
|
||||
|
||||
// validateUsername validates Windows usernames according to SAM Account Name rules
|
||||
@@ -104,7 +106,7 @@ func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, l
|
||||
func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) {
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
|
||||
shell := getUserShell(localUser.Uid)
|
||||
sh := shell.GetUserShell(localUser.Uid)
|
||||
|
||||
rawCmd := session.RawCommand()
|
||||
var command string
|
||||
@@ -116,7 +118,7 @@ func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session,
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: shell,
|
||||
Shell: sh,
|
||||
Command: command,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user