Compare commits

..

3 Commits

Author SHA1 Message Date
bcmmbaga
37ce40ede6 build client config in sync response 2026-06-05 23:53:48 +03:00
bcmmbaga
318014552f Merge branch 'main' into refactor/mgmt-bootstrap 2026-06-05 22:02:25 +03:00
bcmmbaga
78ed62535e Clean up middleware and move auth middlewa into module 2026-05-26 20:45:57 +03:00
40 changed files with 256 additions and 1102 deletions

View File

@@ -29,10 +29,10 @@ jobs:
persist-credentials: false
- name: Generate FreeBSD port diff
run: bash -x release_files/freebsd-port-diff.sh
run: bash release_files/freebsd-port-diff.sh
- name: Generate FreeBSD port issue body
run: bash -x release_files/freebsd-port-issue-body.sh
run: bash release_files/freebsd-port-issue-body.sh
- name: Check if diff was generated
id: check_diff

View File

@@ -65,7 +65,7 @@ jobs:
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 62914560 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 60MB limit!"
if [ ${SIZE} -gt 58720256 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
exit 1
fi

View File

@@ -1,36 +0,0 @@
//go:build darwin || freebsd
package cmd
import (
"fmt"
"net"
"golang.org/x/sys/unix"
)
// peerUID returns the uid of the process on the other end of a unix socket
// connection, read via LOCAL_PEERCRED (xucred). Note: xucred carries the uid
// and group list but no pid, so audit on these platforms is uid-based.
func peerUID(c net.Conn) (int, error) {
uc, ok := c.(*net.UnixConn)
if !ok {
return 0, fmt.Errorf("connection is not a unix socket: %T", c)
}
raw, err := uc.SyscallConn()
if err != nil {
return 0, fmt.Errorf("raw conn: %w", err)
}
var cred *unix.Xucred
var credErr error
if err := raw.Control(func(fd uintptr) {
cred, credErr = unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
}); err != nil {
return 0, fmt.Errorf("getsockopt control: %w", err)
}
if credErr != nil {
return 0, fmt.Errorf("LOCAL_PEERCRED: %w", credErr)
}
return int(cred.Uid), nil
}

View File

@@ -1,35 +0,0 @@
//go:build linux
package cmd
import (
"fmt"
"net"
"golang.org/x/sys/unix"
)
// peerUID returns the uid of the process on the other end of a unix socket
// connection, read from the kernel via SO_PEERCRED.
func peerUID(c net.Conn) (int, error) {
uc, ok := c.(*net.UnixConn)
if !ok {
return 0, fmt.Errorf("connection is not a unix socket: %T", c)
}
raw, err := uc.SyscallConn()
if err != nil {
return 0, fmt.Errorf("raw conn: %w", err)
}
var cred *unix.Ucred
var credErr error
if err := raw.Control(func(fd uintptr) {
cred, credErr = unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
}); err != nil {
return 0, fmt.Errorf("getsockopt control: %w", err)
}
if credErr != nil {
return 0, fmt.Errorf("SO_PEERCRED: %w", credErr)
}
return int(cred.Uid), nil
}

View File

@@ -1,16 +0,0 @@
//go:build !linux && !darwin && !freebsd
package cmd
import (
"fmt"
"net"
"runtime"
)
// peerUID is unimplemented on this platform, so the trust-on-first-use socket
// migration cannot run here. Configure --socket-owner explicitly, or use
// --disable-strict-socket. (Windows uses a TCP socket and never reaches this.)
func peerUID(net.Conn) (int, error) {
return 0, fmt.Errorf("peer credential check not supported on %s", runtime.GOOS)
}

View File

@@ -77,8 +77,6 @@ var (
updateSettingsDisabled bool
captureEnabled bool
networksDisabled bool
socketOwner string
strictSocketDisabled bool
rootCmd = &cobra.Command{
Use: "netbird",

View File

@@ -57,9 +57,6 @@ func init() {
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
serviceCmd.PersistentFlags().StringVar(&socketOwner, "socket-owner", "", "user to own the daemon control socket; restricts it to that user plus the netbird group (0660). If unset, the first client to connect claims ownership (trust-on-first-use)")
serviceCmd.PersistentFlags().BoolVar(&strictSocketDisabled, "disable-strict-socket", false, "leave the daemon control socket world-writable (0666) instead of restricting it; set via the (root-only) service command")
rootCmd.AddCommand(serviceCmd)
}

View File

@@ -4,15 +4,10 @@ package cmd
import (
"context"
"errors"
"fmt"
"net"
"os"
"os/exec"
"os/user"
"strconv"
"strings"
"sync"
"time"
"github.com/kardianos/service"
@@ -21,7 +16,6 @@ import (
"github.com/spf13/cobra"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/client/internal/shell"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/client/system"
@@ -60,36 +54,10 @@ func (p *program) Start(svc service.Service) error {
go func() {
defer listen.Close()
srvListener := listen
if split[0] == "unix" {
owner := effectiveSocketOwner()
switch {
case strictSocketDisabled:
// Opt-out (root-only, via service.json): leave it world-writable.
if err := os.Chmod(split[1], 0666); err != nil {
log.Errorf("failed setting daemon permissions: %v", split[1])
return
}
case owner != "":
// Seeded owner (flag, MDM, or persisted TOFU result): restrict
// before serving so there is no open window.
uid, err := lookupUser(owner)
if err != nil {
log.Errorf("lookup socket owner %q: %v", owner, err)
return
}
if err := restrictSocket(split[1], uid); err != nil {
log.Errorf("restrict socket to %q: %v", owner, err)
return
}
default:
// Trust-on-first-use: open the socket now; tofuListener locks it
// to the first caller's uid on the first connection.
if err := os.Chmod(split[1], 0666); err != nil {
log.Errorf("failed setting daemon permissions: %v", split[1])
return
}
srvListener = &tofuListener{Listener: listen, path: split[1], owner: -1}
if err := os.Chmod(split[1], 0666); err != nil {
log.Errorf("failed setting daemon permissions: %v", split[1])
return
}
}
@@ -104,180 +72,13 @@ func (p *program) Start(svc service.Service) error {
p.serverInstanceMu.Unlock()
log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(srvListener); err != nil {
if err := p.serv.Serve(listen); err != nil {
log.Errorf("failed to serve daemon requests: %v", err)
}
}()
return nil
}
func lookupUser(username string) (int, error) {
u, err := shell.LookupWithGetent(username)
if err != nil {
return -1, fmt.Errorf("lookup user %s: %w", username, err)
}
uid, err := strconv.Atoi(u.Uid)
if err != nil {
return -1, fmt.Errorf("parse uid %s: %w", u.Uid, err)
}
return uid, nil
}
// addGroup creates a system group if it doesn't already exist and returns the gid.
// Must run as root.
func addGroup(name string) (int, error) {
group, err := shell.LookupGroupWithGetent(name)
if err == nil {
gid, err := strconv.ParseInt(group.Gid, 10, 64)
return int(gid), err
}
// looup failed, create the group
groupadd, err := exec.LookPath("groupadd")
if err != nil {
// Fallback for Alpine/BusyBox systems.
if groupadd, err = exec.LookPath("addgroup"); err != nil {
return -1, errors.New("neither groupadd nor addgroup found")
}
}
// Use --system for a service/daemon group (no login, low GID).
out, err := exec.Command(groupadd, "--system", name).CombinedOutput()
if err != nil {
return -1, fmt.Errorf("create group %q: %w: %s", name, err, out)
}
if group, err := shell.LookupWithGetent(name); err == nil {
gid, err := strconv.ParseInt(group.Gid, 10, 64)
return int(gid), err
}
return -1, fmt.Errorf("lookup group %q: %w", name, err)
}
// restrictSocket locks the unix socket down to the owner uid plus the netbird
// group (0660). If the group cannot be created or applied, it fails closed to
// owner-only 0600 — it never leaves the socket world-writable.
func restrictSocket(path string, uid int) error {
// TODO: introduce flag to configure this (LDAP/AD usecase)
gid, err := addGroup("netbird")
if err != nil {
log.Errorf("create netbird group, failing closed to owner-only 0600: %v", err)
return chownChmod(path, uid, -1, 0600)
}
if err := chownChmod(path, uid, gid, 0660); err != nil {
log.Errorf("apply netbird group to socket, failing closed to owner-only 0600: %v", err)
return chownChmod(path, uid, -1, 0600)
}
return nil
}
// chownChmod sets ownership and mode on the socket. A gid of -1 leaves the
// group unchanged.
func chownChmod(path string, uid, gid int, mode os.FileMode) error {
if err := os.Chown(path, uid, gid); err != nil {
return fmt.Errorf("chown socket %s: %w", path, err)
}
if err := os.Chmod(path, mode); err != nil {
return fmt.Errorf("chmod socket %s: %w", path, err)
}
return nil
}
// tofuListener implements trust-on-first-use for the daemon control socket.
// The socket starts world-writable; the first caller's uid (read via the
// platform peer-credential mechanism) becomes the owner. On that first
// connection the socket is restricted (see restrictSocket) and the owner is
// persisted so the open window never reopens on later starts. Connections that
// raced in during the open window and are neither the owner nor root are
// dropped. Changing the socket mode does not disturb the already-open
// connection, so the first caller's request is served normally.
type tofuListener struct {
net.Listener
path string
mu sync.Mutex
owner int // -1 until claimed
}
func (l *tofuListener) Accept() (net.Conn, error) {
for {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
uid, err := peerUID(c)
if err != nil {
log.Errorf("read peer credentials, dropping connection: %v", err)
_ = c.Close()
continue
}
l.mu.Lock()
if l.owner == -1 {
if err := restrictSocket(l.path, uid); err != nil {
l.mu.Unlock()
_ = c.Close()
// Refuse to serve on a socket we could not lock down.
return nil, fmt.Errorf("restrict socket on first connection: %w", err)
}
l.owner = uid
persistSocketOwner(uid)
log.Infof("control socket restricted to first caller (uid %d)", uid)
l.mu.Unlock()
return c, nil
}
owner := l.owner
l.mu.Unlock()
// New connects are already gated by the 0660 perms set above; this only
// drops anything that slipped in during the brief open window.
if uid != owner && uid != 0 {
log.Warnf("dropping non-owner connection (uid %d) during socket bootstrap", uid)
_ = c.Close()
continue
}
return c, nil
}
}
// effectiveSocketOwner returns the configured socket owner: the --socket-owner
// flag when set, otherwise the owner persisted by a previous TOFU migration.
func effectiveSocketOwner() string {
if socketOwner != "" {
return socketOwner
}
params, err := loadServiceParams()
if err != nil {
log.Errorf("load service params for socket owner: %v", err)
return ""
}
if params != nil {
return params.SocketOwner
}
return ""
}
// persistSocketOwner records the TOFU-selected owner (by username) so the next
// daemon start restricts the socket immediately, with no open window.
func persistSocketOwner(uid int) {
u, err := user.LookupId(strconv.Itoa(uid))
if err != nil {
log.Errorf("resolve uid %d to username for persistence: %v", uid, err)
return
}
params, err := loadServiceParams()
if err != nil {
log.Errorf("load service params to persist socket owner: %v", err)
return
}
if params == nil {
params = currentServiceParams()
}
params.SocketOwner = u.Username
if err := saveServiceParams(params); err != nil {
log.Errorf("persist socket owner: %v", err)
}
}
func (p *program) Stop(srv service.Service) error {
p.serverInstanceMu.Lock()
if p.serverInstance != nil {

View File

@@ -67,14 +67,6 @@ func buildServiceArguments() []string {
args = append(args, "--disable-networks")
}
if socketOwner != "" {
args = append(args, "--socket-owner", socketOwner)
}
if strictSocketDisabled {
args = append(args, "--disable-strict-socket")
}
return args
}
@@ -135,8 +127,6 @@ var installCmd = &cobra.Command{
return err
}
cmd.Printf("SUDO_UID: %s\n", os.Getenv("SUDO_UID"))
if err := loadAndApplyServiceParams(cmd); err != nil {
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
}

View File

@@ -30,8 +30,6 @@ type serviceParams struct {
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
EnableCapture bool `json:"enable_capture,omitempty"`
DisableNetworks bool `json:"disable_networks,omitempty"`
SocketOwner string `json:"socket_owner,omitempty"`
DisableStrictSocket bool `json:"disable_strict_socket,omitempty"`
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
}
@@ -84,8 +82,6 @@ func currentServiceParams() *serviceParams {
DisableUpdateSettings: updateSettingsDisabled,
EnableCapture: captureEnabled,
DisableNetworks: networksDisabled,
SocketOwner: socketOwner,
DisableStrictSocket: strictSocketDisabled,
}
if len(serviceEnvVars) > 0 {
@@ -158,14 +154,6 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
networksDisabled = params.DisableNetworks
}
if !serviceCmd.PersistentFlags().Changed("socket-owner") {
socketOwner = params.SocketOwner
}
if !serviceCmd.PersistentFlags().Changed("disable-strict-socket") {
strictSocketDisabled = params.DisableStrictSocket
}
applyServiceEnvParams(cmd, params)
}

View File

@@ -806,8 +806,6 @@ func (g *BundleGenerator) addSyncResponse() error {
AllowPartial: true,
}
g.maskSecrets()
jsonBytes, err := options.Marshal(g.syncResponse)
if err != nil {
return fmt.Errorf("generate json: %w", err)
@@ -820,27 +818,6 @@ func (g *BundleGenerator) addSyncResponse() error {
return nil
}
func (g *BundleGenerator) maskSecrets() {
if g.syncResponse == nil || g.syncResponse.NetbirdConfig == nil {
return
}
if g.syncResponse.NetbirdConfig.Flow != nil {
g.syncResponse.NetbirdConfig.Flow.TokenPayload = maskedValue
}
if g.syncResponse.NetbirdConfig.Relay != nil {
g.syncResponse.NetbirdConfig.Relay.TokenPayload = maskedValue
}
for i := range g.syncResponse.NetbirdConfig.Turns {
if g.syncResponse.NetbirdConfig.Turns[i] != nil {
g.syncResponse.NetbirdConfig.Turns[i].Password = maskedValue
}
}
}
func (g *BundleGenerator) addStateFile() error {
sm := profilemanager.NewServiceManager("")
path := sm.GetStatePath()

View File

@@ -1,26 +0,0 @@
package shell
import (
"os/user"
"runtime"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestIntegration_ShellLookupChain tests the full shell resolution chain
// (getShellFromPasswd -> getShellFromGetent -> $SHELL -> default) on Unix.
func TestIntegration_ShellLookupChain(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Unix shell lookup not applicable on Windows")
}
current, err := user.Current()
require.NoError(t, err)
// getUserShell is the top-level function used by the SSH server.
shell := GetUserShell(current.Uid)
require.NotEmpty(t, shell, "getUserShell must always return a shell")
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
}

View File

@@ -19,7 +19,6 @@ import (
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
"github.com/netbirdio/netbird/client/internal/shell"
log "github.com/sirupsen/logrus"
)
@@ -147,10 +146,10 @@ func (s *Server) createShellCommand(ctx context.Context, shell string, args []st
// prepareCommandEnv prepares environment variables for command execution on Unix
func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string {
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if shell.AcceptEnv(v) {
if acceptEnv(v) {
env = append(env, v)
}
}

View File

@@ -247,10 +247,10 @@ func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, sess
userEnv, err := s.getUserEnvironment(logger, username, domain)
if err != nil {
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if shell.AcceptEnv(v) {
if acceptEnv(v) {
env = append(env, v)
}
}
@@ -260,7 +260,7 @@ func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, sess
env := userEnv
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if shell.AcceptEnv(v) {
if acceptEnv(v) {
env = append(env, v)
}
}
@@ -273,7 +273,7 @@ func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privileg
return false
}
shell := shell.GetUserShell(privilegeResult.User.Uid)
shell := getUserShell(privilegeResult.User.Uid)
logger.Infof("starting interactive shell: %s", shell)
s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil)
@@ -384,7 +384,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _
}
username, domain := s.parseUsername(localUser.Username)
shell := shell.GetUserShell(localUser.Uid)
shell := getUserShell(localUser.Uid)
req := PtyExecutionRequest{
Shell: shell,

View File

@@ -1,6 +1,6 @@
//go:build cgo && !osusergo && !windows
package shell
package server
import "os/user"
@@ -8,22 +8,17 @@ import "os/user"
// When CGO is enabled, os/user uses libc (getpwnam_r) which goes through
// the NSS stack natively. If it fails, the user truly doesn't exist and
// getent would also fail.
func LookupWithGetent(username string) (*user.User, error) {
func lookupWithGetent(username string) (*user.User, error) {
return user.Lookup(username)
}
// currentUserWithGetent with CGO delegates directly to os/user.Current.
func CurrentUserWithGetent() (*user.User, error) {
func currentUserWithGetent() (*user.User, error) {
return user.Current()
}
// LookupGroupWithGetent returns the resolved group from either a gid or groupname
func LookupGroupWithGetent(name string) (*user.Group, error) {
return user.LookupGroup(name)
}
// groupIdsWithFallback with CGO delegates directly to user.GroupIds.
// libc's getgrouplist handles NSS groups natively.
func GroupIdsWithFallback(u *user.User) ([]string, error) {
func groupIdsWithFallback(u *user.User) ([]string, error) {
return u.GroupIds()
}

View File

@@ -1,6 +1,6 @@
//go:build (!cgo || osusergo) && !windows
package shell
package server
import (
"os"
@@ -13,7 +13,7 @@ import (
// lookupWithGetent looks up a user by name, falling back to getent if os/user fails.
// Without CGO, os/user only reads /etc/passwd and misses NSS-provided users.
// getent goes through the host's NSS stack.
func LookupWithGetent(username string) (*user.User, error) {
func lookupWithGetent(username string) (*user.User, error) {
u, err := user.Lookup(username)
if err == nil {
return u, nil
@@ -22,7 +22,7 @@ func LookupWithGetent(username string) (*user.User, error) {
stdErr := err
log.Debugf("os/user.Lookup(%q) failed, trying getent: %v", username, err)
u, _, getentErr := runGetentPasswd(username)
u, _, getentErr := runGetent(username)
if getentErr != nil {
log.Debugf("getent fallback for %q also failed: %v", username, getentErr)
return nil, stdErr
@@ -31,25 +31,8 @@ func LookupWithGetent(username string) (*user.User, error) {
return u, nil
}
// LookupGroupWithGetent returns the resolved group from either a gid or groupname
func LookupGroupWithGetent(name string) (*user.Group, error) {
g, err := user.LookupGroup(name)
if err == nil {
return g, nil
}
stdErr := err
log.Debugf("os/user.LookupGroup(%q) failed, trying getent: %v", name, err)
g, getentErr := runGetentGroup(name)
if getentErr != nil {
log.Debugf("getent fallback for %q also failed: %v", name, getentErr)
return nil, stdErr
}
return g, nil
}
// currentUserWithGetent gets the current user, falling back to getent if os/user fails.
func CurrentUserWithGetent() (*user.User, error) {
func currentUserWithGetent() (*user.User, error) {
u, err := user.Current()
if err == nil {
return u, nil
@@ -59,7 +42,7 @@ func CurrentUserWithGetent() (*user.User, error) {
uid := strconv.Itoa(os.Getuid())
log.Debugf("os/user.Current() failed, trying getent with UID %s: %v", uid, err)
u, _, getentErr := runGetentPasswd(uid)
u, _, getentErr := runGetent(uid)
if getentErr != nil {
return nil, stdErr
}
@@ -74,7 +57,7 @@ func CurrentUserWithGetent() (*user.User, error) {
// only reads /etc/group and silently returns incomplete results for NSS users
// (no error, just missing groups). The id command goes through NSS and returns
// the full set.
func GroupIdsWithFallback(u *user.User) ([]string, error) {
func groupIdsWithFallback(u *user.User) ([]string, error) {
ids, err := runIdGroups(u.Username)
if err == nil {
return ids, nil

View File

@@ -1,4 +1,4 @@
package shell
package server
import (
"os/user"
@@ -15,7 +15,7 @@ func TestLookupWithGetent_CurrentUser(t *testing.T) {
current, err := user.Current()
require.NoError(t, err)
u, err := LookupWithGetent(current.Username)
u, err := lookupWithGetent(current.Username)
require.NoError(t, err)
assert.Equal(t, current.Username, u.Username)
assert.Equal(t, current.Uid, u.Uid)
@@ -23,7 +23,7 @@ func TestLookupWithGetent_CurrentUser(t *testing.T) {
}
func TestLookupWithGetent_NonexistentUser(t *testing.T) {
_, err := LookupWithGetent("nonexistent_user_xyzzy_12345")
_, err := lookupWithGetent("nonexistent_user_xyzzy_12345")
require.Error(t, err, "should fail for nonexistent user")
}
@@ -31,7 +31,7 @@ func TestCurrentUserWithGetent(t *testing.T) {
stdUser, err := user.Current()
require.NoError(t, err)
u, err := CurrentUserWithGetent()
u, err := currentUserWithGetent()
require.NoError(t, err)
assert.Equal(t, stdUser.Uid, u.Uid)
assert.Equal(t, stdUser.Username, u.Username)
@@ -41,7 +41,7 @@ func TestGroupIdsWithFallback_CurrentUser(t *testing.T) {
current, err := user.Current()
require.NoError(t, err)
groups, err := GroupIdsWithFallback(current)
groups, err := groupIdsWithFallback(current)
require.NoError(t, err)
require.NotEmpty(t, groups, "current user should have at least one group")
@@ -56,7 +56,7 @@ func TestGroupIdsWithFallback_CurrentUser(t *testing.T) {
func TestGetShellFromGetent_CurrentUser(t *testing.T) {
if runtime.GOOS == "windows" {
// Windows stub always returns empty, which is correct
shell := GetShellFromGetent("1000")
shell := getShellFromGetent("1000")
assert.Empty(t, shell, "Windows stub should return empty")
return
}
@@ -65,7 +65,7 @@ func TestGetShellFromGetent_CurrentUser(t *testing.T) {
require.NoError(t, err)
// getent may not be available on all systems (e.g., macOS without Homebrew getent)
shell := GetShellFromGetent(current.Uid)
shell := getShellFromGetent(current.Uid)
if shell == "" {
t.Log("getShellFromGetent returned empty, getent may not be available")
return
@@ -78,7 +78,7 @@ func TestLookupWithGetent_RootUser(t *testing.T) {
t.Skip("no root user on Windows")
}
u, err := LookupWithGetent("root")
u, err := lookupWithGetent("root")
if err != nil {
t.Skip("root user not available on this system")
}
@@ -91,20 +91,20 @@ func TestLookupWithGetent_RootUser(t *testing.T) {
// consistent and correct results when composed together.
func TestIntegration_FullLookupChain(t *testing.T) {
// Step 1: currentUserWithGetent must resolve the running user.
current, err := CurrentUserWithGetent()
current, err := currentUserWithGetent()
require.NoError(t, err, "currentUserWithGetent must resolve the running user")
require.NotEmpty(t, current.Uid)
require.NotEmpty(t, current.Username)
// Step 2: lookupWithGetent by the same username must return matching identity.
byName, err := LookupWithGetent(current.Username)
byName, err := lookupWithGetent(current.Username)
require.NoError(t, err)
assert.Equal(t, current.Uid, byName.Uid, "lookup by name should return same UID")
assert.Equal(t, current.Gid, byName.Gid, "lookup by name should return same GID")
assert.Equal(t, current.HomeDir, byName.HomeDir, "lookup by name should return same home")
// Step 3: groupIdsWithFallback must return at least the primary GID.
groups, err := GroupIdsWithFallback(current)
groups, err := groupIdsWithFallback(current)
require.NoError(t, err)
require.NotEmpty(t, groups, "user must have at least one group")
@@ -123,7 +123,7 @@ func TestIntegration_FullLookupChain(t *testing.T) {
// Step 4: getShellFromGetent should either return a valid shell path or empty
// (empty is OK when getent is not available, e.g. macOS without Homebrew getent).
if runtime.GOOS != "windows" {
shell := GetShellFromGetent(current.Uid)
shell := getShellFromGetent(current.Uid)
if shell != "" {
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
}
@@ -138,10 +138,10 @@ func TestIntegration_LookupAndGroupsConsistency(t *testing.T) {
require.NoError(t, err)
// Simulate the SSH server flow: lookup user, then get their groups.
resolved, err := LookupWithGetent(current.Username)
resolved, err := lookupWithGetent(current.Username)
require.NoError(t, err)
groups, err := GroupIdsWithFallback(resolved)
groups, err := groupIdsWithFallback(resolved)
require.NoError(t, err)
require.NotEmpty(t, groups, "resolved user must have groups")
@@ -154,3 +154,19 @@ func TestIntegration_LookupAndGroupsConsistency(t *testing.T) {
}
}
}
// TestIntegration_ShellLookupChain tests the full shell resolution chain
// (getShellFromPasswd -> getShellFromGetent -> $SHELL -> default) on Unix.
func TestIntegration_ShellLookupChain(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Unix shell lookup not applicable on Windows")
}
current, err := user.Current()
require.NoError(t, err)
// getUserShell is the top-level function used by the SSH server.
shell := getUserShell(current.Uid)
require.NotEmpty(t, shell, "getUserShell must always return a shell")
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
}

View File

@@ -1,6 +1,6 @@
//go:build !windows
package shell
package server
import (
"context"
@@ -14,25 +14,19 @@ import (
const getentTimeout = 5 * time.Second
// GetShellFromGetent gets a user's login shell via getent by UID.
// getShellFromGetent gets a user's login shell via getent by UID.
// This is needed even with CGO because getShellFromPasswd reads /etc/passwd
// directly and won't find NSS-provided users there.
func GetShellFromGetent(userID string) string {
_, shell, err := runGetentPasswd(userID)
func getShellFromGetent(userID string) string {
_, shell, err := runGetent(userID)
if err != nil {
return ""
}
return shell
}
// GetUserFromGetent returns the resolved group from either a uid or username
func GetUserFromGetent(user string) (*user.User, error) {
u, _, err := runGetentPasswd(user)
return u, err
}
// runGetentPasswd executes `getent passwd <query>` and returns the user and login shell.
func runGetentPasswd(query string) (*user.User, string, error) {
// runGetent executes `getent passwd <query>` and returns the user and login shell.
func runGetent(query string) (*user.User, string, error) {
if !validateGetentInput(query) {
return nil, "", fmt.Errorf("invalid getent input: %q", query)
}
@@ -48,23 +42,6 @@ func runGetentPasswd(query string) (*user.User, string, error) {
return parseGetentPasswd(string(out))
}
// runGetentGroup executes `getent group <query>` and returns the group
func runGetentGroup(query string) (*user.Group, error) {
if !validateGetentInput(query) {
return nil, fmt.Errorf("invalid getent input: %q", query)
}
ctx, cancel := context.WithTimeout(context.Background(), getentTimeout)
defer cancel()
out, err := exec.CommandContext(ctx, "getent", "group", query).Output()
if err != nil {
return nil, fmt.Errorf("getent passwd%s: %w", query, err)
}
return parseGetentGroup(string(out))
}
// parseGetentPasswd parses getent passwd output: "name:x:uid:gid:gecos:home:shell"
func parseGetentPasswd(output string) (*user.User, string, error) {
fields := strings.SplitN(strings.TrimSpace(output), ":", 8)
@@ -90,20 +67,6 @@ func parseGetentPasswd(output string) (*user.User, string, error) {
}, shell, nil
}
// parseGetentGroup parses getent group output: "group:x:gid:user"
func parseGetentGroup(output string) (*user.Group, error) {
fields := strings.SplitN(strings.TrimSpace(output), ":", 8)
if len(fields) < 4 {
return nil, fmt.Errorf("unexpected getent output (need 4+ fields): %q", output)
}
if fields[0] == "" || fields[2] == "" {
return nil, fmt.Errorf("missing required fields in getent output: %q", output)
}
return &user.Group{Gid: fields[2], Name: fields[0]}, nil
}
// validateGetentInput checks that the input is safe to pass to getent or id.
// Allows POSIX usernames, numeric UIDs, and common NSS extensions
// (@ for Kerberos, $ for Samba, + for NIS compat).

View File

@@ -1,6 +1,6 @@
//go:build !windows
package shell
package server
import (
"os/exec"
@@ -195,7 +195,7 @@ func TestRunGetent_RootUser(t *testing.T) {
t.Skip("getent not available on this system")
}
u, shell, err := runGetentPasswd("root")
u, shell, err := runGetent("root")
require.NoError(t, err)
assert.Equal(t, "root", u.Username)
assert.Equal(t, "0", u.Uid)
@@ -208,7 +208,7 @@ func TestRunGetent_ByUID(t *testing.T) {
t.Skip("getent not available on this system")
}
u, _, err := runGetentPasswd("0")
u, _, err := runGetent("0")
require.NoError(t, err)
assert.Equal(t, "root", u.Username)
assert.Equal(t, "0", u.Uid)
@@ -219,15 +219,15 @@ func TestRunGetent_NonexistentUser(t *testing.T) {
t.Skip("getent not available on this system")
}
_, _, err := runGetentPasswd("nonexistent_user_xyzzy_12345")
_, _, err := runGetent("nonexistent_user_xyzzy_12345")
assert.Error(t, err)
}
func TestRunGetent_InvalidInput(t *testing.T) {
_, _, err := runGetentPasswd("")
_, _, err := runGetent("")
assert.Error(t, err)
_, _, err = runGetentPasswd("user\x00name")
_, _, err = runGetent("user\x00name")
assert.Error(t, err)
}
@@ -236,7 +236,7 @@ func TestRunGetent_NotAvailable(t *testing.T) {
t.Skip("getent is available, can't test missing case")
}
_, _, err := runGetentPasswd("root")
_, _, err := runGetent("root")
assert.Error(t, err, "should fail when getent is not installed")
}
@@ -283,7 +283,7 @@ func TestGetentResultsMatchStdlib(t *testing.T) {
current, err := user.Current()
require.NoError(t, err)
getentUser, _, err := runGetentPasswd(current.Username)
getentUser, _, err := runGetent(current.Username)
require.NoError(t, err)
assert.Equal(t, current.Username, getentUser.Username, "username should match")
@@ -300,7 +300,7 @@ func TestGetentResultsMatchStdlib_ByUID(t *testing.T) {
current, err := user.Current()
require.NoError(t, err)
getentUser, _, err := runGetentPasswd(current.Uid)
getentUser, _, err := runGetent(current.Uid)
require.NoError(t, err)
assert.Equal(t, current.Username, getentUser.Username, "username should match when looked up by UID")
@@ -356,7 +356,7 @@ func TestGetShellFromPasswd_CurrentUser(t *testing.T) {
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
if _, err := exec.LookPath("getent"); err == nil {
_, getentShell, getentErr := runGetentPasswd(current.Uid)
_, getentShell, getentErr := runGetent(current.Uid)
if getentErr == nil && getentShell != "" {
assert.Equal(t, getentShell, shell, "shell from /etc/passwd should match getent")
}
@@ -400,7 +400,7 @@ func TestGetShellFromPasswd_MatchesGetentForKnownUsers(t *testing.T) {
continue
}
_, getentShell, err := runGetentPasswd(uid)
_, getentShell, err := runGetent(uid)
if err != nil {
continue
}

View File

@@ -1,26 +1,26 @@
//go:build windows
package shell
package server
import "os/user"
// lookupWithGetent on Windows just delegates to os/user.Lookup.
// Windows does not use NSS/getent; its user lookup works without CGO.
func LookupWithGetent(username string) (*user.User, error) {
func lookupWithGetent(username string) (*user.User, error) {
return user.Lookup(username)
}
// currentUserWithGetent on Windows just delegates to os/user.Current.
func CurrentUserWithGetent() (*user.User, error) {
func currentUserWithGetent() (*user.User, error) {
return user.Current()
}
// getShellFromGetent is a no-op on Windows; shell resolution uses PowerShell detection.
func GetShellFromGetent(_ string) string {
func getShellFromGetent(_ string) string {
return ""
}
// groupIdsWithFallback on Windows just delegates to u.GroupIds().
func GroupIdsWithFallback(u *user.User) ([]string, error) {
func groupIdsWithFallback(u *user.User) ([]string, error) {
return u.GroupIds()
}

View File

@@ -1,14 +1,17 @@
package shell
package server
import (
"bufio"
"fmt"
"net"
"os"
"os/exec"
"os/user"
"runtime"
"strconv"
"strings"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
@@ -21,7 +24,7 @@ const (
// getUserShell returns the appropriate shell for the given user ID
// Handles all platform-specific logic and fallbacks consistently
func GetUserShell(userID string) string {
func getUserShell(userID string) string {
switch runtime.GOOS {
case "windows":
return getWindowsUserShell()
@@ -53,7 +56,7 @@ func getUnixUserShell(userID string) string {
return shell
}
if shell := GetShellFromGetent(userID); shell != "" {
if shell := getShellFromGetent(userID); shell != "" {
return shell
}
@@ -98,8 +101,8 @@ func getShellFromPasswd(userID string) string {
return ""
}
// PrepareUserEnv prepares environment variables for user execution
func PrepareUserEnv(user *user.User, shell string) []string {
// prepareUserEnv prepares environment variables for user execution
func prepareUserEnv(user *user.User, shell string) []string {
pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games"
if runtime.GOOS == "windows" {
pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`
@@ -116,7 +119,7 @@ func PrepareUserEnv(user *user.User, shell string) []string {
// acceptEnv checks if environment variable from SSH client should be accepted
// This is a whitelist of variables that SSH clients can send to the server
func AcceptEnv(envVar string) bool {
func acceptEnv(envVar string) bool {
varName := envVar
if idx := strings.Index(envVar, "="); idx != -1 {
varName = envVar[:idx]
@@ -153,3 +156,29 @@ func AcceptEnv(envVar string) bool {
return false
}
// prepareSSHEnv prepares SSH protocol-specific environment variables
// These variables provide information about the SSH connection itself
func prepareSSHEnv(session ssh.Session) []string {
remoteAddr := session.RemoteAddr()
localAddr := session.LocalAddr()
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
if err != nil {
remoteHost = remoteAddr.String()
remotePort = "0"
}
localHost, localPort, err := net.SplitHostPort(localAddr.String())
if err != nil {
localHost = localAddr.String()
localPort = strconv.Itoa(InternalSSHPort)
}
return []string{
// SSH_CLIENT format: "client_ip client_port server_port"
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
}
}

View File

@@ -3,15 +3,11 @@ package server
import (
"errors"
"fmt"
"net"
"os"
"os/user"
"runtime"
"strconv"
"strings"
"github.com/gliderlabs/ssh"
"github.com/netbirdio/netbird/client/internal/shell"
log "github.com/sirupsen/logrus"
)
@@ -27,8 +23,8 @@ func isPlatformUnix() bool {
// Dependency injection variables for testing - allows mocking dynamic runtime checks
var (
getCurrentUser = shell.CurrentUserWithGetent
lookupUser = shell.LookupWithGetent
getCurrentUser = currentUserWithGetent
lookupUser = lookupWithGetent
getCurrentOS = func() string { return runtime.GOOS }
getIsProcessPrivileged = isCurrentProcessPrivileged
@@ -413,29 +409,3 @@ func isWindowsElevated() bool {
log.Debugf("Windows user switching not supported: not running as privileged user (current: %s)", currentUser.Uid)
return false
}
// prepareSSHEnv prepares SSH protocol-specific environment variables
// These variables provide information about the SSH connection itself
func prepareSSHEnv(session ssh.Session) []string {
remoteAddr := session.RemoteAddr()
localAddr := session.LocalAddr()
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
if err != nil {
remoteHost = remoteAddr.String()
remotePort = "0"
}
localHost, localPort, err := net.SplitHostPort(localAddr.String())
if err != nil {
localHost = localAddr.String()
localPort = strconv.Itoa(InternalSSHPort)
}
return []string{
// SSH_CLIENT format: "client_ip client_port server_port"
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
}
}

View File

@@ -15,7 +15,6 @@ import (
"strconv"
"github.com/gliderlabs/ssh"
"github.com/netbirdio/netbird/client/internal/shell"
log "github.com/sirupsen/logrus"
)
@@ -161,7 +160,7 @@ func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []u
// getSupplementaryGroups retrieves supplementary group IDs for a user.
// Uses id/getent fallback for NSS users in CGO_ENABLED=0 builds.
func (s *Server) getSupplementaryGroups(u *user.User) ([]uint32, error) {
groupIDStrings, err := shell.GroupIdsWithFallback(u)
groupIDStrings, err := groupIdsWithFallback(u)
if err != nil {
return nil, fmt.Errorf("get group IDs for user %s: %w", u.Username, err)
}
@@ -197,7 +196,7 @@ func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, l
GID: gid,
Groups: groups,
WorkingDir: localUser.HomeDir,
Shell: shell.GetUserShell(localUser.Uid),
Shell: getUserShell(localUser.Uid),
Command: session.RawCommand(),
PTY: hasPty,
}
@@ -229,7 +228,7 @@ func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq s
func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.User, ptyReq ssh.Pty) *exec.Cmd {
log.Debugf("creating direct Pty command for user %s (no user switching needed)", localUser.Username)
shell := shell.GetUserShell(localUser.Uid)
shell := getUserShell(localUser.Uid)
args := s.getShellCommandArgs(shell, session.RawCommand())
cmd := s.createShellCommand(session.Context(), shell, args)
@@ -246,12 +245,12 @@ func (s *Server) preparePtyEnv(localUser *user.User, ptyReq ssh.Pty, session ssh
termType = "xterm-256color"
}
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
env = append(env, fmt.Sprintf("TERM=%s", termType))
for _, v := range session.Environ() {
if shell.AcceptEnv(v) {
if acceptEnv(v) {
env = append(env, v)
}
}

View File

@@ -13,8 +13,6 @@ import (
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal/shell"
)
// validateUsername validates Windows usernames according to SAM Account Name rules
@@ -106,7 +104,7 @@ func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, l
func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) {
username, domain := s.parseUsername(localUser.Username)
sh := shell.GetUserShell(localUser.Uid)
shell := getUserShell(localUser.Uid)
rawCmd := session.RawCommand()
var command string
@@ -118,7 +116,7 @@ func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session,
Username: username,
Domain: domain,
WorkingDir: localUser.HomeDir,
Shell: sh,
Shell: shell,
Command: command,
}

View File

@@ -19,46 +19,6 @@ readonly MSG_SEPARATOR="=========================================="
# Utility Functions
############################################
check_docker_sock_perms() {
local sock="${DOCKER_HOST:-unix:///var/run/docker.sock}"
sock="${sock#unix://}"
if [[ ! -S "$sock" ]]; then
return 0
fi
if [[ ! -r "$sock" ]] || [[ ! -w "$sock" ]]; then
local group
if [[ "${OSTYPE}" == "darwin"* ]]; then
group="$(stat -f '%Sg' "$sock")"
else
group="$(stat -c '%G' "$sock")"
fi
echo "Cannot access Docker socket: $sock" > /dev/stderr
echo "" > /dev/stderr
echo "Socket permissions:" > /dev/stderr
ls -l "$sock" > /dev/stderr
echo "" > /dev/stderr
if [[ "$group" == "docker" ]]; then
echo "Your user may need to be added to the '$group' group:" > /dev/stderr
echo " sudo usermod -aG $group \"$USER\"" > /dev/stderr
echo "Then log out and back in, or run this for the current shell:" > /dev/stderr
echo " newgrp $group" > /dev/stderr
echo "Note: newgrp is temporary; usermod is the permanent group change." > /dev/stderr
else
echo "The Docker socket is owned by the '$group' group, which is not the standard 'docker' group." > /dev/stderr
echo "For safety, this script will not suggest adding your user to '$group'." > /dev/stderr
echo "Instead, either run this script with appropriate privileges (for example, via sudo) or follow Docker's post-install steps to configure access via the 'docker' group:" > /dev/stderr
echo " https://docs.docker.com/engine/install/linux-postinstall/" > /dev/stderr
fi
exit 1
fi
return 0
}
check_docker_compose() {
if command -v docker-compose &> /dev/null
then
@@ -621,15 +581,12 @@ start_services_and_show_instructions() {
}
init_environment() {
# Check if docker compose is installed using check_docker_compose function
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
check_docker_sock_perms
initialize_default_values
configure_domain
configure_reverse_proxy
check_jq
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
check_existing_installation
generate_configuration_files

View File

@@ -120,7 +120,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount)
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.InstanceManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.AuthMiddleware())
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
@@ -153,6 +153,20 @@ func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
})
}
func (s *BaseServer) AuthMiddleware() mux.MiddlewareFunc {
return Create(s, func() mux.MiddlewareFunc {
m := middleware.NewAuthMiddleware(
s.AuthManager(),
s.AccountManager().GetAccountIDFromUserAuth,
s.AccountManager().SyncUserJWTGroups,
s.AccountManager().GetUserFromUserAuth,
s.RateLimiter(),
s.Metrics().GetMeter(),
)
return m.Handler
})
}
func (s *BaseServer) GRPCServer() *grpc.Server {
return Create(s, func() *grpc.Server {
trustedPeers := s.Config.ReverseProxy.TrustedPeers

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
@@ -151,6 +152,16 @@ func (s *BaseServer) IdpManager() idp.Manager {
})
}
func (s *BaseServer) InstanceManager() instance.Manager {
return Create(s, func() instance.Manager {
m, err := instance.NewManager(context.Background(), s.Store(), s.IdpManager())
if err != nil {
log.Fatalf("failed to create instance manager: %v", err)
}
return m
})
}
// OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
@@ -229,6 +240,3 @@ func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
})
}
func (s *BaseServer) IsValidChildAccount(_ context.Context, _, _, _ string) bool {
return false
}

View File

@@ -12,8 +12,6 @@ import (
goproto "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
@@ -147,9 +145,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
Checks: toProtocolChecks(ctx, checks),
}
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
response.NetbirdConfig = extendedConfig
response.NetbirdConfig = toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
response.NetworkMap.PeerConfig = response.PeerConfig
@@ -363,7 +359,6 @@ func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSource
return result
}
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {

View File

@@ -978,7 +978,6 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
Mode: m.Mode,
ListenPort: m.ListenPort,
AccessRestrictions: m.AccessRestrictions,
Private: m.Private,
}
}

View File

@@ -1,88 +0,0 @@
package grpc
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/proto"
)
// authTokenField is the only per-proxy field that shallowCloneMapping must NOT
// copy from the source, since callers assign it individually after cloning.
const authTokenField = "AuthToken"
// TestShallowCloneMapping_ClonesAllFields populates every exported field of
// ProxyMapping with a non-zero value and verifies the clone carries each one
// (except AuthToken). It uses reflection so adding a new field to ProxyMapping
// without updating shallowCloneMapping fails this test.
func TestShallowCloneMapping_ClonesAllFields(t *testing.T) {
src := &proto.ProxyMapping{}
populated := populateExportedFields(t, reflect.ValueOf(src).Elem())
require.NotEmpty(t, populated, "ProxyMapping should expose fields to populate")
clone := shallowCloneMapping(src)
require.NotNil(t, clone, "clone must not be nil")
srcVal := reflect.ValueOf(src).Elem()
cloneVal := reflect.ValueOf(clone).Elem()
for _, name := range populated {
srcField := srcVal.FieldByName(name).Interface()
cloneField := cloneVal.FieldByName(name).Interface()
if name == authTokenField {
assert.Zero(t, cloneField, "AuthToken must not be cloned; it is set per proxy after cloning")
continue
}
assert.Equal(t, srcField, cloneField, "field %s must be carried over by shallowCloneMapping", name)
}
}
// populateExportedFields sets a non-zero value on every settable exported field
// of the struct and returns their names.
func populateExportedFields(t *testing.T, v reflect.Value) []string {
t.Helper()
var names []string
typ := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
structField := typ.Field(i)
if structField.PkgPath != "" || !field.CanSet() {
continue
}
setNonZero(t, field, structField.Name)
names = append(names, structField.Name)
}
return names
}
// setNonZero assigns a deterministic non-zero value based on the field kind.
func setNonZero(t *testing.T, field reflect.Value, name string) {
t.Helper()
switch field.Kind() {
case reflect.String:
field.SetString("non-zero-" + name)
case reflect.Bool:
field.SetBool(true)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.SetInt(7)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.SetUint(7)
case reflect.Ptr:
field.Set(reflect.New(field.Type().Elem()))
case reflect.Slice:
field.Set(reflect.MakeSlice(field.Type(), 1, 1))
case reflect.Map:
field.Set(reflect.MakeMapWithSize(field.Type(), 0))
default:
t.Fatalf("unhandled field kind %s for field %s; extend setNonZero", field.Kind(), name)
}
}

View File

@@ -12,7 +12,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
@@ -41,8 +40,6 @@ type TimeBasedAuthSecretsManager struct {
turnHmacToken *auth.TimedHMAC
relayHmacToken *authv2.Generator
updateManager network_map.PeersUpdateManager
settingsManager settings.Manager
groupsManager groups.Manager
turnCancelMap map[string]chan struct{}
relayCancelMap map[string]chan struct{}
wgKey wgtypes.Key
@@ -62,8 +59,6 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager
relayCfg: relayCfg,
turnCancelMap: make(map[string]chan struct{}),
relayCancelMap: make(map[string]chan struct{}),
settingsManager: settingsManager,
groupsManager: groupsManager,
wgKey: key,
}
@@ -239,8 +234,6 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
}
}
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
Update: update,
@@ -266,26 +259,9 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
},
}
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeControlConfig,
})
}
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
extraSettings, err := m.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
}
peerGroups, err := m.groupsManager.GetPeerGroupIDs(ctx, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peer groups: %v", err)
}
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, peerGroups, update.NetbirdConfig, extraSettings)
update.NetbirdConfig = extendedConfig
}

View File

@@ -8,7 +8,6 @@ import (
"github.com/gorilla/mux"
"github.com/rs/cors"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
@@ -20,7 +19,6 @@ import (
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
idpmanager "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/zones"
@@ -34,7 +32,6 @@ import (
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/geolocation"
nbgroups "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/http/handlers/accounts"
@@ -49,7 +46,6 @@ import (
"github.com/netbirdio/netbird/management/server/http/handlers/routes"
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
"github.com/netbirdio/netbird/management/server/http/handlers/users"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
@@ -59,7 +55,7 @@ import (
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, instanceManager nbinstance.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, authMiddleware mux.MiddlewareFunc) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -80,32 +76,11 @@ func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager accou
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
if rateLimiter == nil {
log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
rateLimiter = middleware.NewAPIRateLimiter(nil)
rateLimiter.SetEnabled(false)
}
authMiddleware := middleware.NewAuthMiddleware(
authManager,
accountManager.GetAccountIDFromUserAuth,
accountManager.SyncUserJWTGroups,
accountManager.GetUserFromUserAuth,
rateLimiter,
appMetrics.GetMeter(),
isValidChildAccount,
)
corsMiddleware := cors.AllowAll()
metricsMiddleware := appMetrics.HTTPMiddleware()
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler)
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), idpManager)
if err != nil {
return nil, fmt.Errorf("failed to create instance manager: %w", err)
}
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware)
accounts.AddEndpoints(accountManager, settingsManager, router)
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)

View File

@@ -8,6 +8,7 @@ import (
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
@@ -25,7 +26,8 @@ type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) err
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
type IsValidChildAccountFunc func(ctx context.Context, userID, accountID, childAccountID string) bool
// jwtTokenCtxKey carries the parsed JWT token.
type jwtTokenCtxKey struct{}
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
@@ -35,7 +37,6 @@ type AuthMiddleware struct {
syncUserJWTGroups SyncUserJWTGroupsFunc
rateLimiter *APIRateLimiter
patUsageTracker *PATUsageTracker
isValidChildAccount IsValidChildAccountFunc
}
// NewAuthMiddleware instance constructor
@@ -46,7 +47,6 @@ func NewAuthMiddleware(
getUserFromUserAuth GetUserFromUserAuthFunc,
rateLimiter *APIRateLimiter,
meter metric.Meter,
isValidChildAccount IsValidChildAccountFunc,
) *AuthMiddleware {
var patUsageTracker *PATUsageTracker
if meter != nil {
@@ -64,12 +64,18 @@ func NewAuthMiddleware(
getUserFromUserAuth: getUserFromUserAuth,
rateLimiter: rateLimiter,
patUsageTracker: patUsageTracker,
isValidChildAccount: isValidChildAccount,
}
}
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
// Handler composes the full authentication chain by wrapping the given
// handler with ValidationHandler followed by AccountAccessHandler.
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
return m.ValidationHandler(m.AccountAccessHandler(h))
}
// ValidationHandler authenticates the caller via JWT or PAT and stores the
// resulting UserAuth in the request context. It performs no account-level work.
func (m *AuthMiddleware) ValidationHandler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
return
@@ -86,14 +92,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
switch authType {
case "bearer":
if err := m.checkJWTFromRequest(r, authHeader); err != nil {
if err := m.validateJWT(r, authHeader); err != nil {
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
case "token":
if err := m.checkPATFromRequest(r, authHeader); err != nil {
if err := m.validatePAT(r, authHeader); err != nil {
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
// Check if it's a status error, otherwise default to Unauthorized
if _, ok := status.FromError(err); !ok {
@@ -110,66 +116,55 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
})
}
// CheckJWTFromRequest checks if the JWT is valid
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error {
token, err := getTokenFromJWTRequest(authHeaderParts)
// AccountAccessHandler runs post-validation access checks for JWT-authenticated
// requests. PAT requests pass through unchanged.
func (m *AuthMiddleware) AccountAccessHandler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
return
}
// If an error occurs, call the error handler and return an error
userAuth, err := nbcontext.GetUserAuthFromRequest(r)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return
}
if userAuth.IsPAT {
h.ServeHTTP(w, r)
return
}
validatedToken, _ := r.Context().Value(jwtTokenCtxKey{}).(*jwt.Token)
if err := m.applyAccountAccess(r, userAuth, validatedToken); err != nil {
log.WithContext(r.Context()).Errorf("Error applying JWT account access: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
})
}
func (m *AuthMiddleware) validateJWT(r *http.Request, authHeaderParts []string) error {
token, err := getTokenFromJWTRequest(authHeaderParts)
if err != nil {
return fmt.Errorf("error extracting token: %w", err)
}
ctx := r.Context()
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(r.Context(), token)
if err != nil {
return err
}
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
if m.isValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
userAuth.AccountId = impersonate[0]
userAuth.IsChild = true
}
}
// Email is now extracted in ToUserAuth (from claims or userinfo endpoint)
// Available as userAuth.Email
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
accountId, _, err := m.ensureAccount(ctx, userAuth)
if err != nil {
return err
}
if userAuth.AccountId != accountId {
log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
userAuth.AccountId = accountId
}
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
if err != nil {
return err
}
err = m.syncUserJWTGroups(ctx, userAuth)
if err != nil {
log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err)
}
_, err = m.getUserFromUserAuth(ctx, userAuth)
if err != nil {
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
return err
}
// propagates ctx change to upstream middleware
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
*r = *r.WithContext(context.WithValue(r.Context(), jwtTokenCtxKey{}, validatedToken))
return nil
}
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error {
func (m *AuthMiddleware) validatePAT(r *http.Request, authHeaderParts []string) error {
token, err := getTokenFromPATRequest(authHeaderParts)
if err != nil {
return fmt.Errorf("error extracting token: %w", err)
@@ -192,8 +187,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
return fmt.Errorf("token expired")
}
err = m.authManager.MarkPATUsed(ctx, pat.ID)
if err != nil {
if err := m.authManager.MarkPATUsed(ctx, pat.ID); err != nil {
return err
}
@@ -205,11 +199,40 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
IsPAT: true,
}
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
if m.isValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
userAuth.AccountId = impersonate[0]
userAuth.IsChild = true
}
// propagates ctx change to upstream middleware
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
return nil
}
// applyAccountAccess executes account-level checks for an authenticated JWT
// user: ensures the account exists, verifies access via JWT groups, syncs
// groups, and fetches the user record.
func (m *AuthMiddleware) applyAccountAccess(r *http.Request, userAuth auth.UserAuth, validatedToken *jwt.Token) error {
ctx := r.Context()
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
accountId, _, err := m.ensureAccount(ctx, userAuth)
if err != nil {
return err
}
if userAuth.AccountId != accountId {
log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
userAuth.AccountId = accountId
}
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
if err != nil {
return err
}
if err := m.syncUserJWTGroups(ctx, userAuth); err != nil {
log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err)
}
if _, err := m.getUserFromUserAuth(ctx, userAuth); err != nil {
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
return err
}
// propagates ctx change to upstream middleware

View File

@@ -211,7 +211,6 @@ func TestAuthMiddleware_Handler(t *testing.T) {
},
disabledLimiter,
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handlerToTest := authMiddleware.Handler(nextHandler)
@@ -271,7 +270,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -324,7 +322,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -368,7 +365,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -413,7 +409,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -478,7 +473,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -538,7 +532,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -594,7 +587,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -695,7 +687,6 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
},
disabledLimiter,
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
for _, tc := range tt {

View File

@@ -40,6 +40,7 @@ import (
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
http2 "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
@@ -136,8 +137,12 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
rateLimiter := middleware.NewAPIRateLimiter(nil)
rateLimiter.SetEnabled(false)
authMiddleware := middleware.NewAuthMiddleware(authManagerMock, am.GetAccountIDFromUserAuth, am.SyncUserJWTGroups, am.GetUserFromUserAuth, rateLimiter, metrics.GetMeter())
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, authMiddleware.Handler)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
@@ -266,8 +271,12 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
rateLimiter := middleware.NewAPIRateLimiter(nil)
rateLimiter.SetEnabled(false)
authMiddleware := middleware.NewAuthMiddleware(authManagerMock, am.GetAccountIDFromUserAuth, am.SyncUserJWTGroups, am.GetUserFromUserAuth, rateLimiter, metrics.GetMeter())
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, authMiddleware.Handler)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}

View File

@@ -1216,7 +1216,6 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
Preload("NetworkResources").
Preload("Onboarding").
Preload("Services.Targets").
Preload("Domains").
Take(&account, idQueryCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
@@ -1303,7 +1302,7 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
}
var wg sync.WaitGroup
errChan := make(chan error, 16)
errChan := make(chan error, 12)
wg.Add(1)
go func() {
@@ -1404,17 +1403,6 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
account.Services = services
}()
wg.Add(1)
go func() {
defer wg.Done()
domains, err := s.ListCustomDomains(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.Domains = domains
}()
wg.Add(1)
go func() {
defer wg.Done()

View File

@@ -4,8 +4,6 @@ import (
"context"
"net"
"net/netip"
"os"
"runtime"
"testing"
"time"
@@ -23,63 +21,6 @@ import (
"github.com/netbirdio/netbird/route"
)
// TestGetAccount_LoadsCustomDomains verifies GetAccount populates account.Domains.
// SynthesizePrivateServiceZones depends on this relation to anchor a custom-domain
// private service's DNS zone; without the preload the relation is empty and the
// service is silently skipped, so a custom domain never resolves on clients.
func TestGetAccount_LoadsCustomDomains(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err)
defer cleanup()
assertGetAccountLoadsCustomDomains(t, store)
}
func TestPostgresql_GetAccount_LoadsCustomDomains(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanup)
assertGetAccountLoadsCustomDomains(t, store)
}
// assertGetAccountLoadsCustomDomains exercises both the gorm and pgx GetAccount
// paths: it persists two custom domains and asserts the relation comes back
// populated, which SynthesizePrivateServiceZones relies on.
func assertGetAccountLoadsCustomDomains(t *testing.T, store Store) {
t.Helper()
ctx := context.Background()
accountID := "acct-custom-domains"
require.NoError(t, store.SaveAccount(ctx, newAccountWithId(ctx, accountID, "user-1", "")))
_, err := store.CreateCustomDomain(ctx, accountID, "example.com", "eu.proxy.netbird.io", true)
require.NoError(t, err, "creating the first custom domain must succeed")
_, err = store.CreateCustomDomain(ctx, accountID, "apps.acme.io", "us.proxy.netbird.io", false)
require.NoError(t, err, "creating the second custom domain must succeed")
account, err := store.GetAccount(ctx, accountID)
require.NoError(t, err)
require.Len(t, account.Domains, 2, "GetAccount must preload the account's custom domains")
byDomain := map[string]string{}
for _, d := range account.Domains {
require.NotNil(t, d)
byDomain[d.Domain] = d.TargetCluster
}
assert.Equal(t, "eu.proxy.netbird.io", byDomain["example.com"], "custom domain must carry its target cluster")
assert.Equal(t, "us.proxy.netbird.io", byDomain["apps.acme.io"], "custom domain must carry its target cluster")
}
// TestGetAccount_ComprehensiveFieldValidation validates that GetAccount properly loads
// all fields and nested objects from the database, including deeply nested structures.
func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) {

View File

@@ -273,7 +273,7 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
}
peerGroups := a.GetPeerGroups(peerID)
zonesByApex := map[string]*nbdns.CustomZone{}
zonesByCluster := map[string]*nbdns.CustomZone{}
for _, svc := range a.Services {
if svc == nil || !svc.Enabled || !svc.Private {
@@ -290,24 +290,19 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
continue
}
serviceDomainZone := a.privateServiceDomainZone(svc)
if serviceDomainZone == "" {
continue
}
zone, exists := zonesByApex[serviceDomainZone]
zone, exists := zonesByCluster[svc.ProxyCluster]
if !exists {
// NonAuthoritative makes this a match-only zone: queries for
// names without an explicit record fall through to the
// upstream resolver instead of returning NXDOMAIN. Without
// it, adding a single private service would black-hole every
// other name under the zone apex.
// other name under the cluster apex.
zone = &nbdns.CustomZone{
Domain: dns.Fqdn(serviceDomainZone),
Domain: dns.Fqdn(svc.ProxyCluster),
Records: []nbdns.SimpleRecord{},
NonAuthoritative: true,
}
zonesByApex[serviceDomainZone] = zone
zonesByCluster[svc.ProxyCluster] = zone
}
emitted := 0
@@ -345,8 +340,8 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
}
}
out := make([]nbdns.CustomZone, 0, len(zonesByApex))
for _, zone := range zonesByApex {
out := make([]nbdns.CustomZone, 0, len(zonesByCluster))
for _, zone := range zonesByCluster {
if len(zone.Records) == 0 {
continue
}
@@ -362,33 +357,6 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
return out
}
// privateServiceDomainZone returns the DNS zone name for the given private service domain by
// looking at the proxy cluster domain then the custom domains.
func (a *Account) privateServiceDomainZone(svc *service.Service) string {
if domainFromSuffix(svc.Domain, svc.ProxyCluster) {
return svc.ProxyCluster
}
// Longest matching custom domain wins
zoneName := ""
for _, d := range a.Domains {
if d == nil || d.TargetCluster != svc.ProxyCluster {
continue
}
if domainFromSuffix(svc.Domain, d.Domain) && len(d.Domain) > len(zoneName) {
zoneName = d.Domain
}
}
return zoneName
}
func domainFromSuffix(domain, suffix string) bool {
if suffix == "" {
return false
}
return domain == suffix || strings.HasSuffix(domain, "."+suffix)
}
// peerInDistributionGroups reports whether any of the peer's groups
// matches the service's bearer-auth distribution_groups.
func peerInDistributionGroups(peerGroups LookupMap, distributionGroups []string) bool {

View File

@@ -11,7 +11,6 @@ import (
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
@@ -235,113 +234,6 @@ func TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone(t *testi
assert.False(t, ok, "peer outside the distribution_groups must not see the synth zone")
}
func TestSynthesizePrivateServiceZones_CustomDomain_ZoneApexIsRegisteredDomain(t *testing.T) {
account := privateZoneTestAccount(t)
// A custom-domain service: Domain is the custom FQDN, ProxyCluster
// is the cluster serving it, and account.Domains holds the registered
// custom domain. The synth zone apex must be the registered domain,
// not the cluster, or the client's match-only zone never intercepts
// the query.
account.Services[0].Domain = "app.example.com"
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
}
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 1, "custom-domain service must still produce one zone")
zone := zones[0]
assert.Equal(t, "example.com.", zone.Domain, "zone apex must be the registered custom domain, not the cluster or the service FQDN")
assert.True(t, zone.NonAuthoritative, "synth zone must remain match-only")
require.Len(t, zone.Records, 1, "custom-domain service yields one A record")
rec := zone.Records[0]
assert.Equal(t, "app.example.com.", rec.Name, "record name is the custom service FQDN")
assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP")
}
func TestSynthesizePrivateServiceZones_CustomAndFreeDomain_SeparateZones(t *testing.T) {
account := privateZoneTestAccount(t)
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
}
account.Services = append(account.Services, &service.Service{
ID: "svc-2",
AccountID: "acct-1",
Name: "custom",
Domain: "app.example.com",
ProxyCluster: "eu.proxy.netbird.io",
Enabled: true,
Private: true,
Mode: service.ModeHTTP,
AccessGroups: []string{"grp-admins"},
})
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 2, "a free-domain and a custom-domain service must not collapse into one zone")
free, ok := findCustomZone(zones, "eu.proxy.netbird.io")
require.True(t, ok, "free-domain service keeps the shared cluster-apex zone")
require.Len(t, free.Records, 1, "cluster zone carries only the free-domain record")
assert.Equal(t, "myapp.eu.proxy.netbird.io.", free.Records[0].Name, "cluster zone record is the free-domain FQDN")
custom, ok := findCustomZone(zones, "example.com")
require.True(t, ok, "custom-domain service gets its own zone at the registered custom domain apex")
require.Len(t, custom.Records, 1, "custom zone carries only the custom-domain record")
assert.Equal(t, "app.example.com.", custom.Records[0].Name, "custom zone record is the custom-domain FQDN")
}
func TestSynthesizePrivateServiceZones_TwoServicesSameCustomDomain_OneZone(t *testing.T) {
account := privateZoneTestAccount(t)
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
}
account.Services[0].Domain = "a.example.com"
account.Services = append(account.Services, &service.Service{
ID: "svc-2",
AccountID: "acct-1",
Name: "bapp",
Domain: "b.example.com",
ProxyCluster: "eu.proxy.netbird.io",
Enabled: true,
Private: true,
Mode: service.ModeHTTP,
AccessGroups: []string{"grp-admins"},
})
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 1, "two services under the same registered custom domain must share one zone")
assert.Equal(t, "example.com.", zones[0].Domain, "shared zone apex is the registered custom domain")
require.Len(t, zones[0].Records, 2, "both services surface as records in the shared custom-domain zone")
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
assert.ElementsMatch(t, []string{"a.example.com.", "b.example.com."}, names, "both custom-domain service FQDNs must surface")
}
func TestSynthesizePrivateServiceZones_CustomDomainNotRegistered_NoZone(t *testing.T) {
account := privateZoneTestAccount(t)
// Service domain is outside the cluster and no account.Domains entry
// covers it: there is no apex that would intercept the query, so the
// service must be skipped rather than emit an unmatchable record.
account.Services[0].Domain = "app.example.com"
zones := account.SynthesizePrivateServiceZones("user-peer")
assert.Empty(t, zones, "a custom-domain service with no registered domain apex must not produce a zone")
}
func TestSynthesizePrivateServiceZones_CustomDomainClusterMismatch_NoZone(t *testing.T) {
account := privateZoneTestAccount(t)
// The registered custom domain matches the service FQDN by suffix but
// targets a different cluster than the service's ProxyCluster. It must
// be ignored, leaving no apex to intercept the query — otherwise the
// zone would point at this cluster's proxy peers under a domain owned
// by a different cluster.
account.Services[0].Domain = "app.example.com"
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "us.proxy.netbird.io", Validated: true},
}
zones := account.SynthesizePrivateServiceZones("user-peer")
assert.Empty(t, zones, "a custom domain targeting a different cluster must not anchor the service zone")
}
func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing.T) {
account := privateZoneTestAccount(t)
account.Services = append(account.Services, &service.Service{
@@ -362,72 +254,3 @@ func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
assert.ElementsMatch(t, []string{"myapp.eu.proxy.netbird.io.", "anotherapp.eu.proxy.netbird.io."}, names, "both service domains must surface")
}
func TestSynthesizePrivateServiceZones_MixedClusterCustomAndPublic(t *testing.T) {
account := privateZoneTestAccount(t)
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
}
privateService := func(id, domain string) *service.Service {
return &service.Service{
ID: id,
AccountID: "acct-1",
Name: id,
Domain: domain,
ProxyCluster: "eu.proxy.netbird.io",
Enabled: true,
Private: true,
Mode: service.ModeHTTP,
AccessGroups: []string{"grp-admins"},
}
}
publicService := func(id, domain string) *service.Service {
s := privateService(id, domain)
s.Private = false
return s
}
account.Services = []*service.Service{
// 3 private services under the cluster suffix.
privateService("cluster-1", "cluster1.eu.proxy.netbird.io"),
privateService("cluster-2", "cluster2.eu.proxy.netbird.io"),
privateService("cluster-3", "cluster3.eu.proxy.netbird.io"),
// 4 private services under the custom domain suffix.
privateService("custom-1", "custom1.example.com"),
privateService("custom-2", "custom2.example.com"),
privateService("custom-3", "custom3.example.com"),
privateService("custom-4", "custom4.example.com"),
// 2 public services, one per suffix, must not surface.
publicService("public-cluster", "public.eu.proxy.netbird.io"),
publicService("public-custom", "public.example.com"),
}
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 2, "one zone per apex: the cluster apex and the custom domain apex")
cluster, ok := findCustomZone(zones, "eu.proxy.netbird.io")
require.True(t, ok, "cluster-suffix services collapse into the cluster-apex zone")
clusterNames := recordNames(cluster)
assert.ElementsMatch(t,
[]string{"cluster1.eu.proxy.netbird.io.", "cluster2.eu.proxy.netbird.io.", "cluster3.eu.proxy.netbird.io."},
clusterNames,
"only the 3 private cluster services surface in the cluster zone (public one excluded)")
custom, ok := findCustomZone(zones, "example.com")
require.True(t, ok, "custom-suffix services collapse into the custom-domain-apex zone")
customNames := recordNames(custom)
assert.ElementsMatch(t,
[]string{"custom1.example.com.", "custom2.example.com.", "custom3.example.com.", "custom4.example.com."},
customNames,
"only the 4 private custom services surface in the custom zone (public one excluded)")
}
// recordNames returns the record names of a zone for order-independent assertions.
func recordNames(zone nbdns.CustomZone) []string {
names := make([]string, 0, len(zone.Records))
for _, r := range zone.Records {
names = append(names, r.Name)
}
return names
}

View File

@@ -417,30 +417,15 @@ if type uname >/dev/null 2>&1; then
# Check the availability of a compatible package manager
if check_use_bin_variable; then
PACKAGE_MANAGER="bin"
elif [ -e /run/ostree-booted ]; then
if [ -x "$(command -v rpm-ostree)" ]; then
PACKAGE_MANAGER="rpm-ostree"
echo "The installation will be performed using rpm-ostree package manager"
elif [ -x "$(command -v bootc)" ]; then
echo "Detected bootc system without rpm-ostree." >&2
echo "NetBird cannot be installed via package manager on this system." >&2
echo "Options:" >&2
echo " 1. Install via Distrobox (instructions in the installation docs)" >&2
echo " 2. Rebuild your base image with rpm-ostree included" >&2
echo " 3. Bake NetBird into your Containerfile" >&2
exit 1
else
echo "Detected ostree-booted system without rpm-ostree or bootc." >&2
echo "NetBird cannot be installed automatically on this atomic system." >&2
echo "Please install NetBird by rebuilding your base image or use a supported package manager." >&2
exit 1
fi
elif [ -x "$(command -v apt-get)" ]; then
PACKAGE_MANAGER="apt"
echo "The installation will be performed using apt package manager"
elif [ -x "$(command -v dnf)" ]; then
PACKAGE_MANAGER="dnf"
echo "The installation will be performed using dnf package manager"
elif [ -x "$(command -v rpm-ostree)" ]; then
PACKAGE_MANAGER="rpm-ostree"
echo "The installation will be performed using rpm-ostree package manager"
elif [ -x "$(command -v yum)" ]; then
PACKAGE_MANAGER="yum"
echo "The installation will be performed using yum package manager"