Compare commits

...

3 Commits

Author SHA1 Message Date
Theodor S. Midtlien
d65927275d Refactor shell package and use getent for user/group lookup 2026-06-11 17:40:45 +02:00
Theodor S. Midtlien
064f7bf0fd WIP TOFU socket ownership 2026-06-10 17:40:17 +02:00
Theodor S. Midtlien
644615fed6 WIP test 2026-06-10 17:29:00 +02:00
21 changed files with 504 additions and 117 deletions

View File

@@ -0,0 +1,36 @@
//go:build darwin || freebsd
package cmd
import (
"fmt"
"net"
"golang.org/x/sys/unix"
)
// peerUID returns the uid of the process on the other end of a unix socket
// connection, read via LOCAL_PEERCRED (xucred). Note: xucred carries the uid
// and group list but no pid, so audit on these platforms is uid-based.
func peerUID(c net.Conn) (int, error) {
uc, ok := c.(*net.UnixConn)
if !ok {
return 0, fmt.Errorf("connection is not a unix socket: %T", c)
}
raw, err := uc.SyscallConn()
if err != nil {
return 0, fmt.Errorf("raw conn: %w", err)
}
var cred *unix.Xucred
var credErr error
if err := raw.Control(func(fd uintptr) {
cred, credErr = unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
}); err != nil {
return 0, fmt.Errorf("getsockopt control: %w", err)
}
if credErr != nil {
return 0, fmt.Errorf("LOCAL_PEERCRED: %w", credErr)
}
return int(cred.Uid), nil
}

View File

@@ -0,0 +1,35 @@
//go:build linux
package cmd
import (
"fmt"
"net"
"golang.org/x/sys/unix"
)
// peerUID returns the uid of the process on the other end of a unix socket
// connection, read from the kernel via SO_PEERCRED.
func peerUID(c net.Conn) (int, error) {
uc, ok := c.(*net.UnixConn)
if !ok {
return 0, fmt.Errorf("connection is not a unix socket: %T", c)
}
raw, err := uc.SyscallConn()
if err != nil {
return 0, fmt.Errorf("raw conn: %w", err)
}
var cred *unix.Ucred
var credErr error
if err := raw.Control(func(fd uintptr) {
cred, credErr = unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
}); err != nil {
return 0, fmt.Errorf("getsockopt control: %w", err)
}
if credErr != nil {
return 0, fmt.Errorf("SO_PEERCRED: %w", credErr)
}
return int(cred.Uid), nil
}

View File

@@ -0,0 +1,16 @@
//go:build !linux && !darwin && !freebsd
package cmd
import (
"fmt"
"net"
"runtime"
)
// peerUID is unimplemented on this platform, so the trust-on-first-use socket
// migration cannot run here. Configure --socket-owner explicitly, or use
// --disable-strict-socket. (Windows uses a TCP socket and never reaches this.)
func peerUID(net.Conn) (int, error) {
return 0, fmt.Errorf("peer credential check not supported on %s", runtime.GOOS)
}

View File

@@ -77,6 +77,8 @@ var (
updateSettingsDisabled bool
captureEnabled bool
networksDisabled bool
socketOwner string
strictSocketDisabled bool
rootCmd = &cobra.Command{
Use: "netbird",

View File

@@ -57,6 +57,9 @@ func init() {
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
serviceCmd.PersistentFlags().StringVar(&socketOwner, "socket-owner", "", "user to own the daemon control socket; restricts it to that user plus the netbird group (0660). If unset, the first client to connect claims ownership (trust-on-first-use)")
serviceCmd.PersistentFlags().BoolVar(&strictSocketDisabled, "disable-strict-socket", false, "leave the daemon control socket world-writable (0666) instead of restricting it; set via the (root-only) service command")
rootCmd.AddCommand(serviceCmd)
}

View File

@@ -4,10 +4,15 @@ package cmd
import (
"context"
"errors"
"fmt"
"net"
"os"
"os/exec"
"os/user"
"strconv"
"strings"
"sync"
"time"
"github.com/kardianos/service"
@@ -16,6 +21,7 @@ import (
"github.com/spf13/cobra"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/client/internal/shell"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/client/system"
@@ -54,10 +60,36 @@ func (p *program) Start(svc service.Service) error {
go func() {
defer listen.Close()
srvListener := listen
if split[0] == "unix" {
if err := os.Chmod(split[1], 0666); err != nil {
log.Errorf("failed setting daemon permissions: %v", split[1])
return
owner := effectiveSocketOwner()
switch {
case strictSocketDisabled:
// Opt-out (root-only, via service.json): leave it world-writable.
if err := os.Chmod(split[1], 0666); err != nil {
log.Errorf("failed setting daemon permissions: %v", split[1])
return
}
case owner != "":
// Seeded owner (flag, MDM, or persisted TOFU result): restrict
// before serving so there is no open window.
uid, err := lookupUser(owner)
if err != nil {
log.Errorf("lookup socket owner %q: %v", owner, err)
return
}
if err := restrictSocket(split[1], uid); err != nil {
log.Errorf("restrict socket to %q: %v", owner, err)
return
}
default:
// Trust-on-first-use: open the socket now; tofuListener locks it
// to the first caller's uid on the first connection.
if err := os.Chmod(split[1], 0666); err != nil {
log.Errorf("failed setting daemon permissions: %v", split[1])
return
}
srvListener = &tofuListener{Listener: listen, path: split[1], owner: -1}
}
}
@@ -72,13 +104,180 @@ func (p *program) Start(svc service.Service) error {
p.serverInstanceMu.Unlock()
log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(listen); err != nil {
if err := p.serv.Serve(srvListener); err != nil {
log.Errorf("failed to serve daemon requests: %v", err)
}
}()
return nil
}
func lookupUser(username string) (int, error) {
u, err := shell.LookupWithGetent(username)
if err != nil {
return -1, fmt.Errorf("lookup user %s: %w", username, err)
}
uid, err := strconv.Atoi(u.Uid)
if err != nil {
return -1, fmt.Errorf("parse uid %s: %w", u.Uid, err)
}
return uid, nil
}
// addGroup creates a system group if it doesn't already exist and returns the gid.
// Must run as root.
func addGroup(name string) (int, error) {
group, err := shell.LookupGroupWithGetent(name)
if err == nil {
gid, err := strconv.ParseInt(group.Gid, 10, 64)
return int(gid), err
}
// looup failed, create the group
groupadd, err := exec.LookPath("groupadd")
if err != nil {
// Fallback for Alpine/BusyBox systems.
if groupadd, err = exec.LookPath("addgroup"); err != nil {
return -1, errors.New("neither groupadd nor addgroup found")
}
}
// Use --system for a service/daemon group (no login, low GID).
out, err := exec.Command(groupadd, "--system", name).CombinedOutput()
if err != nil {
return -1, fmt.Errorf("create group %q: %w: %s", name, err, out)
}
if group, err := shell.LookupWithGetent(name); err == nil {
gid, err := strconv.ParseInt(group.Gid, 10, 64)
return int(gid), err
}
return -1, fmt.Errorf("lookup group %q: %w", name, err)
}
// restrictSocket locks the unix socket down to the owner uid plus the netbird
// group (0660). If the group cannot be created or applied, it fails closed to
// owner-only 0600 — it never leaves the socket world-writable.
func restrictSocket(path string, uid int) error {
// TODO: introduce flag to configure this (LDAP/AD usecase)
gid, err := addGroup("netbird")
if err != nil {
log.Errorf("create netbird group, failing closed to owner-only 0600: %v", err)
return chownChmod(path, uid, -1, 0600)
}
if err := chownChmod(path, uid, gid, 0660); err != nil {
log.Errorf("apply netbird group to socket, failing closed to owner-only 0600: %v", err)
return chownChmod(path, uid, -1, 0600)
}
return nil
}
// chownChmod sets ownership and mode on the socket. A gid of -1 leaves the
// group unchanged.
func chownChmod(path string, uid, gid int, mode os.FileMode) error {
if err := os.Chown(path, uid, gid); err != nil {
return fmt.Errorf("chown socket %s: %w", path, err)
}
if err := os.Chmod(path, mode); err != nil {
return fmt.Errorf("chmod socket %s: %w", path, err)
}
return nil
}
// tofuListener implements trust-on-first-use for the daemon control socket.
// The socket starts world-writable; the first caller's uid (read via the
// platform peer-credential mechanism) becomes the owner. On that first
// connection the socket is restricted (see restrictSocket) and the owner is
// persisted so the open window never reopens on later starts. Connections that
// raced in during the open window and are neither the owner nor root are
// dropped. Changing the socket mode does not disturb the already-open
// connection, so the first caller's request is served normally.
type tofuListener struct {
net.Listener
path string
mu sync.Mutex
owner int // -1 until claimed
}
func (l *tofuListener) Accept() (net.Conn, error) {
for {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
uid, err := peerUID(c)
if err != nil {
log.Errorf("read peer credentials, dropping connection: %v", err)
_ = c.Close()
continue
}
l.mu.Lock()
if l.owner == -1 {
if err := restrictSocket(l.path, uid); err != nil {
l.mu.Unlock()
_ = c.Close()
// Refuse to serve on a socket we could not lock down.
return nil, fmt.Errorf("restrict socket on first connection: %w", err)
}
l.owner = uid
persistSocketOwner(uid)
log.Infof("control socket restricted to first caller (uid %d)", uid)
l.mu.Unlock()
return c, nil
}
owner := l.owner
l.mu.Unlock()
// New connects are already gated by the 0660 perms set above; this only
// drops anything that slipped in during the brief open window.
if uid != owner && uid != 0 {
log.Warnf("dropping non-owner connection (uid %d) during socket bootstrap", uid)
_ = c.Close()
continue
}
return c, nil
}
}
// effectiveSocketOwner returns the configured socket owner: the --socket-owner
// flag when set, otherwise the owner persisted by a previous TOFU migration.
func effectiveSocketOwner() string {
if socketOwner != "" {
return socketOwner
}
params, err := loadServiceParams()
if err != nil {
log.Errorf("load service params for socket owner: %v", err)
return ""
}
if params != nil {
return params.SocketOwner
}
return ""
}
// persistSocketOwner records the TOFU-selected owner (by username) so the next
// daemon start restricts the socket immediately, with no open window.
func persistSocketOwner(uid int) {
u, err := user.LookupId(strconv.Itoa(uid))
if err != nil {
log.Errorf("resolve uid %d to username for persistence: %v", uid, err)
return
}
params, err := loadServiceParams()
if err != nil {
log.Errorf("load service params to persist socket owner: %v", err)
return
}
if params == nil {
params = currentServiceParams()
}
params.SocketOwner = u.Username
if err := saveServiceParams(params); err != nil {
log.Errorf("persist socket owner: %v", err)
}
}
func (p *program) Stop(srv service.Service) error {
p.serverInstanceMu.Lock()
if p.serverInstance != nil {

View File

@@ -67,6 +67,14 @@ func buildServiceArguments() []string {
args = append(args, "--disable-networks")
}
if socketOwner != "" {
args = append(args, "--socket-owner", socketOwner)
}
if strictSocketDisabled {
args = append(args, "--disable-strict-socket")
}
return args
}
@@ -127,6 +135,8 @@ var installCmd = &cobra.Command{
return err
}
cmd.Printf("SUDO_UID: %s\n", os.Getenv("SUDO_UID"))
if err := loadAndApplyServiceParams(cmd); err != nil {
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
}

View File

@@ -30,6 +30,8 @@ type serviceParams struct {
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
EnableCapture bool `json:"enable_capture,omitempty"`
DisableNetworks bool `json:"disable_networks,omitempty"`
SocketOwner string `json:"socket_owner,omitempty"`
DisableStrictSocket bool `json:"disable_strict_socket,omitempty"`
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
}
@@ -82,6 +84,8 @@ func currentServiceParams() *serviceParams {
DisableUpdateSettings: updateSettingsDisabled,
EnableCapture: captureEnabled,
DisableNetworks: networksDisabled,
SocketOwner: socketOwner,
DisableStrictSocket: strictSocketDisabled,
}
if len(serviceEnvVars) > 0 {
@@ -154,6 +158,14 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
networksDisabled = params.DisableNetworks
}
if !serviceCmd.PersistentFlags().Changed("socket-owner") {
socketOwner = params.SocketOwner
}
if !serviceCmd.PersistentFlags().Changed("disable-strict-socket") {
strictSocketDisabled = params.DisableStrictSocket
}
applyServiceEnvParams(cmd, params)
}

View File

@@ -1,6 +1,6 @@
//go:build cgo && !osusergo && !windows
package server
package shell
import "os/user"
@@ -8,17 +8,22 @@ import "os/user"
// When CGO is enabled, os/user uses libc (getpwnam_r) which goes through
// the NSS stack natively. If it fails, the user truly doesn't exist and
// getent would also fail.
func lookupWithGetent(username string) (*user.User, error) {
func LookupWithGetent(username string) (*user.User, error) {
return user.Lookup(username)
}
// currentUserWithGetent with CGO delegates directly to os/user.Current.
func currentUserWithGetent() (*user.User, error) {
func CurrentUserWithGetent() (*user.User, error) {
return user.Current()
}
// LookupGroupWithGetent returns the resolved group from either a gid or groupname
func LookupGroupWithGetent(name string) (*user.Group, error) {
return user.LookupGroup(name)
}
// groupIdsWithFallback with CGO delegates directly to user.GroupIds.
// libc's getgrouplist handles NSS groups natively.
func groupIdsWithFallback(u *user.User) ([]string, error) {
func GroupIdsWithFallback(u *user.User) ([]string, error) {
return u.GroupIds()
}

View File

@@ -1,6 +1,6 @@
//go:build (!cgo || osusergo) && !windows
package server
package shell
import (
"os"
@@ -13,7 +13,7 @@ import (
// lookupWithGetent looks up a user by name, falling back to getent if os/user fails.
// Without CGO, os/user only reads /etc/passwd and misses NSS-provided users.
// getent goes through the host's NSS stack.
func lookupWithGetent(username string) (*user.User, error) {
func LookupWithGetent(username string) (*user.User, error) {
u, err := user.Lookup(username)
if err == nil {
return u, nil
@@ -22,7 +22,7 @@ func lookupWithGetent(username string) (*user.User, error) {
stdErr := err
log.Debugf("os/user.Lookup(%q) failed, trying getent: %v", username, err)
u, _, getentErr := runGetent(username)
u, _, getentErr := runGetentPasswd(username)
if getentErr != nil {
log.Debugf("getent fallback for %q also failed: %v", username, getentErr)
return nil, stdErr
@@ -31,8 +31,25 @@ func lookupWithGetent(username string) (*user.User, error) {
return u, nil
}
// LookupGroupWithGetent returns the resolved group from either a gid or groupname
func LookupGroupWithGetent(name string) (*user.Group, error) {
g, err := user.LookupGroup(name)
if err == nil {
return g, nil
}
stdErr := err
log.Debugf("os/user.LookupGroup(%q) failed, trying getent: %v", name, err)
g, getentErr := runGetentGroup(name)
if getentErr != nil {
log.Debugf("getent fallback for %q also failed: %v", name, getentErr)
return nil, stdErr
}
return g, nil
}
// currentUserWithGetent gets the current user, falling back to getent if os/user fails.
func currentUserWithGetent() (*user.User, error) {
func CurrentUserWithGetent() (*user.User, error) {
u, err := user.Current()
if err == nil {
return u, nil
@@ -42,7 +59,7 @@ func currentUserWithGetent() (*user.User, error) {
uid := strconv.Itoa(os.Getuid())
log.Debugf("os/user.Current() failed, trying getent with UID %s: %v", uid, err)
u, _, getentErr := runGetent(uid)
u, _, getentErr := runGetentPasswd(uid)
if getentErr != nil {
return nil, stdErr
}
@@ -57,7 +74,7 @@ func currentUserWithGetent() (*user.User, error) {
// only reads /etc/group and silently returns incomplete results for NSS users
// (no error, just missing groups). The id command goes through NSS and returns
// the full set.
func groupIdsWithFallback(u *user.User) ([]string, error) {
func GroupIdsWithFallback(u *user.User) ([]string, error) {
ids, err := runIdGroups(u.Username)
if err == nil {
return ids, nil

View File

@@ -1,4 +1,4 @@
package server
package shell
import (
"os/user"
@@ -15,7 +15,7 @@ func TestLookupWithGetent_CurrentUser(t *testing.T) {
current, err := user.Current()
require.NoError(t, err)
u, err := lookupWithGetent(current.Username)
u, err := LookupWithGetent(current.Username)
require.NoError(t, err)
assert.Equal(t, current.Username, u.Username)
assert.Equal(t, current.Uid, u.Uid)
@@ -23,7 +23,7 @@ func TestLookupWithGetent_CurrentUser(t *testing.T) {
}
func TestLookupWithGetent_NonexistentUser(t *testing.T) {
_, err := lookupWithGetent("nonexistent_user_xyzzy_12345")
_, err := LookupWithGetent("nonexistent_user_xyzzy_12345")
require.Error(t, err, "should fail for nonexistent user")
}
@@ -31,7 +31,7 @@ func TestCurrentUserWithGetent(t *testing.T) {
stdUser, err := user.Current()
require.NoError(t, err)
u, err := currentUserWithGetent()
u, err := CurrentUserWithGetent()
require.NoError(t, err)
assert.Equal(t, stdUser.Uid, u.Uid)
assert.Equal(t, stdUser.Username, u.Username)
@@ -41,7 +41,7 @@ func TestGroupIdsWithFallback_CurrentUser(t *testing.T) {
current, err := user.Current()
require.NoError(t, err)
groups, err := groupIdsWithFallback(current)
groups, err := GroupIdsWithFallback(current)
require.NoError(t, err)
require.NotEmpty(t, groups, "current user should have at least one group")
@@ -56,7 +56,7 @@ func TestGroupIdsWithFallback_CurrentUser(t *testing.T) {
func TestGetShellFromGetent_CurrentUser(t *testing.T) {
if runtime.GOOS == "windows" {
// Windows stub always returns empty, which is correct
shell := getShellFromGetent("1000")
shell := GetShellFromGetent("1000")
assert.Empty(t, shell, "Windows stub should return empty")
return
}
@@ -65,7 +65,7 @@ func TestGetShellFromGetent_CurrentUser(t *testing.T) {
require.NoError(t, err)
// getent may not be available on all systems (e.g., macOS without Homebrew getent)
shell := getShellFromGetent(current.Uid)
shell := GetShellFromGetent(current.Uid)
if shell == "" {
t.Log("getShellFromGetent returned empty, getent may not be available")
return
@@ -78,7 +78,7 @@ func TestLookupWithGetent_RootUser(t *testing.T) {
t.Skip("no root user on Windows")
}
u, err := lookupWithGetent("root")
u, err := LookupWithGetent("root")
if err != nil {
t.Skip("root user not available on this system")
}
@@ -91,20 +91,20 @@ func TestLookupWithGetent_RootUser(t *testing.T) {
// consistent and correct results when composed together.
func TestIntegration_FullLookupChain(t *testing.T) {
// Step 1: currentUserWithGetent must resolve the running user.
current, err := currentUserWithGetent()
current, err := CurrentUserWithGetent()
require.NoError(t, err, "currentUserWithGetent must resolve the running user")
require.NotEmpty(t, current.Uid)
require.NotEmpty(t, current.Username)
// Step 2: lookupWithGetent by the same username must return matching identity.
byName, err := lookupWithGetent(current.Username)
byName, err := LookupWithGetent(current.Username)
require.NoError(t, err)
assert.Equal(t, current.Uid, byName.Uid, "lookup by name should return same UID")
assert.Equal(t, current.Gid, byName.Gid, "lookup by name should return same GID")
assert.Equal(t, current.HomeDir, byName.HomeDir, "lookup by name should return same home")
// Step 3: groupIdsWithFallback must return at least the primary GID.
groups, err := groupIdsWithFallback(current)
groups, err := GroupIdsWithFallback(current)
require.NoError(t, err)
require.NotEmpty(t, groups, "user must have at least one group")
@@ -123,7 +123,7 @@ func TestIntegration_FullLookupChain(t *testing.T) {
// Step 4: getShellFromGetent should either return a valid shell path or empty
// (empty is OK when getent is not available, e.g. macOS without Homebrew getent).
if runtime.GOOS != "windows" {
shell := getShellFromGetent(current.Uid)
shell := GetShellFromGetent(current.Uid)
if shell != "" {
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
}
@@ -138,10 +138,10 @@ func TestIntegration_LookupAndGroupsConsistency(t *testing.T) {
require.NoError(t, err)
// Simulate the SSH server flow: lookup user, then get their groups.
resolved, err := lookupWithGetent(current.Username)
resolved, err := LookupWithGetent(current.Username)
require.NoError(t, err)
groups, err := groupIdsWithFallback(resolved)
groups, err := GroupIdsWithFallback(resolved)
require.NoError(t, err)
require.NotEmpty(t, groups, "resolved user must have groups")
@@ -154,19 +154,3 @@ func TestIntegration_LookupAndGroupsConsistency(t *testing.T) {
}
}
}
// TestIntegration_ShellLookupChain tests the full shell resolution chain
// (getShellFromPasswd -> getShellFromGetent -> $SHELL -> default) on Unix.
func TestIntegration_ShellLookupChain(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Unix shell lookup not applicable on Windows")
}
current, err := user.Current()
require.NoError(t, err)
// getUserShell is the top-level function used by the SSH server.
shell := getUserShell(current.Uid)
require.NotEmpty(t, shell, "getUserShell must always return a shell")
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
}

View File

@@ -1,6 +1,6 @@
//go:build !windows
package server
package shell
import (
"context"
@@ -14,19 +14,25 @@ import (
const getentTimeout = 5 * time.Second
// getShellFromGetent gets a user's login shell via getent by UID.
// GetShellFromGetent gets a user's login shell via getent by UID.
// This is needed even with CGO because getShellFromPasswd reads /etc/passwd
// directly and won't find NSS-provided users there.
func getShellFromGetent(userID string) string {
_, shell, err := runGetent(userID)
func GetShellFromGetent(userID string) string {
_, shell, err := runGetentPasswd(userID)
if err != nil {
return ""
}
return shell
}
// runGetent executes `getent passwd <query>` and returns the user and login shell.
func runGetent(query string) (*user.User, string, error) {
// GetUserFromGetent returns the resolved group from either a uid or username
func GetUserFromGetent(user string) (*user.User, error) {
u, _, err := runGetentPasswd(user)
return u, err
}
// runGetentPasswd executes `getent passwd <query>` and returns the user and login shell.
func runGetentPasswd(query string) (*user.User, string, error) {
if !validateGetentInput(query) {
return nil, "", fmt.Errorf("invalid getent input: %q", query)
}
@@ -42,6 +48,23 @@ func runGetent(query string) (*user.User, string, error) {
return parseGetentPasswd(string(out))
}
// runGetentGroup executes `getent group <query>` and returns the group
func runGetentGroup(query string) (*user.Group, error) {
if !validateGetentInput(query) {
return nil, fmt.Errorf("invalid getent input: %q", query)
}
ctx, cancel := context.WithTimeout(context.Background(), getentTimeout)
defer cancel()
out, err := exec.CommandContext(ctx, "getent", "group", query).Output()
if err != nil {
return nil, fmt.Errorf("getent passwd%s: %w", query, err)
}
return parseGetentGroup(string(out))
}
// parseGetentPasswd parses getent passwd output: "name:x:uid:gid:gecos:home:shell"
func parseGetentPasswd(output string) (*user.User, string, error) {
fields := strings.SplitN(strings.TrimSpace(output), ":", 8)
@@ -67,6 +90,20 @@ func parseGetentPasswd(output string) (*user.User, string, error) {
}, shell, nil
}
// parseGetentGroup parses getent group output: "group:x:gid:user"
func parseGetentGroup(output string) (*user.Group, error) {
fields := strings.SplitN(strings.TrimSpace(output), ":", 8)
if len(fields) < 4 {
return nil, fmt.Errorf("unexpected getent output (need 4+ fields): %q", output)
}
if fields[0] == "" || fields[2] == "" {
return nil, fmt.Errorf("missing required fields in getent output: %q", output)
}
return &user.Group{Gid: fields[2], Name: fields[0]}, nil
}
// validateGetentInput checks that the input is safe to pass to getent or id.
// Allows POSIX usernames, numeric UIDs, and common NSS extensions
// (@ for Kerberos, $ for Samba, + for NIS compat).

View File

@@ -1,6 +1,6 @@
//go:build !windows
package server
package shell
import (
"os/exec"
@@ -195,7 +195,7 @@ func TestRunGetent_RootUser(t *testing.T) {
t.Skip("getent not available on this system")
}
u, shell, err := runGetent("root")
u, shell, err := runGetentPasswd("root")
require.NoError(t, err)
assert.Equal(t, "root", u.Username)
assert.Equal(t, "0", u.Uid)
@@ -208,7 +208,7 @@ func TestRunGetent_ByUID(t *testing.T) {
t.Skip("getent not available on this system")
}
u, _, err := runGetent("0")
u, _, err := runGetentPasswd("0")
require.NoError(t, err)
assert.Equal(t, "root", u.Username)
assert.Equal(t, "0", u.Uid)
@@ -219,15 +219,15 @@ func TestRunGetent_NonexistentUser(t *testing.T) {
t.Skip("getent not available on this system")
}
_, _, err := runGetent("nonexistent_user_xyzzy_12345")
_, _, err := runGetentPasswd("nonexistent_user_xyzzy_12345")
assert.Error(t, err)
}
func TestRunGetent_InvalidInput(t *testing.T) {
_, _, err := runGetent("")
_, _, err := runGetentPasswd("")
assert.Error(t, err)
_, _, err = runGetent("user\x00name")
_, _, err = runGetentPasswd("user\x00name")
assert.Error(t, err)
}
@@ -236,7 +236,7 @@ func TestRunGetent_NotAvailable(t *testing.T) {
t.Skip("getent is available, can't test missing case")
}
_, _, err := runGetent("root")
_, _, err := runGetentPasswd("root")
assert.Error(t, err, "should fail when getent is not installed")
}
@@ -283,7 +283,7 @@ func TestGetentResultsMatchStdlib(t *testing.T) {
current, err := user.Current()
require.NoError(t, err)
getentUser, _, err := runGetent(current.Username)
getentUser, _, err := runGetentPasswd(current.Username)
require.NoError(t, err)
assert.Equal(t, current.Username, getentUser.Username, "username should match")
@@ -300,7 +300,7 @@ func TestGetentResultsMatchStdlib_ByUID(t *testing.T) {
current, err := user.Current()
require.NoError(t, err)
getentUser, _, err := runGetent(current.Uid)
getentUser, _, err := runGetentPasswd(current.Uid)
require.NoError(t, err)
assert.Equal(t, current.Username, getentUser.Username, "username should match when looked up by UID")
@@ -356,7 +356,7 @@ func TestGetShellFromPasswd_CurrentUser(t *testing.T) {
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
if _, err := exec.LookPath("getent"); err == nil {
_, getentShell, getentErr := runGetent(current.Uid)
_, getentShell, getentErr := runGetentPasswd(current.Uid)
if getentErr == nil && getentShell != "" {
assert.Equal(t, getentShell, shell, "shell from /etc/passwd should match getent")
}
@@ -400,7 +400,7 @@ func TestGetShellFromPasswd_MatchesGetentForKnownUsers(t *testing.T) {
continue
}
_, getentShell, err := runGetent(uid)
_, getentShell, err := runGetentPasswd(uid)
if err != nil {
continue
}

View File

@@ -1,26 +1,26 @@
//go:build windows
package server
package shell
import "os/user"
// lookupWithGetent on Windows just delegates to os/user.Lookup.
// Windows does not use NSS/getent; its user lookup works without CGO.
func lookupWithGetent(username string) (*user.User, error) {
func LookupWithGetent(username string) (*user.User, error) {
return user.Lookup(username)
}
// currentUserWithGetent on Windows just delegates to os/user.Current.
func currentUserWithGetent() (*user.User, error) {
func CurrentUserWithGetent() (*user.User, error) {
return user.Current()
}
// getShellFromGetent is a no-op on Windows; shell resolution uses PowerShell detection.
func getShellFromGetent(_ string) string {
func GetShellFromGetent(_ string) string {
return ""
}
// groupIdsWithFallback on Windows just delegates to u.GroupIds().
func groupIdsWithFallback(u *user.User) ([]string, error) {
func GroupIdsWithFallback(u *user.User) ([]string, error) {
return u.GroupIds()
}

View File

@@ -1,17 +1,14 @@
package server
package shell
import (
"bufio"
"fmt"
"net"
"os"
"os/exec"
"os/user"
"runtime"
"strconv"
"strings"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
@@ -24,7 +21,7 @@ const (
// getUserShell returns the appropriate shell for the given user ID
// Handles all platform-specific logic and fallbacks consistently
func getUserShell(userID string) string {
func GetUserShell(userID string) string {
switch runtime.GOOS {
case "windows":
return getWindowsUserShell()
@@ -56,7 +53,7 @@ func getUnixUserShell(userID string) string {
return shell
}
if shell := getShellFromGetent(userID); shell != "" {
if shell := GetShellFromGetent(userID); shell != "" {
return shell
}
@@ -101,8 +98,8 @@ func getShellFromPasswd(userID string) string {
return ""
}
// prepareUserEnv prepares environment variables for user execution
func prepareUserEnv(user *user.User, shell string) []string {
// PrepareUserEnv prepares environment variables for user execution
func PrepareUserEnv(user *user.User, shell string) []string {
pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games"
if runtime.GOOS == "windows" {
pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`
@@ -119,7 +116,7 @@ func prepareUserEnv(user *user.User, shell string) []string {
// acceptEnv checks if environment variable from SSH client should be accepted
// This is a whitelist of variables that SSH clients can send to the server
func acceptEnv(envVar string) bool {
func AcceptEnv(envVar string) bool {
varName := envVar
if idx := strings.Index(envVar, "="); idx != -1 {
varName = envVar[:idx]
@@ -156,29 +153,3 @@ func acceptEnv(envVar string) bool {
return false
}
// prepareSSHEnv prepares SSH protocol-specific environment variables
// These variables provide information about the SSH connection itself
func prepareSSHEnv(session ssh.Session) []string {
remoteAddr := session.RemoteAddr()
localAddr := session.LocalAddr()
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
if err != nil {
remoteHost = remoteAddr.String()
remotePort = "0"
}
localHost, localPort, err := net.SplitHostPort(localAddr.String())
if err != nil {
localHost = localAddr.String()
localPort = strconv.Itoa(InternalSSHPort)
}
return []string{
// SSH_CLIENT format: "client_ip client_port server_port"
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
}
}

View File

@@ -0,0 +1,26 @@
package shell
import (
"os/user"
"runtime"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestIntegration_ShellLookupChain tests the full shell resolution chain
// (getShellFromPasswd -> getShellFromGetent -> $SHELL -> default) on Unix.
func TestIntegration_ShellLookupChain(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Unix shell lookup not applicable on Windows")
}
current, err := user.Current()
require.NoError(t, err)
// getUserShell is the top-level function used by the SSH server.
shell := GetUserShell(current.Uid)
require.NotEmpty(t, shell, "getUserShell must always return a shell")
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
}

View File

@@ -19,6 +19,7 @@ import (
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
"github.com/netbirdio/netbird/client/internal/shell"
log "github.com/sirupsen/logrus"
)
@@ -146,10 +147,10 @@ func (s *Server) createShellCommand(ctx context.Context, shell string, args []st
// prepareCommandEnv prepares environment variables for command execution on Unix
func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string {
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if acceptEnv(v) {
if shell.AcceptEnv(v) {
env = append(env, v)
}
}

View File

@@ -247,10 +247,10 @@ func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, sess
userEnv, err := s.getUserEnvironment(logger, username, domain)
if err != nil {
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if acceptEnv(v) {
if shell.AcceptEnv(v) {
env = append(env, v)
}
}
@@ -260,7 +260,7 @@ func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, sess
env := userEnv
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if acceptEnv(v) {
if shell.AcceptEnv(v) {
env = append(env, v)
}
}
@@ -273,7 +273,7 @@ func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privileg
return false
}
shell := getUserShell(privilegeResult.User.Uid)
shell := shell.GetUserShell(privilegeResult.User.Uid)
logger.Infof("starting interactive shell: %s", shell)
s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil)
@@ -384,7 +384,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _
}
username, domain := s.parseUsername(localUser.Username)
shell := getUserShell(localUser.Uid)
shell := shell.GetUserShell(localUser.Uid)
req := PtyExecutionRequest{
Shell: shell,

View File

@@ -3,11 +3,15 @@ package server
import (
"errors"
"fmt"
"net"
"os"
"os/user"
"runtime"
"strconv"
"strings"
"github.com/gliderlabs/ssh"
"github.com/netbirdio/netbird/client/internal/shell"
log "github.com/sirupsen/logrus"
)
@@ -23,8 +27,8 @@ func isPlatformUnix() bool {
// Dependency injection variables for testing - allows mocking dynamic runtime checks
var (
getCurrentUser = currentUserWithGetent
lookupUser = lookupWithGetent
getCurrentUser = shell.CurrentUserWithGetent
lookupUser = shell.LookupWithGetent
getCurrentOS = func() string { return runtime.GOOS }
getIsProcessPrivileged = isCurrentProcessPrivileged
@@ -409,3 +413,29 @@ func isWindowsElevated() bool {
log.Debugf("Windows user switching not supported: not running as privileged user (current: %s)", currentUser.Uid)
return false
}
// prepareSSHEnv prepares SSH protocol-specific environment variables
// These variables provide information about the SSH connection itself
func prepareSSHEnv(session ssh.Session) []string {
remoteAddr := session.RemoteAddr()
localAddr := session.LocalAddr()
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
if err != nil {
remoteHost = remoteAddr.String()
remotePort = "0"
}
localHost, localPort, err := net.SplitHostPort(localAddr.String())
if err != nil {
localHost = localAddr.String()
localPort = strconv.Itoa(InternalSSHPort)
}
return []string{
// SSH_CLIENT format: "client_ip client_port server_port"
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
}
}

View File

@@ -15,6 +15,7 @@ import (
"strconv"
"github.com/gliderlabs/ssh"
"github.com/netbirdio/netbird/client/internal/shell"
log "github.com/sirupsen/logrus"
)
@@ -160,7 +161,7 @@ func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []u
// getSupplementaryGroups retrieves supplementary group IDs for a user.
// Uses id/getent fallback for NSS users in CGO_ENABLED=0 builds.
func (s *Server) getSupplementaryGroups(u *user.User) ([]uint32, error) {
groupIDStrings, err := groupIdsWithFallback(u)
groupIDStrings, err := shell.GroupIdsWithFallback(u)
if err != nil {
return nil, fmt.Errorf("get group IDs for user %s: %w", u.Username, err)
}
@@ -196,7 +197,7 @@ func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, l
GID: gid,
Groups: groups,
WorkingDir: localUser.HomeDir,
Shell: getUserShell(localUser.Uid),
Shell: shell.GetUserShell(localUser.Uid),
Command: session.RawCommand(),
PTY: hasPty,
}
@@ -228,7 +229,7 @@ func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq s
func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.User, ptyReq ssh.Pty) *exec.Cmd {
log.Debugf("creating direct Pty command for user %s (no user switching needed)", localUser.Username)
shell := getUserShell(localUser.Uid)
shell := shell.GetUserShell(localUser.Uid)
args := s.getShellCommandArgs(shell, session.RawCommand())
cmd := s.createShellCommand(session.Context(), shell, args)
@@ -245,12 +246,12 @@ func (s *Server) preparePtyEnv(localUser *user.User, ptyReq ssh.Pty, session ssh
termType = "xterm-256color"
}
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
env = append(env, fmt.Sprintf("TERM=%s", termType))
for _, v := range session.Environ() {
if acceptEnv(v) {
if shell.AcceptEnv(v) {
env = append(env, v)
}
}

View File

@@ -13,6 +13,8 @@ import (
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal/shell"
)
// validateUsername validates Windows usernames according to SAM Account Name rules
@@ -104,7 +106,7 @@ func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, l
func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) {
username, domain := s.parseUsername(localUser.Username)
shell := getUserShell(localUser.Uid)
sh := shell.GetUserShell(localUser.Uid)
rawCmd := session.RawCommand()
var command string
@@ -116,7 +118,7 @@ func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session,
Username: username,
Domain: domain,
WorkingDir: localUser.HomeDir,
Shell: shell,
Shell: sh,
Command: command,
}