Compare commits

...

22 Commits

Author SHA1 Message Date
Zoltan Papp
08b782d6ba [client] Fix update download url (#5023) 2026-01-03 20:05:38 +03:00
Maycon Santos
80a312cc9c [client] add verbose flag for free ad tests (#5021)
add verbose flag for free ad tests
2026-01-03 11:32:41 +01:00
Zoltan Papp
9ba067391f [client] Fix semaphore slot leaks (#5018)
- Remove WaitGroup, make SemaphoreGroup a pure semaphore
- Make Add() return error instead of silently failing on context cancel
- Remove context parameter from Done() to prevent slot leaks
- Fix missing Done() call in conn.go error path
2026-01-03 09:10:02 +01:00
Pascal Fischer
7ac65bf1ad [management] Fix/delete groups without lock (#5012) 2025-12-31 11:53:20 +01:00
Zoltan Papp
2e9c316852 Fix UI stuck in "Connecting" state when daemon reports "Connected" status. (#5014)
The UI can get stuck showing "Connecting" status even after the daemon successfully connects and reports "Connected" status. This occurs because the condition to update the UI to "Connected" state checks the wrong flag.
2025-12-31 11:50:43 +01:00
shuuri-labs
96cdd56902 Feat/add support for forcing device auth flow on ios (#4944)
* updates to client file writing

* numerous

* minor

* - Align OnLoginSuccess behavior with Android (only call on nil error)
- Remove verbose debug logging from WaitToken in device_flow.go
- Improve TUN FD=0 fallback comments and warning messages
- Document why config save after login differs from Android

* Add nolint directive for staticcheck SA1029 in login.go

* Fix CodeRabbit review issues for iOS/tvOS SDK

- Remove goroutine from OnLoginSuccess callback, invoke synchronously
- Stop treating PermissionDenied as success, propagate as permanent error
- Replace context.TODO() with bounded timeout context (30s) in RequestAuthInfo
- Handle DirectUpdateOrCreateConfig errors in IsLoginRequired and LoginForMobile
- Add permission enforcement to DirectUpdateOrCreateConfig for existing configs
- Fix variable shadowing in device_ios.go where err was masked by := in else block

* Address additional CodeRabbit review issues for iOS/tvOS SDK

- Make tunFd == 0 a hard error with exported ErrInvalidTunnelFD (remove dead fallback code)
- Apply defaults in ConfigFromJSON to prevent partially-initialized configs
- Add nil guards for listener/urlOpener interfaces in public SDK entry points
- Reorder config save before OnLoginSuccess to prevent teardown race
- Add explanatory comment for urlOpener.Open goroutine

* Make urlOpener.Open() synchronous in device auth flow
2025-12-30 16:41:36 +00:00
Misha Bragin
9ed1437442 Add DEX IdP Support (#4949) 2025-12-30 07:42:34 -05:00
Pascal Fischer
a8604ef51c [management] filter own peer when having a group to peer policy to themself (#4956) 2025-12-30 10:49:43 +01:00
Nicolas Henneaux
d88e046d00 fix(router): nft tables limit number of peers source (#4852)
* fix(router): nft tables limit number of peers source batching them, failing at 3277 prefixes on nftables v1.0.9 with Ubuntu 24.04.3 LTS,  6.14.0-35-generic #35~24.04.1-Ubuntu

* fix(router): nft tables limit number of prefixes on ipSet creation
2025-12-30 10:48:17 +01:00
Pascal Fischer
1d2c7776fd [management] apply login filter only for setup key peers (#4943) 2025-12-30 10:46:00 +01:00
Haruki Hasegawa
4035f07248 [client] Fix Advanced Settings not opening on Windows with Japanese locale (#4455) (#4637)
The Fyne framework does not support TTC font files.
Use the default system font (Segoe UI) instead, so Windows can
automatically fall back to a Japanese font when needed.
2025-12-30 10:36:12 +01:00
Zoltan Papp
ef2721f4e1 Filter out own peer from remote peers list during peer updates. (#4986) 2025-12-30 10:29:45 +01:00
Louis Li
e11970e32e [client] add reset for management backoff (#4935)
Reset client management grpc client backoff after successful connected to management API.

Current Situation:
If the connection duration exceeds MaxElapsedTime, when the connection is interrupted, the backoff fails immediately due to timeout and does not actually perform a retry.
2025-12-30 08:37:49 +01:00
Maycon Santos
38f9d5ed58 [infra] Preset signal port on templates (#5004)
When passing certificates to signal, it will select port 443 when no port is supplied. This changes forces port 80.
2025-12-29 18:07:06 +03:00
Pascal Fischer
b6a327e0c9 [management] fix scanning authorized user on policy rule (#5002) 2025-12-29 15:03:16 +01:00
Zoltan Papp
67f7b2404e [client, management] Feature/ssh fine grained access (#4969)
Add fine-grained SSH access control with authorized users/groups
2025-12-29 12:50:41 +01:00
Zoltan Papp
73201c4f3e Add conditional checks for FreeBSD diff file generation in release workflow (#5001) 2025-12-29 12:47:38 +01:00
Carlos Hernandez
33d1761fe8 Apply DNS host config on change only (#4695)
Adds a per-instance uint64 hash to DefaultServer to detect identical merged host DNS configs (including extra domains). applyHostConfig computes and compares the hash, skips applying if unchanged, treats hash errors as a fail-safe (proceed to apply), and updates the stored hash only after successful hashing and apply.
2025-12-29 12:43:57 +01:00
August
aa914a0f26 [docs] Fix broken image link (#4876) 2025-12-24 22:06:35 +05:00
Maycon Santos
ab6a9e85de [misc] Use new sign pipelines 0.1.0 (#4993) 2025-12-24 22:03:14 +05:00
Maycon Santos
d3b123c76d [ci] Add FreeBSD port release job to GitHub Actions (#4916)
adds a job that produces new freebsd release files
2025-12-24 11:22:33 +01:00
Viktor Liu
fc4932a23f [client] Fix Linux UI flickering on state updates (#4886) 2025-12-24 11:06:13 +01:00
66 changed files with 4842 additions and 701 deletions

View File

@@ -39,7 +39,7 @@ jobs:
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -p 1 ./client/...
time go test -timeout 8m -failfast -v -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.23"
SIGN_PIPE_VER: "v0.1.0"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -19,6 +19,100 @@ concurrency:
cancel-in-progress: true
jobs:
release_freebsd_port:
name: "FreeBSD Port / Build & Test"
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Generate FreeBSD port diff
run: bash release_files/freebsd-port-diff.sh
- name: Generate FreeBSD port issue body
run: bash release_files/freebsd-port-issue-body.sh
- name: Check if diff was generated
id: check_diff
run: |
if ls netbird-*.diff 1> /dev/null 2>&1; then
echo "diff_exists=true" >> $GITHUB_OUTPUT
else
echo "diff_exists=false" >> $GITHUB_OUTPUT
echo "No diff file generated (port may already be up to date)"
fi
- name: Extract version
if: steps.check_diff.outputs.diff_exists == 'true'
id: version
run: |
VERSION=$(ls netbird-*.diff | sed 's/netbird-\(.*\)\.diff/\1/')
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "Generated files for version: $VERSION"
cat netbird-*.diff
- name: Test FreeBSD port
if: steps.check_diff.outputs.diff_exists == 'true'
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "15.0"
prepare: |
# Install required packages
pkg install -y git curl portlint go
# Install Go for building
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -LO "$GO_URL"
tar -C /usr/local -xzf "$GO_TARBALL"
# Clone ports tree (shallow, only what we need)
git clone --depth 1 --filter=blob:none https://git.FreeBSD.org/ports.git /usr/ports
cd /usr/ports
run: |
set -e -x
export PATH=$PATH:/usr/local/go/bin
# Find the diff file
echo "Finding diff file..."
DIFF_FILE=$(find $PWD -name "netbird-*.diff" -type f 2>/dev/null | head -1)
echo "Found: $DIFF_FILE"
if [[ -z "$DIFF_FILE" ]]; then
echo "ERROR: Could not find diff file"
find ~ -name "*.diff" -type f 2>/dev/null || true
exit 1
fi
# Apply the generated diff from /usr/ports (diff has a/security/netbird/... paths)
cd /usr/ports
patch -p1 -V none < "$DIFF_FILE"
# Show patched Makefile
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
cd /usr/ports/security/netbird
export BATCH=yes
make package
pkg add ./work/pkg/netbird-*.pkg
netbird version | grep "$version"
echo "FreeBSD port test completed successfully!"
- name: Upload FreeBSD port files
if: steps.check_diff.outputs.diff_exists == 'true'
uses: actions/upload-artifact@v4
with:
name: freebsd-port-files
path: |
./netbird-*-issue.txt
./netbird-*.diff
retention-days: 30
release:
runs-on: ubuntu-latest-m
env:

View File

@@ -113,7 +113,7 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
<p float="left" align="middle">
<img src="https://docs.netbird.io/docs-static/img/architecture/high-level-dia.png" width="700"/>
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
</p>
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.

View File

@@ -386,6 +386,97 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr)
}
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
if _, err := exec.LookPath("iptables-save"); err != nil {
t.Skipf("iptables-save not available on this system: %v", err)
}
// First ensure iptables-nft tables exist by running iptables-save
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Close(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
})
const octet2Count = 25
const octet3Count = 255
prefixes := make([]netip.Prefix, 0, (octet2Count-1)*(octet3Count-1))
for i := 1; i < octet2Count; i++ {
for j := 1; j < octet3Count; j++ {
addr := netip.AddrFrom4([4]byte{192, byte(j), byte(i), 0})
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
}
}
_, err = manager.AddRouteFiltering(
nil,
prefixes,
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}
func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
if _, err := exec.LookPath("iptables-save"); err != nil {
t.Skipf("iptables-save not available on this system: %v", err)
}
// First ensure iptables-nft tables exist by running iptables-save
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Close(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
})
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{},
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
t.Helper()
require.Equal(t, len(got), len(want), "expression count mismatch")

View File

@@ -48,9 +48,11 @@ const (
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
)
const refreshRulesMapError = "refresh rules map: %w"
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
maxPrefixesSet = 1500
refreshRulesMapError = "refresh rules map: %w"
)
var (
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
@@ -513,16 +515,35 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
}
elements := convertPrefixesToSet(prefixes)
if err := r.conn.AddSet(nfset, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
nElements := len(elements)
maxElements := maxPrefixesSet * 2
initialElements := elements[:min(maxElements, nElements)]
if err := r.conn.AddSet(nfset, initialElements); err != nil {
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
var subEnd int
for subStart := maxElements; subStart < nElements; subStart += maxElements {
subEnd = min(subStart+maxElements, nElements)
subElement := elements[subStart:subEnd]
nSubPrefixes := len(subElement) / 2
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
return nil, fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
}
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
return nfset, nil
}

View File

@@ -4,6 +4,7 @@
package device
import (
"fmt"
"os"
log "github.com/sirupsen/logrus"
@@ -45,10 +46,31 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu
}
}
// ErrInvalidTunnelFD is returned when the tunnel file descriptor is invalid (0).
// This typically means the Swift code couldn't find the utun control socket.
var ErrInvalidTunnelFD = fmt.Errorf("invalid tunnel file descriptor: fd is 0 (Swift failed to locate utun socket)")
func (t *TunDevice) Create() (WGConfigurer, error) {
log.Infof("create tun interface")
dupTunFd, err := unix.Dup(t.tunFd)
var tunDevice tun.Device
var err error
// Validate the tunnel file descriptor.
// On iOS/tvOS, the FD must be provided by the NEPacketTunnelProvider.
// A value of 0 means the Swift code couldn't find the utun control socket
// (the low-level APIs like ctl_info, sockaddr_ctl may not be exposed in
// tvOS SDK headers). This is a hard error - there's no viable fallback
// since tun.CreateTUN() cannot work within the iOS/tvOS sandbox.
if t.tunFd == 0 {
log.Errorf("Tunnel file descriptor is 0 - Swift code failed to locate the utun control socket. " +
"On tvOS, ensure the NEPacketTunnelProvider is properly configured and the tunnel is started.")
return nil, ErrInvalidTunnelFD
}
// Normal iOS/tvOS path: use the provided file descriptor from NEPacketTunnelProvider
var dupTunFd int
dupTunFd, err = unix.Dup(t.tunFd)
if err != nil {
log.Errorf("Unable to dup tun fd: %v", err)
return nil, err
@@ -60,7 +82,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
_ = unix.Close(dupTunFd)
return nil, err
}
tunDevice, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
tunDevice, err = tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
if err != nil {
log.Errorf("Unable to create new tun device from fd: %v", err)
_ = unix.Close(dupTunFd)

View File

@@ -80,6 +80,7 @@ type DefaultServer struct {
updateSerial uint64
previousConfigHash uint64
currentConfig HostDNSConfig
currentConfigHash uint64
handlerChain *HandlerChain
extraDomains map[domain.Domain]int
@@ -207,6 +208,7 @@ func newDefaultServer(
hostsDNSHolder: newHostsDNSHolder(),
hostManager: &noopHostConfigurator{},
mgmtCacheResolver: mgmtCacheResolver,
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
}
// register with root zone, handler chain takes care of the routing
@@ -586,8 +588,29 @@ func (s *DefaultServer) applyHostConfig() {
log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
hash, err := hashstructure.Hash(config, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
})
if err != nil {
log.Warnf("unable to hash the host dns configuration, will apply config anyway: %s", err)
// Fall through to apply config anyway (fail-safe approach)
} else if s.currentConfigHash == hash {
log.Debugf("not applying host config as there are no changes")
return
}
log.Debugf("applying host config as there are changes")
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
log.Errorf("failed to apply DNS host manager update: %v", err)
return
}
// Only update hash if it was computed successfully and config was applied
if err == nil {
s.currentConfigHash = hash
}
s.registerFallback(config)

View File

@@ -1602,7 +1602,10 @@ func TestExtraDomains(t *testing.T) {
"other.example.com.",
"duplicate.example.com.",
},
applyHostConfigCall: 4,
// Expect 3 calls instead of 4 because when deregistering duplicate.example.com,
// the domain remains in the config (ref count goes from 2 to 1), so the host
// config hash doesn't change and applyDNSConfig is not called.
applyHostConfigCall: 3,
},
{
name: "Config update with new domains after registration",
@@ -1657,7 +1660,10 @@ func TestExtraDomains(t *testing.T) {
expectedMatchOnly: []string{
"extra.example.com.",
},
applyHostConfigCall: 3,
// Expect 2 calls instead of 3 because when deregistering protected.example.com,
// it's removed from extraDomains but still remains in the config (from customZones),
// so the host config hash doesn't change and applyDNSConfig is not called.
applyHostConfigCall: 2,
},
{
name: "Register domain that is part of nameserver group",

View File

@@ -1121,6 +1121,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.updateOfflinePeers(networkMap.GetOfflinePeers())
// Filter out own peer from the remote peers list
localPubKey := e.config.WgPrivateKey.PublicKey().String()
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
for _, p := range networkMap.GetRemotePeers() {
if p.GetWgPubKey() != localPubKey {
remotePeers = append(remotePeers, p)
}
}
// cleanup request, most likely our peer has been deleted
if networkMap.GetRemotePeersIsEmpty() {
err := e.removeAllPeers()
@@ -1129,32 +1138,34 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return err
}
} else {
err := e.removePeers(networkMap.GetRemotePeers())
err := e.removePeers(remotePeers)
if err != nil {
return err
}
err = e.modifyPeers(networkMap.GetRemotePeers())
err = e.modifyPeers(remotePeers)
if err != nil {
return err
}
err = e.addNewPeers(networkMap.GetRemotePeers())
err = e.addNewPeers(remotePeers)
if err != nil {
return err
}
e.statusRecorder.FinishPeerListModifications()
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
e.updatePeerSSHHostKeys(remotePeers)
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
if err := e.updateSSHClientConfig(remotePeers); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, networkMap.GetRemotePeers())
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
e.networkSerial = serial

View File

@@ -11,15 +11,18 @@ import (
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
type sshServer interface {
Start(ctx context.Context, addr netip.AddrPort) error
Stop() error
GetStatus() (bool, []sshserver.SessionInfo)
UpdateSSHAuth(config *sshauth.Config)
}
func (e *Engine) setupSSHPortRedirection() error {
@@ -353,3 +356,38 @@ func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.Sessio
return sshServer.GetStatus()
}
// updateSSHServerAuth updates SSH fine-grained access control configuration on a running SSH server
func (e *Engine) updateSSHServerAuth(sshAuth *mgmProto.SSHAuth) {
if sshAuth == nil {
return
}
if e.sshServer == nil {
return
}
protoUsers := sshAuth.GetAuthorizedUsers()
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
for i, hash := range protoUsers {
if len(hash) != 16 {
log.Warnf("invalid hash length %d, expected 16 - skipping SSH server auth update", len(hash))
return
}
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
}
machineUsers := make(map[string][]uint32)
for osUser, indexes := range sshAuth.GetMachineUsers() {
machineUsers[osUser] = indexes.GetIndexes()
}
// Update SSH server with new authorization configuration
authConfig := &sshauth.Config{
UserIDClaim: sshAuth.GetUserIDClaim(),
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
}
e.sshServer.UpdateSSHAuth(authConfig)
}

View File

@@ -148,13 +148,15 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used.
func (conn *Conn) Open(engineCtx context.Context) error {
conn.semaphore.Add(engineCtx)
if err := conn.semaphore.Add(engineCtx); err != nil {
return err
}
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.opened {
conn.semaphore.Done(engineCtx)
conn.semaphore.Done()
return nil
}
@@ -165,6 +167,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
conn.semaphore.Done()
return err
}
conn.workerICE = workerICE
@@ -200,7 +203,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
defer conn.wg.Done()
conn.waitInitialRandomSleepTime(conn.ctx)
conn.semaphore.Done(conn.ctx)
conn.semaphore.Done()
conn.guard.Start(conn.ctx, conn.onGuardEvent)
}()

View File

@@ -3,6 +3,7 @@ package profilemanager
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net/url"
"os"
@@ -820,3 +821,85 @@ func readConfig(configPath string, createIfMissing bool) (*Config, error) {
func WriteOutConfig(path string, config *Config) error {
return util.WriteJson(context.Background(), path, config)
}
// DirectWriteOutConfig writes config directly without atomic temp file operations.
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
func DirectWriteOutConfig(path string, config *Config) error {
return util.DirectWriteJson(context.Background(), path, config)
}
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if !fileExists(input.ConfigPath) {
log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input)
if err != nil {
return nil, err
}
err = util.DirectWriteJson(context.Background(), input.ConfigPath, cfg)
return cfg, err
}
if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil
}
// Enforce permissions on existing config files (same as UpdateOrCreateConfig)
if err := util.EnforcePermission(input.ConfigPath); err != nil {
log.Errorf("failed to enforce permission on config file: %v", err)
}
return directUpdate(input)
}
func directUpdate(input ConfigInput) (*Config, error) {
config := &Config{}
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
return nil, err
}
updated, err := config.apply(input)
if err != nil {
return nil, err
}
if updated {
if err := util.DirectWriteJson(context.Background(), input.ConfigPath, config); err != nil {
return nil, err
}
}
return config, nil
}
// ConfigToJSON serializes a Config struct to a JSON string.
// This is useful for exporting config to alternative storage mechanisms
// (e.g., UserDefaults on tvOS where file writes are blocked).
func ConfigToJSON(config *Config) (string, error) {
bs, err := json.MarshalIndent(config, "", " ")
if err != nil {
return "", err
}
return string(bs), nil
}
// ConfigFromJSON deserializes a JSON string to a Config struct.
// This is useful for restoring config from alternative storage mechanisms.
// After unmarshaling, defaults are applied to ensure the config is fully initialized.
func ConfigFromJSON(jsonStr string) (*Config, error) {
config := &Config{}
err := json.Unmarshal([]byte(jsonStr), config)
if err != nil {
return nil, err
}
// Apply defaults to ensure required fields are initialized.
// This mirrors what readConfig does after loading from file.
if _, err := config.apply(ConfigInput{}); err != nil {
return nil, fmt.Errorf("failed to apply defaults to config: %w", err)
}
return config, nil
}

View File

@@ -22,7 +22,7 @@ const (
defaultTempDir = "/var/lib/netbird/tmp-install"
pkgDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
pkgDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
)
var (

View File

@@ -22,8 +22,8 @@ const (
msiLogFile = "msi.log"
msiDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
exeDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
msiDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
exeDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
)
var (

View File

@@ -75,6 +75,8 @@ type Client struct {
dnsManager dns.IosDnsManager
loginComplete bool
connectClient *internal.ConnectClient
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
preloadedConfig *profilemanager.Config
}
// NewClient instantiate a new Client
@@ -92,17 +94,44 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s
}
}
// SetConfigFromJSON loads config from a JSON string into memory.
// This is used on tvOS where file writes to App Group containers are blocked.
// When set, IsLoginRequired() and Run() will use this preloaded config instead of reading from file.
func (c *Client) SetConfigFromJSON(jsonStr string) error {
cfg, err := profilemanager.ConfigFromJSON(jsonStr)
if err != nil {
log.Errorf("SetConfigFromJSON: failed to parse config JSON: %v", err)
return err
}
c.preloadedConfig = cfg
log.Infof("SetConfigFromJSON: config loaded successfully from JSON")
return nil
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
exportEnvList(envList)
log.Infof("Starting NetBird client")
log.Debugf("Tunnel uses interface: %s", interfaceName)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
StateFilePath: c.stateFile,
})
if err != nil {
return err
var cfg *profilemanager.Config
var err error
// Use preloaded config if available (tvOS where file writes are blocked)
if c.preloadedConfig != nil {
log.Infof("Run: using preloaded config from memory")
cfg = c.preloadedConfig
} else {
log.Infof("Run: loading config from file")
// Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
// which are blocked by the tvOS sandbox in App Group containers
cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
StateFilePath: c.stateFile,
})
if err != nil {
return err
}
}
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
@@ -120,7 +149,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
c.ctxCancelLock.Unlock()
auth := NewAuthWithConfig(ctx, cfg)
err = auth.Login()
err = auth.LoginSync()
if err != nil {
return err
}
@@ -208,14 +237,45 @@ func (c *Client) IsLoginRequired() bool {
defer c.ctxCancelLock.Unlock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
cfg, _ := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
var cfg *profilemanager.Config
var err error
needsLogin, _ := internal.IsLoginRequired(ctx, cfg)
// Use preloaded config if available (tvOS where file writes are blocked)
if c.preloadedConfig != nil {
log.Infof("IsLoginRequired: using preloaded config from memory")
cfg = c.preloadedConfig
} else {
log.Infof("IsLoginRequired: loading config from file")
// Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
// which are blocked by the tvOS sandbox in App Group containers
cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
if err != nil {
log.Errorf("IsLoginRequired: failed to load config: %v", err)
// If we can't load config, assume login is required
return true
}
}
if cfg == nil {
log.Errorf("IsLoginRequired: config is nil")
return true
}
needsLogin, err := internal.IsLoginRequired(ctx, cfg)
if err != nil {
log.Errorf("IsLoginRequired: check failed: %v", err)
// If the check fails, assume login is required to be safe
return true
}
log.Infof("IsLoginRequired: needsLogin=%v", needsLogin)
return needsLogin
}
// loginForMobileAuthTimeout is the timeout for requesting auth info from the server
const loginForMobileAuthTimeout = 30 * time.Second
func (c *Client) LoginForMobile() string {
var ctx context.Context
//nolint
@@ -228,16 +288,26 @@ func (c *Client) LoginForMobile() string {
defer c.ctxCancelLock.Unlock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
cfg, _ := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
// Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
// which are blocked by the tvOS sandbox in App Group containers
cfg, err := profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
if err != nil {
log.Errorf("LoginForMobile: failed to load config: %v", err)
return fmt.Sprintf("failed to load config: %v", err)
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, false, "")
if err != nil {
return err.Error()
}
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
// Use a bounded timeout for the auth info request to prevent indefinite hangs
authInfoCtx, authInfoCancel := context.WithTimeout(ctx, loginForMobileAuthTimeout)
defer authInfoCancel()
flowInfo, err := oAuthFlow.RequestAuthInfo(authInfoCtx)
if err != nil {
return err.Error()
}
@@ -249,10 +319,14 @@ func (c *Client) LoginForMobile() string {
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {
log.Errorf("LoginForMobile: WaitToken failed: %v", err)
return
}
jwtToken := tokenInfo.GetTokenToUse()
_ = internal.Login(ctx, cfg, "", jwtToken)
if err := internal.Login(ctx, cfg, "", jwtToken); err != nil {
log.Errorf("LoginForMobile: Login failed: %v", err)
return
}
c.loginComplete = true
}()

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
)
@@ -33,7 +34,8 @@ type ErrListener interface {
// URLOpener it is a callback interface. The Open function will be triggered if
// the backend want to show an url for the user
type URLOpener interface {
Open(string)
Open(url string, userCode string)
OnLoginSuccess()
}
// Auth can register or login new client
@@ -72,13 +74,32 @@ func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
// is not supported and returns false without saving the configuration. For other errors return false.
func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
if listener == nil {
log.Errorf("SaveConfigIfSSOSupported: listener is nil")
return
}
go func() {
sso, err := a.saveConfigIfSSOSupported()
if err != nil {
listener.OnError(err)
} else {
listener.OnSuccess(sso)
}
}()
}
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}
@@ -97,12 +118,29 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
// which are blocked by the tvOS sandbox in App Group containers
err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
return true, err
}
// LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key.
func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupKey string, deviceName string) {
if resultListener == nil {
log.Errorf("LoginWithSetupKeyAndSaveConfig: resultListener is nil")
return
}
go func() {
err := a.loginWithSetupKeyAndSaveConfig(setupKey, deviceName)
if err != nil {
resultListener.OnError(err)
} else {
resultListener.OnSuccess()
}
}()
}
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
@@ -118,10 +156,14 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
return fmt.Errorf("backoff cycle failed: %v", err)
}
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
// which are blocked by the tvOS sandbox in App Group containers
return profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
}
func (a *Auth) Login() error {
// LoginSync performs a synchronous login check without UI interaction
// Used for background VPN connection where user should already be authenticated
func (a *Auth) LoginSync() error {
var needsLogin bool
// check if we need to generate JWT token
@@ -135,23 +177,142 @@ func (a *Auth) Login() error {
jwtToken := ""
if needsLogin {
return fmt.Errorf("Not authenticated")
return fmt.Errorf("not authenticated")
}
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
// PermissionDenied means registration is required or peer is blocked
return backoff.Permanent(err)
}
return err
})
if err != nil {
return fmt.Errorf("login failed: %v", err)
}
return nil
}
// Login performs interactive login with device authentication support
// Deprecated: Use LoginWithDeviceName instead to ensure proper device naming on tvOS
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, forceDeviceAuth bool) {
// Use empty device name - system will use hostname as fallback
a.LoginWithDeviceName(resultListener, urlOpener, forceDeviceAuth, "")
}
// LoginWithDeviceName performs interactive login with device authentication support
// The deviceName parameter allows specifying a custom device name (required for tvOS)
func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpener, forceDeviceAuth bool, deviceName string) {
if resultListener == nil {
log.Errorf("LoginWithDeviceName: resultListener is nil")
return
}
if urlOpener == nil {
log.Errorf("LoginWithDeviceName: urlOpener is nil")
resultListener.OnError(fmt.Errorf("urlOpener is nil"))
return
}
go func() {
err := a.login(urlOpener, forceDeviceAuth, deviceName)
if err != nil {
resultListener.OnError(err)
} else {
resultListener.OnSuccess()
}
}()
}
func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
var needsLogin bool
// Create context with device name if provided
ctx := a.ctx
if deviceName != "" {
//nolint:staticcheck
ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
}
// check if we need to generate JWT token
err := a.withBackOff(ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(ctx, a.config)
return
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
}
jwtToken := ""
if needsLogin {
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
jwtToken = tokenInfo.GetTokenToUse()
}
err = a.withBackOff(ctx, func() error {
err := internal.Login(ctx, a.config, "", jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
// PermissionDenied means registration is required or peer is blocked
return backoff.Permanent(err)
}
return err
})
if err != nil {
return fmt.Errorf("login failed: %v", err)
}
// Save the config before notifying success to ensure persistence completes
// before the callback potentially triggers teardown on the Swift side.
// Note: This differs from Android which doesn't save config after login.
// On iOS/tvOS, we save here because:
// 1. The config may have been modified during login (e.g., new tokens)
// 2. On tvOS, the Network Extension context may be the only place with
// write permissions to the App Group container
if a.cfgPath != "" {
if err := profilemanager.DirectWriteOutConfig(a.cfgPath, a.config); err != nil {
log.Warnf("failed to save config after login: %v", err)
}
}
// Notify caller of successful login synchronously before returning
urlOpener.OnLoginSuccess()
return nil
}
const authInfoRequestTimeout = 30 * time.Second
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "")
if err != nil {
return nil, err
}
// Use a bounded timeout for the auth info request to prevent indefinite hangs
authInfoCtx, authInfoCancel := context.WithTimeout(a.ctx, authInfoRequestTimeout)
defer authInfoCancel()
flowInfo, err := oAuthFlow.RequestAuthInfo(authInfoCtx)
if err != nil {
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
}
urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}
return &tokenInfo, nil
}
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
@@ -160,3 +321,24 @@ func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}
// GetConfigJSON returns the current config as a JSON string.
// This can be used by the caller to persist the config via alternative storage
// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).
func (a *Auth) GetConfigJSON() (string, error) {
if a.config == nil {
return "", fmt.Errorf("no config available")
}
return profilemanager.ConfigToJSON(a.config)
}
// SetConfigFromJSON loads config from a JSON string.
// This can be used to restore config from alternative storage mechanisms.
func (a *Auth) SetConfigFromJSON(jsonStr string) error {
cfg, err := profilemanager.ConfigFromJSON(jsonStr)
if err != nil {
return err
}
a.config = cfg
return nil
}

View File

@@ -112,6 +112,8 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) {
// Commit write out the changes into config file
func (p *Preferences) Commit() error {
_, err := profilemanager.UpdateOrCreateConfig(p.configInput)
// Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
// which are blocked by the tvOS sandbox in App Group containers
_, err := profilemanager.DirectUpdateOrCreateConfig(p.configInput)
return err
}

184
client/ssh/auth/auth.go Normal file
View File

@@ -0,0 +1,184 @@
package auth
import (
"errors"
"fmt"
"sync"
log "github.com/sirupsen/logrus"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
const (
// DefaultUserIDClaim is the default JWT claim used to extract user IDs
DefaultUserIDClaim = "sub"
// Wildcard is a special user ID that matches all users
Wildcard = "*"
)
var (
ErrEmptyUserID = errors.New("JWT user ID is empty")
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
)
// Authorizer handles SSH fine-grained access control authorization
type Authorizer struct {
// UserIDClaim is the JWT claim to extract the user ID from
userIDClaim string
// authorizedUsers is a list of hashed user IDs authorized to access this peer
authorizedUsers []sshuserhash.UserIDHash
// machineUsers maps OS login usernames to lists of authorized user indexes
machineUsers map[string][]uint32
// mu protects the list of users
mu sync.RWMutex
}
// Config contains configuration for the SSH authorizer
type Config struct {
// UserIDClaim is the JWT claim to extract the user ID from (e.g., "sub", "email")
UserIDClaim string
// AuthorizedUsers is a list of hashed user IDs (FNV-1a 64-bit) authorized to access this peer
AuthorizedUsers []sshuserhash.UserIDHash
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
// If a user wants to login as a specific OS user, their index must be in the corresponding list
MachineUsers map[string][]uint32
}
// NewAuthorizer creates a new SSH authorizer with empty configuration
func NewAuthorizer() *Authorizer {
a := &Authorizer{
userIDClaim: DefaultUserIDClaim,
machineUsers: make(map[string][]uint32),
}
return a
}
// Update updates the authorizer configuration with new values
func (a *Authorizer) Update(config *Config) {
a.mu.Lock()
defer a.mu.Unlock()
if config == nil {
// Clear authorization
a.userIDClaim = DefaultUserIDClaim
a.authorizedUsers = []sshuserhash.UserIDHash{}
a.machineUsers = make(map[string][]uint32)
log.Info("SSH authorization cleared")
return
}
userIDClaim := config.UserIDClaim
if userIDClaim == "" {
userIDClaim = DefaultUserIDClaim
}
a.userIDClaim = userIDClaim
// Store authorized users list
a.authorizedUsers = config.AuthorizedUsers
// Store machine users mapping
machineUsers := make(map[string][]uint32)
for osUser, indexes := range config.MachineUsers {
if len(indexes) > 0 {
machineUsers[osUser] = indexes
}
}
a.machineUsers = machineUsers
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
len(config.AuthorizedUsers), len(machineUsers))
}
// Authorize validates if a user is authorized to login as the specified OS user
// Returns nil if authorized, or an error describing why authorization failed
func (a *Authorizer) Authorize(jwtUserID, osUsername string) error {
if jwtUserID == "" {
log.Warnf("SSH auth denied: JWT user ID is empty for OS user '%s'", osUsername)
return ErrEmptyUserID
}
// Hash the JWT user ID for comparison
hashedUserID, err := sshuserhash.HashUserID(jwtUserID)
if err != nil {
log.Errorf("SSH auth denied: failed to hash user ID '%s' for OS user '%s': %v", jwtUserID, osUsername, err)
return fmt.Errorf("failed to hash user ID: %w", err)
}
a.mu.RLock()
defer a.mu.RUnlock()
// Find the index of this user in the authorized list
userIndex, found := a.findUserIndex(hashedUserID)
if !found {
log.Warnf("SSH auth denied: user '%s' (hash: %s) not in authorized list for OS user '%s'", jwtUserID, hashedUserID, osUsername)
return ErrUserNotAuthorized
}
return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex)
}
// checkMachineUserMapping validates if a user's index is authorized for the specified OS user
// Checks wildcard mapping first, then specific OS user mappings
func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) error {
// If wildcard exists and user's index is in the wildcard list, allow access to any OS user
if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard {
if a.isIndexInList(uint32(userIndex), wildcardIndexes) {
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' via wildcard (index: %d)", jwtUserID, osUsername, userIndex)
return nil
}
}
// Check for specific OS username mapping
allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername]
if !hasMachineUserMapping {
// No mapping for this OS user - deny by default (fail closed)
log.Warnf("SSH auth denied: no machine user mapping for OS user '%s' (JWT user: %s)", osUsername, jwtUserID)
return ErrNoMachineUserMapping
}
// Check if user's index is in the allowed indexes for this specific OS user
if !a.isIndexInList(uint32(userIndex), allowedIndexes) {
log.Warnf("SSH auth denied: user '%s' not mapped to OS user '%s' (user index: %d)", jwtUserID, osUsername, userIndex)
return ErrUserNotMappedToOSUser
}
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' (index: %d)", jwtUserID, osUsername, userIndex)
return nil
}
// GetUserIDClaim returns the JWT claim name used to extract user IDs
func (a *Authorizer) GetUserIDClaim() string {
a.mu.RLock()
defer a.mu.RUnlock()
return a.userIDClaim
}
// findUserIndex finds the index of a hashed user ID in the authorized users list
// Returns the index and true if found, 0 and false if not found
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {
for i, id := range a.authorizedUsers {
if id == hashedUserID {
return i, true
}
}
return 0, false
}
// isIndexInList checks if an index exists in a list of indexes
func (a *Authorizer) isIndexInList(index uint32, indexes []uint32) bool {
for _, idx := range indexes {
if idx == index {
return true
}
}
return false
}

View File

@@ -0,0 +1,612 @@
package auth
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/sshauth"
)
func TestAuthorizer_Authorize_UserNotInList(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list with one user
authorizedUserHash, err := sshauth.HashUserID("authorized-user")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{authorizedUserHash},
MachineUsers: map[string][]uint32{},
}
authorizer.Update(config)
// Try to authorize a different user
err = authorizer.Authorize("unauthorized-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{}, // Empty = deny all (fail closed)
}
authorizer.Update(config)
// All attempts should fail when no machine user mappings exist (fail closed)
err = authorizer.Authorize("user1", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
user3Hash, err := sshauth.HashUserID("user3")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
MachineUsers: map[string][]uint32{
"root": {0, 1}, // user1 and user2 can access root
"postgres": {1, 2}, // user2 and user3 can access postgres
"admin": {0}, // only user1 can access admin
},
}
authorizer.Update(config)
// user1 (index 0) should access root and admin
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user1", "admin")
assert.NoError(t, err)
// user2 (index 1) should access root and postgres
err = authorizer.Authorize("user2", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
// user3 (index 2) should access postgres
err = authorizer.Authorize("user3", "postgres")
assert.NoError(t, err)
}
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
user3Hash, err := sshauth.HashUserID("user3")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
MachineUsers: map[string][]uint32{
"root": {0, 1}, // user1 and user2 can access root
"postgres": {1, 2}, // user2 and user3 can access postgres
"admin": {0}, // only user1 can access admin
},
}
authorizer.Update(config)
// user1 (index 0) should NOT access postgres
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user2 (index 1) should NOT access admin
err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user3 (index 2) should NOT access root
err = authorizer.Authorize("user3", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user3 (index 2) should NOT access admin
err = authorizer.Authorize("user3", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
}
func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"root": {0}, // only root is mapped
},
}
authorizer.Update(config)
// user1 should NOT access an unmapped OS user (fail closed)
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{},
}
authorizer.Update(config)
// Empty user ID should fail
err = authorizer.Authorize("", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrEmptyUserID)
}
func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) {
authorizer := NewAuthorizer()
// Set up multiple authorized users
userHashes := make([]sshauth.UserIDHash, 10)
for i := 0; i < 10; i++ {
hash, err := sshauth.HashUserID("user" + string(rune('0'+i)))
require.NoError(t, err)
userHashes[i] = hash
}
// Create machine user mapping for all users
rootIndexes := make([]uint32, 10)
for i := 0; i < 10; i++ {
rootIndexes[i] = uint32(i)
}
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: userHashes,
MachineUsers: map[string][]uint32{
"root": rootIndexes,
},
}
authorizer.Update(config)
// All users should be authorized for root
for i := 0; i < 10; i++ {
err := authorizer.Authorize("user"+string(rune('0'+i)), "root")
assert.NoError(t, err, "user%d should be authorized", i)
}
// User not in list should fail
err := authorizer.Authorize("unknown-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) {
authorizer := NewAuthorizer()
// Set up initial configuration
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{"root": {0}},
}
authorizer.Update(config)
// user1 should be authorized
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// Clear configuration
authorizer.Update(nil)
// user1 should no longer be authorized
err = authorizer.Authorize("user1", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
// Machine users with empty index lists should be filtered out
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"root": {0},
"postgres": {}, // empty list - should be filtered out
"admin": nil, // nil list - should be filtered out
},
}
authorizer.Update(config)
// root should work
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// postgres should fail (no mapping)
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
// admin should fail (no mapping)
err = authorizer.Authorize("user1", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_CustomUserIDClaim(t *testing.T) {
authorizer := NewAuthorizer()
// Set up with custom user ID claim
user1Hash, err := sshauth.HashUserID("user@example.com")
require.NoError(t, err)
config := &Config{
UserIDClaim: "email",
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"root": {0},
},
}
authorizer.Update(config)
// Verify the custom claim is set
assert.Equal(t, "email", authorizer.GetUserIDClaim())
// Authorize with email as user ID
err = authorizer.Authorize("user@example.com", "root")
assert.NoError(t, err)
}
func TestAuthorizer_DefaultUserIDClaim(t *testing.T) {
authorizer := NewAuthorizer()
// Verify default claim
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
assert.Equal(t, "sub", authorizer.GetUserIDClaim())
// Set up with empty user ID claim (should use default)
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: "", // empty - should use default
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{},
}
authorizer.Update(config)
// Should fall back to default
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
}
func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) {
authorizer := NewAuthorizer()
// Create a large authorized users list
const numUsers = 1000
userHashes := make([]sshauth.UserIDHash, numUsers)
for i := 0; i < numUsers; i++ {
hash, err := sshauth.HashUserID("user" + string(rune(i)))
require.NoError(t, err)
userHashes[i] = hash
}
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: userHashes,
MachineUsers: map[string][]uint32{
"root": {0, 500, 999}, // first, middle, and last user
},
}
authorizer.Update(config)
// First user should have access
err := authorizer.Authorize("user"+string(rune(0)), "root")
assert.NoError(t, err)
// Middle user should have access
err = authorizer.Authorize("user"+string(rune(500)), "root")
assert.NoError(t, err)
// Last user should have access
err = authorizer.Authorize("user"+string(rune(999)), "root")
assert.NoError(t, err)
// User not in mapping should NOT have access
err = authorizer.Authorize("user"+string(rune(100)), "root")
assert.Error(t, err)
}
func TestAuthorizer_ConcurrentAuthorization(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{
"root": {0, 1},
},
}
authorizer.Update(config)
// Test concurrent authorization calls (should be safe to read concurrently)
const numGoroutines = 100
errChan := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(idx int) {
user := "user1"
if idx%2 == 0 {
user = "user2"
}
err := authorizer.Authorize(user, "root")
errChan <- err
}(i)
}
// Wait for all goroutines to complete and collect errors
for i := 0; i < numGoroutines; i++ {
err := <-errChan
assert.NoError(t, err)
}
}
func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
user3Hash, err := sshauth.HashUserID("user3")
require.NoError(t, err)
// Configure with wildcard - all authorized users can access any OS user
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
MachineUsers: map[string][]uint32{
"*": {0, 1, 2}, // wildcard with all user indexes
},
}
authorizer.Update(config)
// All authorized users should be able to access any OS user
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
err = authorizer.Authorize("user3", "admin")
assert.NoError(t, err)
err = authorizer.Authorize("user1", "ubuntu")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "nginx")
assert.NoError(t, err)
err = authorizer.Authorize("user3", "docker")
assert.NoError(t, err)
}
func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
// Configure with wildcard
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"*": {0},
},
}
authorizer.Update(config)
// user1 should have access
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// Unauthorized user should still be denied even with wildcard
err = authorizer.Authorize("unauthorized-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
// Configure with both wildcard and specific mappings
// Wildcard takes precedence for users in the wildcard index list
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{
"*": {0, 1}, // wildcard for both users
"root": {0}, // specific mapping that would normally restrict to user1 only
},
}
authorizer.Update(config)
// Both users should be able to access root via wildcard (takes precedence over specific mapping)
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "root")
assert.NoError(t, err)
// Both users should be able to access any other OS user via wildcard
err = authorizer.Authorize("user1", "postgres")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "admin")
assert.NoError(t, err)
}
func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
// Configure WITHOUT wildcard - only specific mappings
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{
"root": {0}, // only user1
"postgres": {1}, // only user2
},
}
authorizer.Update(config)
// user1 can access root
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// user2 can access postgres
err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
// user1 cannot access postgres
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user2 cannot access root
err = authorizer.Authorize("user2", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// Neither can access unmapped OS users
err = authorizer.Authorize("user1", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
// This test covers the scenario where wildcard exists with limited indexes.
// Only users whose indexes are in the wildcard list can access any OS user via wildcard.
// Other users can only access OS users they are explicitly mapped to.
authorizer := NewAuthorizer()
// Create two authorized user hashes (simulating the base64-encoded hashes in the config)
wasmHash, err := sshauth.HashUserID("wasm")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
// Configure with wildcard having only index 0, and specific mappings for other OS users
config := &Config{
UserIDClaim: "sub",
AuthorizedUsers: []sshauth.UserIDHash{wasmHash, user2Hash},
MachineUsers: map[string][]uint32{
"*": {0}, // wildcard with only index 0 - only wasm has wildcard access
"alice": {1}, // specific mapping for user2
"bob": {1}, // specific mapping for user2
},
}
authorizer.Update(config)
// wasm (index 0) should access any OS user via wildcard
err = authorizer.Authorize("wasm", "root")
assert.NoError(t, err, "wasm should access root via wildcard")
err = authorizer.Authorize("wasm", "alice")
assert.NoError(t, err, "wasm should access alice via wildcard")
err = authorizer.Authorize("wasm", "bob")
assert.NoError(t, err, "wasm should access bob via wildcard")
err = authorizer.Authorize("wasm", "postgres")
assert.NoError(t, err, "wasm should access postgres via wildcard")
// user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres
err = authorizer.Authorize("user2", "alice")
assert.NoError(t, err, "user2 should access alice via explicit mapping")
err = authorizer.Authorize("user2", "bob")
assert.NoError(t, err, "user2 should access bob via explicit mapping")
err = authorizer.Authorize("user2", "root")
assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)")
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "postgres")
assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)")
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
// Unauthorized user should still be denied
err = authorizer.Authorize("user3", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
}

View File

@@ -27,9 +27,11 @@ import (
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
func TestMain(m *testing.M) {
@@ -137,6 +139,21 @@ func TestSSHProxy_Connect(t *testing.T) {
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
// Configure SSH authorization for the test user
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0}, // Index 0 in AuthorizedUsers
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
defer func() { _ = sshServer.Stop() }()
@@ -150,10 +167,10 @@ func TestSSHProxy_Connect(t *testing.T) {
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
clientConn, proxyConn := net.Pipe()
@@ -347,12 +364,12 @@ func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
return privateKey, jwksJSON
}
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
t.Helper()
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"sub": user,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}

View File

@@ -23,10 +23,12 @@ import (
"github.com/stretchr/testify/require"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
func TestJWTEnforcement(t *testing.T) {
@@ -577,6 +579,22 @@ func TestJWTAuthentication(t *testing.T) {
tc.setupServer(server)
}
// Always set up authorization for test-user to ensure tests fail at JWT validation stage
testUserHash, err := sshuserhash.HashUserID("test-user")
require.NoError(t, err)
// Get current OS username for machine user mapping
currentUser := testutil.GetTestUsername(t)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
currentUser: {0}, // Allow test-user (index 0) to access current OS user
},
}
server.UpdateSSHAuth(authConfig)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())

View File

@@ -21,6 +21,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt"
@@ -138,6 +139,8 @@ type Server struct {
jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig
authorizer *sshauth.Authorizer
suSupportsPty bool
loginIsUtilLinux bool
}
@@ -179,6 +182,7 @@ func New(config *Config) *Server {
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
jwtEnabled: config.JWT != nil,
jwtConfig: config.JWT,
authorizer: sshauth.NewAuthorizer(), // Initialize with empty config
}
return s
@@ -320,6 +324,19 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
s.wgAddress = addr
}
// UpdateSSHAuth updates the SSH fine-grained access control configuration
// This should be called when network map updates include new SSH auth configuration
func (s *Server) UpdateSSHAuth(config *sshauth.Config) {
s.mu.Lock()
defer s.mu.Unlock()
// Reset JWT validator/extractor to pick up new userIDClaim
s.jwtValidator = nil
s.jwtExtractor = nil
s.authorizer.Update(config)
}
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
func (s *Server) ensureJWTValidator() error {
s.mu.RLock()
@@ -328,6 +345,7 @@ func (s *Server) ensureJWTValidator() error {
return nil
}
config := s.jwtConfig
authorizer := s.authorizer
s.mu.RUnlock()
if config == nil {
@@ -343,9 +361,16 @@ func (s *Server) ensureJWTValidator() error {
true,
)
extractor := jwt.NewClaimsExtractor(
// Use custom userIDClaim from authorizer if available
extractorOptions := []jwt.ClaimsExtractorOption{
jwt.WithAudience(config.Audience),
)
}
if authorizer.GetUserIDClaim() != "" {
extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
log.Debugf("Using custom user ID claim: %s", authorizer.GetUserIDClaim())
}
extractor := jwt.NewClaimsExtractor(extractorOptions...)
s.mu.Lock()
defer s.mu.Unlock()
@@ -493,29 +518,41 @@ func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]int
}
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
osUsername := ctx.User()
remoteAddr := ctx.RemoteAddr()
if err := s.ensureJWTValidator(); err != nil {
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
log.Errorf("JWT validator initialization failed for user %s from %s: %v", osUsername, remoteAddr, err)
return false
}
token, err := s.validateJWTToken(password)
if err != nil {
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
log.Warnf("JWT authentication failed for user %s from %s: %v", osUsername, remoteAddr, err)
return false
}
userAuth, err := s.extractAndValidateUser(token)
if err != nil {
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
log.Warnf("User validation failed for user %s from %s: %v", osUsername, remoteAddr, err)
return false
}
key := newAuthKey(ctx.User(), ctx.RemoteAddr())
s.mu.RLock()
authorizer := s.authorizer
s.mu.RUnlock()
if err := authorizer.Authorize(userAuth.UserId, osUsername); err != nil {
log.Warnf("SSH authorization denied for user %s (JWT user ID: %s) from %s: %v", osUsername, userAuth.UserId, remoteAddr, err)
return false
}
key := newAuthKey(osUsername, remoteAddr)
s.mu.Lock()
s.pendingAuthJWT[key] = userAuth.UserId
s.mu.Unlock()
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", osUsername, userAuth.UserId, remoteAddr)
return true
}

View File

@@ -312,6 +312,8 @@ type serviceClient struct {
daemonVersion string
updateIndicationLock sync.Mutex
isUpdateIconActive bool
settingsEnabled bool
profilesEnabled bool
showNetworks bool
wNetworks fyne.Window
wProfiles fyne.Window
@@ -907,7 +909,7 @@ func (s *serviceClient) updateStatus() error {
var systrayIconState bool
switch {
case status.Status == string(internal.StatusConnected):
case status.Status == string(internal.StatusConnected) && !s.connected:
s.connected = true
s.sendNotification = true
if s.isUpdateIconActive {
@@ -921,6 +923,7 @@ func (s *serviceClient) updateStatus() error {
s.mUp.Disable()
s.mDown.Enable()
s.mNetworks.Enable()
s.mExitNode.Enable()
go s.updateExitNodes()
systrayIconState = true
case status.Status == string(internal.StatusConnecting):
@@ -1274,19 +1277,22 @@ func (s *serviceClient) checkAndUpdateFeatures() {
return
}
s.updateIndicationLock.Lock()
defer s.updateIndicationLock.Unlock()
// Update settings menu based on current features
if features != nil && features.DisableUpdateSettings {
s.setSettingsEnabled(false)
} else {
s.setSettingsEnabled(true)
settingsEnabled := features == nil || !features.DisableUpdateSettings
if s.settingsEnabled != settingsEnabled {
s.settingsEnabled = settingsEnabled
s.setSettingsEnabled(settingsEnabled)
}
// Update profile menu based on current features
if s.mProfile != nil {
if features != nil && features.DisableProfiles {
s.mProfile.setEnabled(false)
} else {
s.mProfile.setEnabled(true)
profilesEnabled := features == nil || !features.DisableProfiles
if s.profilesEnabled != profilesEnabled {
s.profilesEnabled = profilesEnabled
s.mProfile.setEnabled(profilesEnabled)
}
}
}

View File

@@ -31,7 +31,6 @@ func (s *serviceClient) getWindowsFontFilePath() string {
"chr-CHER-US": "Gadugi.ttf",
"zh-HK": "Segoeui.ttf",
"zh-TW": "Segoeui.ttf",
"ja-JP": "Yugothm.ttc",
"km-KH": "Leelawui.ttf",
"ko-KR": "Malgun.ttf",
"th-TH": "Leelawui.ttf",

17
go.mod
View File

@@ -22,7 +22,7 @@ require (
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.73.0
google.golang.org/grpc v1.75.0
google.golang.org/protobuf v1.36.8
gopkg.in/natefinch/lumberjack.v2 v2.0.0
)
@@ -41,6 +41,7 @@ require (
github.com/coder/websocket v1.8.13
github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18
github.com/dexidp/dex/api/v2 v2.4.0
github.com/eko/gocache/lib/v4 v4.2.0
github.com/eko/gocache/store/go_cache/v4 v4.2.2
github.com/eko/gocache/store/redis/v4 v4.2.2
@@ -97,10 +98,10 @@ require (
github.com/yusufpapurcu/wmi v1.2.4
github.com/zcalusic/sysinfo v1.1.3
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0
go.opentelemetry.io/otel v1.35.0
go.opentelemetry.io/otel v1.37.0
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
go.opentelemetry.io/otel/metric v1.35.0
go.opentelemetry.io/otel/sdk/metric v1.35.0
go.opentelemetry.io/otel/metric v1.37.0
go.opentelemetry.io/otel/sdk/metric v1.37.0
go.uber.org/mock v0.5.0
go.uber.org/zap v1.27.0
goauthentik.io/api/v3 v3.2023051.3
@@ -124,7 +125,7 @@ require (
require (
cloud.google.com/go/auth v0.3.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
cloud.google.com/go/compute/metadata v0.6.0 // indirect
cloud.google.com/go/compute/metadata v0.7.0 // indirect
dario.cat/mergo v1.0.0 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
@@ -170,7 +171,7 @@ require (
github.com/fyne-io/oksvg v0.2.0 // indirect
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
@@ -248,8 +249,8 @@ require (
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
go.opentelemetry.io/otel/trace v1.35.0 // indirect
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
go.opentelemetry.io/otel/trace v1.37.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.33.0 // indirect
golang.org/x/text v0.31.0 // indirect

40
go.sum
View File

@@ -4,8 +4,8 @@ cloud.google.com/go/auth v0.3.0/go.mod h1:lBv6NKTWp8E3LPzmO1TbiiRKc4drLOfHsgmlH9
cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4=
cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q=
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
cloud.google.com/go/compute/metadata v0.7.0 h1:PBWF+iiAerVNe8UCHxdOt6eHLVc3ydFeOCw78U8ytSU=
cloud.google.com/go/compute/metadata v0.7.0/go.mod h1:j5MvL9PprKL39t166CoB1uVHfQMs4tFQZZcKwksXUjo=
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
@@ -117,6 +117,8 @@ github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70J
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dexidp/dex/api/v2 v2.4.0 h1:gNba7n6BKVp8X4Jp24cxYn5rIIGhM6kDOXcZoL6tr9A=
github.com/dexidp/dex/api/v2 v2.4.0/go.mod h1:/p550ADvFFh7K95VmhUD+jgm15VdaNnab9td8DHOpyI=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
@@ -164,8 +166,8 @@ github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71/go.mod h1:9YTyiznxEY1fVin
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+xzxf1jTJKMKZw3H0swfWk9RpWbBbDK5+0=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
@@ -561,22 +563,22 @@ go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.4
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc=
go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y=
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s=
go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o=
go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M=
go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE=
go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY=
go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg=
go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o=
go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w=
go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs=
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc=
go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps=
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
@@ -761,6 +763,8 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/api v0.177.0 h1:8a0p/BbPa65GlqGWtUKxot4p0TV8OGOfyTjtmkXNXmk=
google.golang.org/api v0.177.0/go.mod h1:srbhue4MLjkjbkux5p3dw/ocYOSZTaIEvf7bCOnFQDw=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
@@ -770,8 +774,8 @@ google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoA
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463 h1:hE3bRWtU6uceqlh4fhrSnUyjKHMKB9KrTLLG+bc0ddM=
google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463/go.mod h1:U90ffi8eUL9MwPcrJylN5+Mk2v3vuPDptd5yyNUiRR8=
google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7 h1:FiusG7LWj+4byqhbvmB+Q93B/mOxJLN2DTozDuZm4EU=
google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:kXqgZtrWaf6qS3jZOCnCH7WYfrvFjkC51bM8fz3RsCA=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 h1:pFyd6EwwL2TqFf8emdthzeX+gZE1ElRq3iM8pui4KBY=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
@@ -779,8 +783,8 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok=
google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc=
google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4=
google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=

View File

@@ -53,7 +53,8 @@ services:
command: [
"--cert-file", "$NETBIRD_MGMT_API_CERT_FILE",
"--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE",
"--log-file", "console"
"--log-file", "console",
"--port", "80"
]
# Relay

View File

@@ -0,0 +1,554 @@
#!/bin/bash
set -e
# NetBird Getting Started with Dex IDP
# This script sets up NetBird with Dex as the identity provider
# Sed pattern to strip base64 padding characters
SED_STRIP_PADDING='s/=//g'
check_docker_compose() {
if command -v docker-compose &> /dev/null
then
echo "docker-compose"
return
fi
if docker compose --help &> /dev/null
then
echo "docker compose"
return
fi
echo "docker-compose is not installed or not in PATH. Please follow the steps from the official guide: https://docs.docker.com/engine/install/" > /dev/stderr
exit 1
}
check_jq() {
if ! command -v jq &> /dev/null
then
echo "jq is not installed or not in PATH, please install with your package manager. e.g. sudo apt install jq" > /dev/stderr
exit 1
fi
return 0
}
get_main_ip_address() {
if [[ "$OSTYPE" == "darwin"* ]]; then
interface=$(route -n get default | grep 'interface:' | awk '{print $2}')
ip_address=$(ifconfig "$interface" | grep 'inet ' | awk '{print $2}')
else
interface=$(ip route | grep default | awk '{print $5}' | head -n 1)
ip_address=$(ip addr show "$interface" | grep 'inet ' | awk '{print $2}' | cut -d'/' -f1)
fi
echo "$ip_address"
return 0
}
check_nb_domain() {
DOMAIN=$1
if [[ "$DOMAIN-x" == "-x" ]]; then
echo "The NETBIRD_DOMAIN variable cannot be empty." > /dev/stderr
return 1
fi
if [[ "$DOMAIN" == "netbird.example.com" ]]; then
echo "The NETBIRD_DOMAIN cannot be netbird.example.com" > /dev/stderr
return 1
fi
return 0
}
read_nb_domain() {
READ_NETBIRD_DOMAIN=""
echo -n "Enter the domain you want to use for NetBird (e.g. netbird.my-domain.com): " > /dev/stderr
read -r READ_NETBIRD_DOMAIN < /dev/tty
if ! check_nb_domain "$READ_NETBIRD_DOMAIN"; then
read_nb_domain
fi
echo "$READ_NETBIRD_DOMAIN"
return 0
}
get_turn_external_ip() {
TURN_EXTERNAL_IP_CONFIG="#external-ip="
IP=$(curl -s -4 https://jsonip.com | jq -r '.ip')
if [[ "x-$IP" != "x-" ]]; then
TURN_EXTERNAL_IP_CONFIG="external-ip=$IP"
fi
echo "$TURN_EXTERNAL_IP_CONFIG"
return 0
}
wait_dex() {
set +e
echo -n "Waiting for Dex to become ready (via $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN)"
counter=1
while true; do
# Check Dex through Caddy proxy (also validates TLS is working)
if curl -sk -f -o /dev/null "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex/.well-known/openid-configuration" 2>/dev/null; then
break
fi
if [[ $counter -eq 60 ]]; then
echo ""
echo "Taking too long. Checking logs..."
$DOCKER_COMPOSE_COMMAND logs --tail=20 caddy
$DOCKER_COMPOSE_COMMAND logs --tail=20 dex
fi
echo -n " ."
sleep 2
counter=$((counter + 1))
done
echo " done"
set -e
return 0
}
init_environment() {
CADDY_SECURE_DOMAIN=""
NETBIRD_PORT=80
NETBIRD_HTTP_PROTOCOL="http"
NETBIRD_RELAY_PROTO="rel"
TURN_USER="self"
TURN_PASSWORD=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING")
NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING")
TURN_MIN_PORT=49152
TURN_MAX_PORT=65535
TURN_EXTERNAL_IP_CONFIG=$(get_turn_external_ip)
# Generate secrets for Dex
DEX_DASHBOARD_CLIENT_SECRET=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING")
# Generate admin password
NETBIRD_ADMIN_PASSWORD=$(openssl rand -base64 16 | sed "$SED_STRIP_PADDING")
if ! check_nb_domain "$NETBIRD_DOMAIN"; then
NETBIRD_DOMAIN=$(read_nb_domain)
fi
if [[ "$NETBIRD_DOMAIN" == "use-ip" ]]; then
NETBIRD_DOMAIN=$(get_main_ip_address)
else
NETBIRD_PORT=443
CADDY_SECURE_DOMAIN=", $NETBIRD_DOMAIN:$NETBIRD_PORT"
NETBIRD_HTTP_PROTOCOL="https"
NETBIRD_RELAY_PROTO="rels"
fi
check_jq
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
if [[ -f dex.yaml ]]; then
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
echo "You can use the following commands:"
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
echo " rm -f docker-compose.yml Caddyfile dex.yaml dashboard.env turnserver.conf management.json relay.env"
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
exit 1
fi
echo Rendering initial files...
render_docker_compose > docker-compose.yml
render_caddyfile > Caddyfile
render_dex_config > dex.yaml
render_dashboard_env > dashboard.env
render_management_json > management.json
render_turn_server_conf > turnserver.conf
render_relay_env > relay.env
echo -e "\nStarting Dex IDP\n"
$DOCKER_COMPOSE_COMMAND up -d caddy dex
# Wait for Dex to be ready (through caddy proxy)
sleep 3
wait_dex
echo -e "\nStarting NetBird services\n"
$DOCKER_COMPOSE_COMMAND up -d
echo -e "\nDone!\n"
echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
echo ""
echo "Login with the following credentials:"
echo "Email: admin@$NETBIRD_DOMAIN" | tee .env
echo "Password: $NETBIRD_ADMIN_PASSWORD" | tee -a .env
echo ""
echo "Dex admin UI is not available (Dex has no built-in UI)."
echo "To add more users, edit dex.yaml and restart: $DOCKER_COMPOSE_COMMAND restart dex"
return 0
}
render_caddyfile() {
cat <<EOF
{
debug
servers :80,:443 {
protocols h1 h2c h2 h3
}
}
(security_headers) {
header * {
Strict-Transport-Security "max-age=3600; includeSubDomains; preload"
X-Content-Type-Options "nosniff"
X-Frame-Options "SAMEORIGIN"
X-XSS-Protection "1; mode=block"
-Server
Referrer-Policy strict-origin-when-cross-origin
}
}
:80${CADDY_SECURE_DOMAIN} {
import security_headers
# Relay
reverse_proxy /relay* relay:80
# Signal
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
# Management
reverse_proxy /api/* management:80
reverse_proxy /management.ManagementService/* h2c://management:80
# Dex
reverse_proxy /dex/* dex:5556
# Dashboard
reverse_proxy /* dashboard:80
}
EOF
return 0
}
render_dex_config() {
# Generate bcrypt hash of the admin password
# Using a simple approach - htpasswd or python if available
ADMIN_PASSWORD_HASH=""
if command -v htpasswd &> /dev/null; then
ADMIN_PASSWORD_HASH=$(htpasswd -bnBC 10 "" "$NETBIRD_ADMIN_PASSWORD" | tr -d ':\n')
elif command -v python3 &> /dev/null; then
ADMIN_PASSWORD_HASH=$(python3 -c "import bcrypt; print(bcrypt.hashpw('$NETBIRD_ADMIN_PASSWORD'.encode(), bcrypt.gensalt(rounds=10)).decode())" 2>/dev/null || echo "")
fi
# Fallback to a known hash if we can't generate one
if [[ -z "$ADMIN_PASSWORD_HASH" ]]; then
# This is hash of "password" - user should change it
ADMIN_PASSWORD_HASH='$2a$10$2b2cU8CPhOTaGrs1HRQuAueS7JTT5ZHsHSzYiFPm1leZck7Mc8T4W'
NETBIRD_ADMIN_PASSWORD="password"
echo "Warning: Could not generate password hash. Using default password: password. Please change it in dex.yaml" > /dev/stderr
fi
cat <<EOF
# Dex configuration for NetBird
# Generated by getting-started-with-dex.sh
issuer: $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex
storage:
type: sqlite3
config:
file: /var/dex/dex.db
web:
http: 0.0.0.0:5556
# gRPC API for user management (used by NetBird IDP manager)
grpc:
addr: 0.0.0.0:5557
oauth2:
skipApprovalScreen: true
# Static OAuth2 clients for NetBird
staticClients:
# Dashboard client
- id: netbird-dashboard
name: NetBird Dashboard
secret: $DEX_DASHBOARD_CLIENT_SECRET
redirectURIs:
- $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/nb-auth
- $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/nb-silent-auth
# CLI client (public - uses PKCE)
- id: netbird-cli
name: NetBird CLI
public: true
redirectURIs:
- http://localhost:53000/
- http://localhost:54000/
# Enable password database for static users
enablePasswordDB: true
# Static users - add more users here as needed
staticPasswords:
- email: "admin@$NETBIRD_DOMAIN"
hash: "$ADMIN_PASSWORD_HASH"
username: "admin"
userID: "$(uuidgen 2>/dev/null || cat /proc/sys/kernel/random/uuid 2>/dev/null || echo "admin-user-id-001")"
# Optional: Add external identity provider connectors
# connectors:
# - type: github
# id: github
# name: GitHub
# config:
# clientID: \$GITHUB_CLIENT_ID
# clientSecret: \$GITHUB_CLIENT_SECRET
# redirectURI: $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex/callback
#
# - type: ldap
# id: ldap
# name: LDAP
# config:
# host: ldap.example.com:636
# insecureNoSSL: false
# bindDN: cn=admin,dc=example,dc=com
# bindPW: admin
# userSearch:
# baseDN: ou=users,dc=example,dc=com
# filter: "(objectClass=person)"
# username: uid
# idAttr: uid
# emailAttr: mail
# nameAttr: cn
EOF
return 0
}
render_turn_server_conf() {
cat <<EOF
listening-port=3478
$TURN_EXTERNAL_IP_CONFIG
tls-listening-port=5349
min-port=$TURN_MIN_PORT
max-port=$TURN_MAX_PORT
fingerprint
lt-cred-mech
user=$TURN_USER:$TURN_PASSWORD
realm=wiretrustee.com
cert=/etc/coturn/certs/cert.pem
pkey=/etc/coturn/private/privkey.pem
log-file=stdout
no-software-attribute
pidfile="/var/tmp/turnserver.pid"
no-cli
EOF
return 0
}
render_management_json() {
cat <<EOF
{
"Stuns": [
{
"Proto": "udp",
"URI": "stun:$NETBIRD_DOMAIN:3478"
}
],
"Relay": {
"Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"],
"CredentialsTTL": "24h",
"Secret": "$NETBIRD_RELAY_AUTH_SECRET"
},
"Signal": {
"Proto": "$NETBIRD_HTTP_PROTOCOL",
"URI": "$NETBIRD_DOMAIN:$NETBIRD_PORT"
},
"HttpConfig": {
"AuthIssuer": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex",
"AuthAudience": "netbird-dashboard",
"OIDCConfigEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex/.well-known/openid-configuration"
},
"IdpManagerConfig": {
"ManagerType": "dex",
"ClientConfig": {
"Issuer": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex"
},
"ExtraConfig": {
"GRPCAddr": "dex:5557"
}
},
"DeviceAuthorizationFlow": {
"Provider": "hosted",
"ProviderConfig": {
"Audience": "netbird-cli",
"ClientID": "netbird-cli",
"Scope": "openid profile email offline_access",
"DeviceAuthEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex/device/code",
"TokenEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex/token"
}
},
"PKCEAuthorizationFlow": {
"ProviderConfig": {
"Audience": "netbird-cli",
"ClientID": "netbird-cli",
"Scope": "openid profile email offline_access",
"RedirectURLs": ["http://localhost:53000/", "http://localhost:54000/"]
}
}
}
EOF
return 0
}
render_dashboard_env() {
cat <<EOF
# Endpoints
NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN
NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN
# OIDC
AUTH_AUDIENCE=netbird-dashboard
AUTH_CLIENT_ID=netbird-dashboard
AUTH_CLIENT_SECRET=$DEX_DASHBOARD_CLIENT_SECRET
AUTH_AUTHORITY=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex
USE_AUTH0=false
AUTH_SUPPORTED_SCOPES=openid profile email offline_access
AUTH_REDIRECT_URI=/nb-auth
AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
# SSL
NGINX_SSL_PORT=443
# Letsencrypt
LETSENCRYPT_DOMAIN=none
EOF
return 0
}
render_relay_env() {
cat <<EOF
NB_LOG_LEVEL=info
NB_LISTEN_ADDRESS=:80
NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT
NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET
EOF
return 0
}
render_docker_compose() {
cat <<EOF
services:
# Caddy reverse proxy
caddy:
image: caddy
container_name: netbird-caddy
restart: unless-stopped
networks: [netbird]
ports:
- '443:443'
- '443:443/udp'
- '80:80'
volumes:
- netbird_caddy_data:/data
- ./Caddyfile:/etc/caddy/Caddyfile
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Dex - identity provider
dex:
image: ghcr.io/dexidp/dex:v2.38.0
container_name: netbird-dex
restart: unless-stopped
networks: [netbird]
volumes:
- ./dex.yaml:/etc/dex/config.docker.yaml:ro
- netbird_dex_data:/var/dex
command: ["dex", "serve", "/etc/dex/config.docker.yaml"]
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# UI dashboard
dashboard:
image: netbirdio/dashboard:latest
container_name: netbird-dashboard
restart: unless-stopped
networks: [netbird]
env_file:
- ./dashboard.env
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Signal
signal:
image: netbirdio/signal:latest
container_name: netbird-signal
restart: unless-stopped
networks: [netbird]
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Relay
relay:
image: netbirdio/relay:latest
container_name: netbird-relay
restart: unless-stopped
networks: [netbird]
env_file:
- ./relay.env
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Management
management:
image: netbirdio/management:latest
container_name: netbird-management
restart: unless-stopped
networks: [netbird]
volumes:
- netbird_management:/var/lib/netbird
- ./management.json:/etc/netbird/management.json
command: [
"--port", "80",
"--log-file", "console",
"--log-level", "info",
"--disable-anonymous-metrics=false",
"--single-account-mode-domain=netbird.selfhosted",
"--dns-domain=netbird.selfhosted",
"--idp-sign-key-refresh-enabled",
]
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Coturn, AKA TURN server
coturn:
image: coturn/coturn
container_name: netbird-coturn
restart: unless-stopped
volumes:
- ./turnserver.conf:/etc/turnserver.conf:ro
network_mode: host
command:
- -c /etc/turnserver.conf
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
volumes:
netbird_caddy_data:
netbird_dex_data:
netbird_management:
networks:
netbird:
EOF
return 0
}
init_environment

View File

@@ -178,6 +178,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.experimentalNetworkMap(accountID) {
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
@@ -224,7 +225,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
if c.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -320,6 +321,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
postureChecks, err := c.getPeerPostureChecks(account, peerId)
if err != nil {
@@ -338,7 +340,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
if c.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -445,7 +447,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
if c.experimentalNetworkMap(accountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics)
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics, account.GetActiveGroupUsers())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -811,7 +813,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]

View File

@@ -158,5 +158,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
}
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}

View File

@@ -6,7 +6,10 @@ import (
"net/url"
"strings"
log "github.com/sirupsen/logrus"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
@@ -16,6 +19,7 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/sshauth"
)
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
@@ -84,15 +88,15 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken
return nbConfig
}
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.PeerConfig {
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, enableSSH bool) *proto.PeerConfig {
netmask, _ := network.Net.Mask.Size()
fqdn := peer.FQDN(dnsName)
sshConfig := &proto.SSHConfig{
SshEnabled: peer.SSHEnabled,
SshEnabled: peer.SSHEnabled || enableSSH,
}
if peer.SSHEnabled {
if sshConfig.SshEnabled {
sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig)
}
@@ -110,12 +114,12 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
},
Checks: toProtocolChecks(ctx, checks),
}
@@ -151,9 +155,45 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
response.NetworkMap.ForwardingRules = forwardingRules
}
if networkMap.AuthorizedUsers != nil {
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
userIDClaim := auth.DefaultUserIDClaim
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
userIDClaim = httpConfig.AuthUserIDClaim
}
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
}
return response
}
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
userIDToIndex := make(map[string]uint32)
var hashedUsers [][]byte
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
for machineUser, users := range authorizedUsers {
indexes := make([]uint32, 0, len(users))
for userID := range users {
idx, exists := userIDToIndex[userID]
if !exists {
hash, err := sshauth.HashUserID(userID)
if err != nil {
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
continue
}
idx = uint32(len(hashedUsers))
userIDToIndex[userID] = idx
hashedUsers = append(hashedUsers, hash[:])
}
indexes = append(indexes, idx)
}
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
}
return hashedUsers, machineUsers
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
dst = append(dst, &proto.RemotePeerConfig{

View File

@@ -184,8 +184,14 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
realIP := getRealIP(ctx)
sRealIP := realIP.String()
peerMeta := extractPeerMeta(ctx, syncReq.GetMeta())
userID, err := s.accountManager.GetUserIDByPeerKey(ctx, peerKey.String())
if err != nil {
s.syncSem.Add(-1)
return mapError(ctx, err)
}
metahashed := metaHash(peerMeta, sRealIP)
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if userID == "" && !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
}
@@ -270,6 +276,8 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
unlock()
unlock = nil
log.WithContext(ctx).Debugf("Sync took %s", time.Since(reqStart))
s.syncSem.Add(-1)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
@@ -559,6 +567,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
}
log.WithContext(ctx).Debugf("Login took %s", time.Since(reqStart))
}()
if loginReq.GetMeta() == nil {
@@ -635,7 +644,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow),
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH),
Checks: toProtocolChecks(ctx, postureChecks),
}

View File

@@ -1456,21 +1456,19 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
}
}
if settings.GroupsPropagationEnabled {
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
if err != nil {
return err
}
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
if err != nil {
return err
}
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
if err != nil {
return err
}
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
if err != nil {
return err
}
if removedGroupAffectsPeers || newGroupsAffectsPeers {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
}
if removedGroupAffectsPeers || newGroupsAffectsPeers {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
}
return nil
@@ -2158,3 +2156,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti
return nil
}
func (am *DefaultAccountManager) GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) {
return am.Store.GetUserIDByPeerKey(ctx, store.LockingStrengthNone, peerKey)
}

View File

@@ -123,4 +123,5 @@ type Manager interface {
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error)
}

View File

@@ -397,7 +397,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
}

View File

@@ -427,7 +427,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range groupIDs {
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthUpdate, accountID, groupID)
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil {
allErrors = errors.Join(allErrors, err)
continue
@@ -442,6 +442,10 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
deletedGroups = append(deletedGroups, group)
}
if len(groupIDsToDelete) == 0 {
return allErrors
}
if err = transaction.DeleteGroups(ctx, accountID, groupIDsToDelete); err != nil {
return err
}

View File

@@ -299,7 +299,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}
@@ -369,6 +369,9 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
PortRanges: []types.RulePortRange{portRange},
}},
}
if protocol == types.PolicyRuleProtocolNetbirdSSH {
policy.Rules[0].AuthorizedUser = userAuth.UserId
}
_, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true)
if err != nil {
@@ -449,6 +452,18 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
Ephemeral: peer.Ephemeral,
LocalFlags: &api.PeerLocalFlags{
BlockInbound: &peer.Meta.Flags.BlockInbound,
BlockLanAccess: &peer.Meta.Flags.BlockLANAccess,
DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes,
DisableDns: &peer.Meta.Flags.DisableDNS,
DisableFirewall: &peer.Meta.Flags.DisableFirewall,
DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes,
LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled,
RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled,
RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive,
ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed,
},
}
if !approved {
@@ -463,7 +478,6 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
if osVersion == "" {
osVersion = peer.Meta.Core
}
return &api.PeerBatch{
CreatedAt: peer.CreatedAt,
Id: peer.ID,
@@ -492,6 +506,18 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
Ephemeral: peer.Ephemeral,
LocalFlags: &api.PeerLocalFlags{
BlockInbound: &peer.Meta.Flags.BlockInbound,
BlockLanAccess: &peer.Meta.Flags.BlockLANAccess,
DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes,
DisableDns: &peer.Meta.Flags.DisableDNS,
DisableFirewall: &peer.Meta.Flags.DisableFirewall,
DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes,
LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled,
RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled,
RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive,
ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed,
},
}
}

View File

@@ -221,6 +221,8 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
pr.Protocol = types.PolicyRuleProtocolUDP
case api.PolicyRuleUpdateProtocolIcmp:
pr.Protocol = types.PolicyRuleProtocolICMP
case api.PolicyRuleUpdateProtocolNetbirdSsh:
pr.Protocol = types.PolicyRuleProtocolNetbirdSSH
default:
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
return
@@ -254,6 +256,17 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
}
}
if pr.Protocol == types.PolicyRuleProtocolNetbirdSSH && rule.AuthorizedGroups != nil && len(*rule.AuthorizedGroups) != 0 {
for _, sourceGroupID := range pr.Sources {
_, ok := (*rule.AuthorizedGroups)[sourceGroupID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "authorized group for netbird-ssh protocol should be specified for each source group"), w)
return
}
}
pr.AuthorizedGroups = *rule.AuthorizedGroups
}
// validate policy object
if pr.Protocol == types.PolicyRuleProtocolALL || pr.Protocol == types.PolicyRuleProtocolICMP {
if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 {
@@ -380,6 +393,11 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
DestinationResource: r.DestinationResource.ToAPIResponse(),
}
if len(r.AuthorizedGroups) != 0 {
authorizedGroupsCopy := r.AuthorizedGroups
rule.AuthorizedGroups = &authorizedGroupsCopy
}
if len(r.Ports) != 0 {
portsCopy := r.Ports
rule.Ports = &portsCopy

View File

@@ -0,0 +1,445 @@
package idp
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/dexidp/dex/api/v2"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/management/server/telemetry"
)
// DexManager implements the Manager interface for Dex IDP.
// It uses Dex's gRPC API to manage users in the password database.
type DexManager struct {
grpcAddr string
httpClient ManagerHTTPClient
helper ManagerHelper
appMetrics telemetry.AppMetrics
mux sync.Mutex
conn *grpc.ClientConn
}
// DexClientConfig Dex manager client configuration.
type DexClientConfig struct {
// GRPCAddr is the address of Dex's gRPC API (e.g., "localhost:5557")
GRPCAddr string
// Issuer is the Dex issuer URL (e.g., "https://dex.example.com/dex")
Issuer string
}
// NewDexManager creates a new instance of DexManager.
func NewDexManager(config DexClientConfig, appMetrics telemetry.AppMetrics) (*DexManager, error) {
if config.GRPCAddr == "" {
return nil, fmt.Errorf("dex IdP configuration is incomplete, GRPCAddr is missing")
}
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
return &DexManager{
grpcAddr: config.GRPCAddr,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}, nil
}
// getConnection returns a gRPC connection to Dex, creating one if necessary.
// It also checks if an existing connection is still healthy and reconnects if needed.
func (dm *DexManager) getConnection(ctx context.Context) (*grpc.ClientConn, error) {
dm.mux.Lock()
defer dm.mux.Unlock()
if dm.conn != nil {
state := dm.conn.GetState()
// If connection is shutdown or in a transient failure, close and reconnect
if state == connectivity.Shutdown || state == connectivity.TransientFailure {
log.WithContext(ctx).Debugf("Dex gRPC connection in state %s, reconnecting", state)
_ = dm.conn.Close()
dm.conn = nil
} else {
return dm.conn, nil
}
}
log.WithContext(ctx).Debugf("connecting to Dex gRPC API at %s", dm.grpcAddr)
conn, err := grpc.NewClient(dm.grpcAddr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
return nil, fmt.Errorf("failed to connect to Dex gRPC API: %w", err)
}
dm.conn = conn
return conn, nil
}
// getDexClient returns a Dex API client.
func (dm *DexManager) getDexClient(ctx context.Context) (api.DexClient, error) {
conn, err := dm.getConnection(ctx)
if err != nil {
return nil, err
}
return api.NewDexClient(conn), nil
}
// encodeDexUserID encodes a user ID and connector ID into Dex's composite format.
// This is the reverse of parseDexUserID - it creates the base64-encoded protobuf
// format that Dex uses in JWT tokens.
func encodeDexUserID(userID, connectorID string) string {
// Build simple protobuf structure:
// Field 1 (tag 0x0a): user ID string
// Field 2 (tag 0x12): connector ID string
buf := make([]byte, 0, 2+len(userID)+2+len(connectorID))
// Field 1: user ID
buf = append(buf, 0x0a) // tag for field 1, wire type 2 (length-delimited)
buf = append(buf, byte(len(userID))) // length
buf = append(buf, []byte(userID)...) // value
// Field 2: connector ID
buf = append(buf, 0x12) // tag for field 2, wire type 2 (length-delimited)
buf = append(buf, byte(len(connectorID))) // length
buf = append(buf, []byte(connectorID)...) // value
return base64.StdEncoding.EncodeToString(buf)
}
// parseDexUserID extracts the actual user ID from Dex's composite user ID.
// Dex encodes user IDs in JWT tokens as base64-encoded protobuf with format:
// - Field 1 (string): actual user ID
// - Field 2 (string): connector ID (e.g., "local")
// If the ID is not in this format, it returns the original ID.
func parseDexUserID(compositeID string) string {
// Try to decode as standard base64
decoded, err := base64.StdEncoding.DecodeString(compositeID)
if err != nil {
// Try URL-safe base64
decoded, err = base64.RawURLEncoding.DecodeString(compositeID)
if err != nil {
// Not base64 encoded, return as-is
return compositeID
}
}
// Parse the simple protobuf structure
// Field 1 (tag 0x0a): user ID string
// Field 2 (tag 0x12): connector ID string
if len(decoded) < 2 {
return compositeID
}
// Check for field 1 tag (0x0a = field 1, wire type 2/length-delimited)
if decoded[0] != 0x0a {
return compositeID
}
// Read the length of the user ID string
length := int(decoded[1])
if len(decoded) < 2+length {
return compositeID
}
// Extract the user ID
userID := string(decoded[2 : 2+length])
return userID
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
// Dex doesn't support app metadata, so this is a no-op.
func (dm *DexManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// GetUserDataByID requests user data from Dex via user ID.
func (dm *DexManager) GetUserDataByID(ctx context.Context, userID string, _ AppMetadata) (*UserData, error) {
if dm.appMetrics != nil {
dm.appMetrics.IDPMetrics().CountGetUserDataByID()
}
client, err := dm.getDexClient(ctx)
if err != nil {
if dm.appMetrics != nil {
dm.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
resp, err := client.ListPasswords(ctx, &api.ListPasswordReq{})
if err != nil {
if dm.appMetrics != nil {
dm.appMetrics.IDPMetrics().CountRequestError()
}
return nil, fmt.Errorf("failed to list passwords from Dex: %w", err)
}
// Try to parse the composite user ID from Dex JWT token
actualUserID := parseDexUserID(userID)
for _, p := range resp.Passwords {
// Match against both the raw userID and the parsed actualUserID
if p.UserId == userID || p.UserId == actualUserID {
return &UserData{
Email: p.Email,
Name: p.Username,
ID: userID, // Return the original ID for consistency
}, nil
}
}
return nil, fmt.Errorf("user with ID %s not found", userID)
}
// GetAccount returns all the users for a given account.
// Since Dex doesn't have account concepts, this returns all users.
func (dm *DexManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
if dm.appMetrics != nil {
dm.appMetrics.IDPMetrics().CountGetAccount()
}
users, err := dm.getAllUsers(ctx)
if err != nil {
return nil, err
}
// Set the account ID for all users
for _, user := range users {
user.AppMetadata.WTAccountID = accountID
}
return users, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// Since Dex doesn't have account concepts, all users are returned under UnsetAccountID.
func (dm *DexManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
if dm.appMetrics != nil {
dm.appMetrics.IDPMetrics().CountGetAllAccounts()
}
users, err := dm.getAllUsers(ctx)
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData)
indexedUsers[UnsetAccountID] = users
return indexedUsers, nil
}
// CreateUser creates a new user in Dex's password database.
func (dm *DexManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
if dm.appMetrics != nil {
dm.appMetrics.IDPMetrics().CountCreateUser()
}
client, err := dm.getDexClient(ctx)
if err != nil {
if dm.appMetrics != nil {
dm.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
// Generate a random password for the new user
password := GeneratePassword(16, 2, 2, 2)
// Hash the password using bcrypt
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
}
// Generate a user ID from email (Dex uses email as the key, but we need a stable ID)
userID := strings.ReplaceAll(email, "@", "-at-")
userID = strings.ReplaceAll(userID, ".", "-")
req := &api.CreatePasswordReq{
Password: &api.Password{
Email: email,
Username: name,
UserId: userID,
Hash: hashedPassword,
},
}
resp, err := client.CreatePassword(ctx, req)
if err != nil {
if dm.appMetrics != nil {
dm.appMetrics.IDPMetrics().CountRequestError()
}
return nil, fmt.Errorf("failed to create user in Dex: %w", err)
}
if resp.AlreadyExists {
return nil, fmt.Errorf("user with email %s already exists", email)
}
log.WithContext(ctx).Debugf("created user %s in Dex", email)
return &UserData{
Email: email,
Name: name,
ID: userID,
AppMetadata: AppMetadata{
WTAccountID: accountID,
WTInvitedBy: invitedByEmail,
},
}, nil
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (dm *DexManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
if dm.appMetrics != nil {
dm.appMetrics.IDPMetrics().CountGetUserByEmail()
}
client, err := dm.getDexClient(ctx)
if err != nil {
if dm.appMetrics != nil {
dm.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
resp, err := client.ListPasswords(ctx, &api.ListPasswordReq{})
if err != nil {
if dm.appMetrics != nil {
dm.appMetrics.IDPMetrics().CountRequestError()
}
return nil, fmt.Errorf("failed to list passwords from Dex: %w", err)
}
users := make([]*UserData, 0)
for _, p := range resp.Passwords {
if strings.EqualFold(p.Email, email) {
// Encode the user ID in Dex's composite format to match stored IDs
encodedID := encodeDexUserID(p.UserId, "local")
users = append(users, &UserData{
Email: p.Email,
Name: p.Username,
ID: encodedID,
})
}
}
return users, nil
}
// InviteUserByID resends an invitation to a user.
// Dex doesn't support invitations, so this returns an error.
func (dm *DexManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented for Dex")
}
// DeleteUser deletes a user from Dex by user ID.
func (dm *DexManager) DeleteUser(ctx context.Context, userID string) error {
if dm.appMetrics != nil {
dm.appMetrics.IDPMetrics().CountDeleteUser()
}
client, err := dm.getDexClient(ctx)
if err != nil {
if dm.appMetrics != nil {
dm.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
// First, find the user's email by ID
resp, err := client.ListPasswords(ctx, &api.ListPasswordReq{})
if err != nil {
if dm.appMetrics != nil {
dm.appMetrics.IDPMetrics().CountRequestError()
}
return fmt.Errorf("failed to list passwords from Dex: %w", err)
}
// Try to parse the composite user ID from Dex JWT token
actualUserID := parseDexUserID(userID)
var email string
for _, p := range resp.Passwords {
if p.UserId == userID || p.UserId == actualUserID {
email = p.Email
break
}
}
if email == "" {
return fmt.Errorf("user with ID %s not found", userID)
}
// Delete the user by email
deleteResp, err := client.DeletePassword(ctx, &api.DeletePasswordReq{
Email: email,
})
if err != nil {
if dm.appMetrics != nil {
dm.appMetrics.IDPMetrics().CountRequestError()
}
return fmt.Errorf("failed to delete user from Dex: %w", err)
}
if deleteResp.NotFound {
return fmt.Errorf("user with email %s not found", email)
}
log.WithContext(ctx).Debugf("deleted user %s from Dex", email)
return nil
}
// getAllUsers retrieves all users from Dex's password database.
func (dm *DexManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
client, err := dm.getDexClient(ctx)
if err != nil {
if dm.appMetrics != nil {
dm.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
resp, err := client.ListPasswords(ctx, &api.ListPasswordReq{})
if err != nil {
if dm.appMetrics != nil {
dm.appMetrics.IDPMetrics().CountRequestError()
}
return nil, fmt.Errorf("failed to list passwords from Dex: %w", err)
}
users := make([]*UserData, 0, len(resp.Passwords))
for _, p := range resp.Passwords {
// Encode the user ID in Dex's composite format (base64-encoded protobuf)
// to match how NetBird stores user IDs from Dex JWT tokens.
// The connector ID "local" is used for Dex's password database.
encodedID := encodeDexUserID(p.UserId, "local")
users = append(users, &UserData{
Email: p.Email,
Name: p.Username,
ID: encodedID,
})
}
return users, nil
}

View File

@@ -0,0 +1,137 @@
package idp
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
)
func TestNewDexManager(t *testing.T) {
type test struct {
name string
inputConfig DexClientConfig
assertErrFunc require.ErrorAssertionFunc
assertErrFuncMessage string
}
defaultTestConfig := DexClientConfig{
GRPCAddr: "localhost:5557",
Issuer: "https://dex.example.com/dex",
}
testCase1 := test{
name: "Good Configuration",
inputConfig: defaultTestConfig,
assertErrFunc: require.NoError,
assertErrFuncMessage: "shouldn't return error",
}
testCase2Config := defaultTestConfig
testCase2Config.GRPCAddr = ""
testCase2 := test{
name: "Missing GRPCAddr Configuration",
inputConfig: testCase2Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when GRPCAddr is empty",
}
// Test with empty issuer - should still work since issuer is optional for the manager
testCase3Config := defaultTestConfig
testCase3Config.Issuer = ""
testCase3 := test{
name: "Missing Issuer Configuration - OK",
inputConfig: testCase3Config,
assertErrFunc: require.NoError,
assertErrFuncMessage: "shouldn't return error when issuer is empty",
}
for _, testCase := range []test{testCase1, testCase2, testCase3} {
t.Run(testCase.name, func(t *testing.T) {
manager, err := NewDexManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
if err == nil {
require.NotNil(t, manager, "manager should not be nil")
require.Equal(t, testCase.inputConfig.GRPCAddr, manager.grpcAddr, "grpcAddr should match")
}
})
}
}
func TestDexManagerUpdateUserAppMetadata(t *testing.T) {
config := DexClientConfig{
GRPCAddr: "localhost:5557",
Issuer: "https://dex.example.com/dex",
}
manager, err := NewDexManager(config, &telemetry.MockAppMetrics{})
require.NoError(t, err, "should create manager without error")
// UpdateUserAppMetadata should be a no-op for Dex
err = manager.UpdateUserAppMetadata(context.Background(), "test-user-id", AppMetadata{
WTAccountID: "test-account",
})
require.NoError(t, err, "UpdateUserAppMetadata should not return error")
}
func TestDexManagerInviteUserByID(t *testing.T) {
config := DexClientConfig{
GRPCAddr: "localhost:5557",
Issuer: "https://dex.example.com/dex",
}
manager, err := NewDexManager(config, &telemetry.MockAppMetrics{})
require.NoError(t, err, "should create manager without error")
// InviteUserByID should return an error for Dex
err = manager.InviteUserByID(context.Background(), "test-user-id")
require.Error(t, err, "InviteUserByID should return error")
require.Contains(t, err.Error(), "not implemented", "error should mention not implemented")
}
func TestParseDexUserID(t *testing.T) {
tests := []struct {
name string
compositeID string
expectedID string
}{
{
name: "Parse base64-encoded protobuf composite ID",
// This is a real Dex composite ID: contains user ID "cf5db180-d360-484d-9b78-c5db92146420" and connector "local"
compositeID: "CiRjZjVkYjE4MC1kMzYwLTQ4NGQtOWI3OC1jNWRiOTIxNDY0MjASBWxvY2Fs",
expectedID: "cf5db180-d360-484d-9b78-c5db92146420",
},
{
name: "Return plain ID unchanged",
compositeID: "simple-user-id",
expectedID: "simple-user-id",
},
{
name: "Return UUID unchanged",
compositeID: "cf5db180-d360-484d-9b78-c5db92146420",
expectedID: "cf5db180-d360-484d-9b78-c5db92146420",
},
{
name: "Handle empty string",
compositeID: "",
expectedID: "",
},
{
name: "Handle invalid base64",
compositeID: "not-valid-base64!!!",
expectedID: "not-valid-base64!!!",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseDexUserID(tt.compositeID)
require.Equal(t, tt.expectedID, result, "parsed user ID should match expected")
})
}
}

View File

@@ -173,40 +173,40 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
return NewZitadelManager(*zitadelClientConfig, appMetrics)
case "authentik":
authentikConfig := AuthentikClientConfig{
return NewAuthentikManager(AuthentikClientConfig{
Issuer: config.ClientConfig.Issuer,
ClientID: config.ClientConfig.ClientID,
TokenEndpoint: config.ClientConfig.TokenEndpoint,
GrantType: config.ClientConfig.GrantType,
Username: config.ExtraConfig["Username"],
Password: config.ExtraConfig["Password"],
}
return NewAuthentikManager(authentikConfig, appMetrics)
}, appMetrics)
case "okta":
oktaClientConfig := OktaClientConfig{
return NewOktaManager(OktaClientConfig{
Issuer: config.ClientConfig.Issuer,
TokenEndpoint: config.ClientConfig.TokenEndpoint,
GrantType: config.ClientConfig.GrantType,
APIToken: config.ExtraConfig["ApiToken"],
}
return NewOktaManager(oktaClientConfig, appMetrics)
}, appMetrics)
case "google":
googleClientConfig := GoogleWorkspaceClientConfig{
return NewGoogleWorkspaceManager(ctx, GoogleWorkspaceClientConfig{
ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"],
CustomerID: config.ExtraConfig["CustomerId"],
}
return NewGoogleWorkspaceManager(ctx, googleClientConfig, appMetrics)
}, appMetrics)
case "jumpcloud":
jumpcloudConfig := JumpCloudClientConfig{
return NewJumpCloudManager(JumpCloudClientConfig{
APIToken: config.ExtraConfig["ApiToken"],
}
return NewJumpCloudManager(jumpcloudConfig, appMetrics)
}, appMetrics)
case "pocketid":
pocketidConfig := PocketIdClientConfig{
return NewPocketIdManager(PocketIdClientConfig{
APIToken: config.ExtraConfig["ApiToken"],
ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"],
}
return NewPocketIdManager(pocketidConfig, appMetrics)
}, appMetrics)
case "dex":
return NewDexManager(DexClientConfig{
GRPCAddr: config.ExtraConfig["GRPCAddr"],
Issuer: config.ClientConfig.Issuer,
}, appMetrics)
default:
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
}

View File

@@ -2,11 +2,12 @@ package mock_server
import (
"context"
"github.com/netbirdio/netbird/shared/auth"
"net"
"net/netip"
"time"
"github.com/netbirdio/netbird/shared/auth"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
@@ -988,3 +989,7 @@ func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, ac
}
return nil
}
func (am *MockAccountManager) GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) {
return "something", nil
}

View File

@@ -91,7 +91,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
// fetch all the peers that have access to the user's peers
for _, peer := range peers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap)
aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap, account.GetActiveGroupUsers())
for _, p := range aclPeers {
peersMap[p.ID] = p
}
@@ -1057,7 +1057,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
}
for _, p := range userPeers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap)
aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap, account.GetActiveGroupUsers())
for _, aclPeer := range aclPeers {
if aclPeer.ID == peer.ID {
return peer, nil

View File

@@ -246,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
t.Run("check that all peers get map", func(t *testing.T) {
for _, p := range account.Peers {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), p, validatedPeers, account.GetActiveGroupUsers())
assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
}
})
t.Run("check first peer map details", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 8)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
@@ -509,7 +509,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
})
t.Run("check port ranges support for older peers", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 1)
assert.Contains(t, peers, account.Peers["peerI"])
@@ -635,7 +635,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}
t.Run("check first peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerC"])
expectedFirewallRules := []*types.FirewallRule{
@@ -665,7 +665,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
})
t.Run("check second peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerB"])
expectedFirewallRules := []*types.FirewallRule{
@@ -697,7 +697,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
account.Policies[1].Rules[0].Bidirectional = false
t.Run("check first peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerC"])
expectedFirewallRules := []*types.FirewallRule{
@@ -719,7 +719,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
})
t.Run("check second peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerB"])
expectedFirewallRules := []*types.FirewallRule{
@@ -917,7 +917,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
// will establish a connection with all source peers satisfying the NB posture check.
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -927,7 +927,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 7)
expectedFirewallRules := []*types.FirewallRule{
@@ -992,7 +992,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -1002,7 +1002,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -1017,19 +1017,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
@@ -1044,14 +1044,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 3)
assert.Len(t, firewallRules, 3)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"])
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers)
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 5)
// assert peers from Group Swarm
assert.Contains(t, peers, account.Peers["peerD"])

View File

@@ -63,6 +63,8 @@ type SqlStore struct {
installationPK int
storeEngine types.Engine
pool *pgxpool.Pool
transactionTimeout time.Duration
}
type installation struct {
@@ -84,6 +86,14 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
conns = runtime.NumCPU()
}
transactionTimeout := 5 * time.Minute
if v := os.Getenv("NB_STORE_TRANSACTION_TIMEOUT"); v != "" {
if parsed, err := time.ParseDuration(v); err == nil {
transactionTimeout = parsed
}
}
log.WithContext(ctx).Infof("Setting transaction timeout to %v", transactionTimeout)
if storeEngine == types.SqliteStoreEngine {
if err == nil {
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
@@ -101,7 +111,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
if skipMigration {
log.WithContext(ctx).Infof("skipping migration")
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1, transactionTimeout: transactionTimeout}, nil
}
if err := migratePreAuto(ctx, db); err != nil {
@@ -120,7 +130,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
return nil, fmt.Errorf("migratePostAuto: %w", err)
}
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1, transactionTimeout: transactionTimeout}, nil
}
func GetKeyQueryCondition(s *SqlStore) string {
@@ -1910,16 +1920,17 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t
if len(policyIDs) == 0 {
return nil, nil
}
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)`
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges, authorized_groups, authorized_user FROM policy_rules WHERE policy_id = ANY($1)`
rows, err := s.pool.Query(ctx, query, policyIDs)
if err != nil {
return nil, err
}
rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) {
var r types.PolicyRule
var dest, destRes, sources, sourceRes, ports, portRanges []byte
var dest, destRes, sources, sourceRes, ports, portRanges, authorizedGroups []byte
var enabled, bidirectional sql.NullBool
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges)
var authorizedUser sql.NullString
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges, &authorizedGroups, &authorizedUser)
if err == nil {
if enabled.Valid {
r.Enabled = enabled.Bool
@@ -1945,6 +1956,12 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t
if portRanges != nil {
_ = json.Unmarshal(portRanges, &r.PortRanges)
}
if authorizedGroups != nil {
_ = json.Unmarshal(authorizedGroups, &r.AuthorizedGroups)
}
if authorizedUser.Valid {
r.AuthorizedUser = authorizedUser.String
}
}
return &r, err
})
@@ -2890,8 +2907,11 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string)
}
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
timeoutCtx, cancel := context.WithTimeout(context.Background(), s.transactionTimeout)
defer cancel()
startTime := time.Now()
tx := s.db.Begin()
tx := s.db.WithContext(timeoutCtx).Begin()
if tx.Error != nil {
return tx.Error
}
@@ -2926,6 +2946,9 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
err := operation(repo)
if err != nil {
tx.Rollback()
if errors.Is(err, context.DeadlineExceeded) || errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) {
log.WithContext(ctx).Warnf("transaction exceeded %s timeout after %v, stack: %s", s.transactionTimeout, time.Since(startTime), debug.Stack())
}
return err
}
@@ -2938,13 +2961,19 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
}
err = tx.Commit().Error
if err != nil {
if errors.Is(err, context.DeadlineExceeded) || errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) {
log.WithContext(ctx).Warnf("transaction commit exceeded %s timeout after %v, stack: %s", s.transactionTimeout, time.Since(startTime), debug.Stack())
}
return err
}
log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime))
if s.metrics != nil {
s.metrics.StoreMetrics().CountTransactionDuration(time.Since(startTime))
}
return err
return nil
}
func (s *SqlStore) withTx(tx *gorm.DB) Store {
@@ -4075,3 +4104,21 @@ func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, gro
return peers, nil
}
func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var userID string
result := tx.Model(&nbpeer.Peer{}).
Select("user_id").
Take(&userID, GetKeyQueryCondition(s), peerKey)
if result.Error != nil {
return "", status.Errorf(status.Internal, "failed to get user ID by peer key")
}
return userID, nil
}

View File

@@ -3718,6 +3718,69 @@ func TestSqlStore_GetPeersByGroupIDs(t *testing.T) {
}
}
func TestSqlStore_GetUserIDByPeerKey(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
userID := "test-user-123"
peerKey := "peer-key-abc"
peer := &nbpeer.Peer{
ID: "test-peer-1",
Key: peerKey,
AccountID: existingAccountID,
UserID: userID,
IP: net.IP{10, 0, 0, 1},
DNSLabel: "test-peer-1",
}
err = store.AddPeerToAccount(context.Background(), peer)
require.NoError(t, err)
retrievedUserID, err := store.GetUserIDByPeerKey(context.Background(), LockingStrengthNone, peerKey)
require.NoError(t, err)
assert.Equal(t, userID, retrievedUserID)
}
func TestSqlStore_GetUserIDByPeerKey_NotFound(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
nonExistentPeerKey := "non-existent-peer-key"
userID, err := store.GetUserIDByPeerKey(context.Background(), LockingStrengthNone, nonExistentPeerKey)
require.Error(t, err)
assert.Equal(t, "", userID)
}
func TestSqlStore_GetUserIDByPeerKey_NoUserID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
peerKey := "peer-key-abc"
peer := &nbpeer.Peer{
ID: "test-peer-1",
Key: peerKey,
AccountID: existingAccountID,
UserID: "",
IP: net.IP{10, 0, 0, 1},
DNSLabel: "test-peer-1",
}
err = store.AddPeerToAccount(context.Background(), peer)
require.NoError(t, err)
retrievedUserID, err := store.GetUserIDByPeerKey(context.Background(), LockingStrengthNone, peerKey)
require.NoError(t, err)
assert.Equal(t, "", retrievedUserID)
}
func TestSqlStore_ApproveAccountPeers(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
accountID := "test-account"
@@ -3794,3 +3857,30 @@ func TestSqlStore_ApproveAccountPeers(t *testing.T) {
})
})
}
func TestSqlStore_ExecuteInTransaction_Timeout(t *testing.T) {
if os.Getenv("NETBIRD_STORE_ENGINE") == "mysql" {
t.Skip("Skipping timeout test for MySQL")
}
t.Setenv("NB_STORE_TRANSACTION_TIMEOUT", "1s")
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanup)
sqlStore, ok := store.(*SqlStore)
require.True(t, ok)
assert.Equal(t, 1*time.Second, sqlStore.transactionTimeout)
ctx := context.Background()
err = sqlStore.ExecuteInTransaction(ctx, func(transaction Store) error {
// Sleep for 2 seconds to exceed the 1 second timeout
time.Sleep(2 * time.Second)
return nil
})
// The transaction should fail with an error (either timeout or already rolled back)
require.Error(t, err)
assert.Contains(t, err.Error(), "transaction has already been committed or rolled back", "expected transaction rolled back error, got: %v", err)
}

View File

@@ -204,6 +204,7 @@ type Store interface {
MarkAccountPrimary(ctx context.Context, accountID string) error
UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error
GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) ([]*types.PolicyRule, error)
GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error)
}
const (

View File

@@ -16,6 +16,7 @@ import (
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
@@ -45,8 +46,10 @@ const (
// nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections.
nativeSSHPortString = "22022"
nativeSSHPortNumber = 22022
// defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections.
defaultSSHPortString = "22"
defaultSSHPortNumber = 22
)
type supportedFeatures struct {
@@ -275,6 +278,7 @@ func (a *Account) GetPeerNetworkMap(
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
metrics *telemetry.AccountManagerMetrics,
groupIDToUserIDs map[string][]string,
) *NetworkMap {
start := time.Now()
peer := a.Peers[peerID]
@@ -290,7 +294,7 @@ func (a *Account) GetPeerNetworkMap(
}
}
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
// exclude expired peers
var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer
@@ -338,6 +342,8 @@ func (a *Account) GetPeerNetworkMap(
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
AuthorizedUsers: authorizedUsers,
EnableSSH: enableSSH,
}
if metrics != nil {
@@ -1009,8 +1015,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
// GetPeerConnectionResources for a given peer
//
// This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, groupIDToUserIDs map[string][]string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
authorizedUsers := make(map[string]map[string]struct{}) // machine user to list of userIDs
sshEnabled := false
for _, policy := range a.Policies {
if !policy.Enabled {
@@ -1053,10 +1061,58 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
if peerInDestinations {
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
}
if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
sshEnabled = true
switch {
case len(rule.AuthorizedGroups) > 0:
for groupID, localUsers := range rule.AuthorizedGroups {
userIDs, ok := groupIDToUserIDs[groupID]
if !ok {
log.WithContext(ctx).Tracef("no user IDs found for group ID %s", groupID)
continue
}
if len(localUsers) == 0 {
localUsers = []string{auth.Wildcard}
}
for _, localUser := range localUsers {
if authorizedUsers[localUser] == nil {
authorizedUsers[localUser] = make(map[string]struct{})
}
for _, userID := range userIDs {
authorizedUsers[localUser][userID] = struct{}{}
}
}
}
case rule.AuthorizedUser != "":
if authorizedUsers[auth.Wildcard] == nil {
authorizedUsers[auth.Wildcard] = make(map[string]struct{})
}
authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{}
default:
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
}
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
sshEnabled = true
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
}
}
}
return getAccumulatedResources()
peers, fwRules := getAccumulatedResources()
return peers, fwRules, authorizedUsers, sshEnabled
}
func (a *Account) getAllowedUserIDs() map[string]struct{} {
users := make(map[string]struct{})
for _, nbUser := range a.Users {
if !nbUser.IsBlocked() && !nbUser.IsServiceUser {
users[nbUser.Id] = struct{}{}
}
}
return users
}
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
@@ -1081,12 +1137,17 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
peersExists[peer.ID] = struct{}{}
}
protocol := rule.Protocol
if protocol == PolicyRuleProtocolNetbirdSSH {
protocol = PolicyRuleProtocolTCP
}
fr := FirewallRule{
PolicyID: rule.ID,
PeerIP: peer.IP.String(),
Direction: direction,
Action: string(rule.Action),
Protocol: string(rule.Protocol),
Protocol: string(protocol),
}
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
@@ -1108,6 +1169,28 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
}
}
func policyRuleImpliesLegacySSH(rule *PolicyRule) bool {
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
}
func portRangeIncludesSSH(portRanges []RulePortRange) bool {
for _, pr := range portRanges {
if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) {
return true
}
}
return false
}
func portsIncludesSSH(ports []string) bool {
for _, port := range ports {
if port == defaultSSHPortString || port == nativeSSHPortString {
return true
}
}
return false
}
// getAllPeersFromGroups for given peer ID and list of groups
//
// Returns a list of peers from specified groups that pass specified posture checks
@@ -1152,7 +1235,11 @@ func (a *Account) getPeerFromResource(resource Resource, peerID string) ([]*nbpe
return []*nbpeer.Peer{}, false
}
return []*nbpeer.Peer{peer}, resource.ID == peerID
if peer.ID == peerID {
return []*nbpeer.Peer{}, true
}
return []*nbpeer.Peer{peer}, false
}
// validatePostureChecksOnPeer validates the posture checks on a peer
@@ -1660,6 +1747,26 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
return nil
}
func (a *Account) GetActiveGroupUsers() map[string][]string {
allGroupID := ""
group, err := a.GetGroupAll()
if err != nil {
log.Errorf("failed to get group all: %v", err)
} else {
allGroupID = group.ID
}
groups := make(map[string][]string, len(a.GroupsG))
for _, user := range a.Users {
if !user.IsBlocked() && !user.IsServiceUser {
for _, groupID := range user.AutoGroups {
groups[groupID] = append(groups[groupID], user.Id)
}
groups[allGroupID] = append(groups[allGroupID], user.Id)
}
}
return groups
}
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
@@ -1691,7 +1798,7 @@ func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer
expanded = append(expanded, &fr)
}
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) {
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) || rule.Protocol == PolicyRuleProtocolNetbirdSSH {
expanded = addNativeSSHRule(base, expanded)
}

View File

@@ -1105,6 +1105,193 @@ func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
}
}
func Test_GetActiveGroupUsers(t *testing.T) {
tests := []struct {
name string
account *Account
expected map[string][]string
}{
{
name: "all users are active",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1", "group2"},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group2", "group3"},
Blocked: false,
},
"user3": {
Id: "user3",
AutoGroups: []string{"group1"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user1", "user3"},
"group2": {"user1", "user2"},
"group3": {"user2"},
"": {"user1", "user2", "user3"},
},
},
{
name: "some users are blocked",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1", "group2"},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group2", "group3"},
Blocked: true,
},
"user3": {
Id: "user3",
AutoGroups: []string{"group1", "group3"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user1", "user3"},
"group2": {"user1"},
"group3": {"user3"},
"": {"user1", "user3"},
},
},
{
name: "all users are blocked",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1"},
Blocked: true,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group2"},
Blocked: true,
},
},
},
expected: map[string][]string{},
},
{
name: "user with no auto groups",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group1"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user2"},
"": {"user1", "user2"},
},
},
{
name: "empty account",
account: &Account{
Users: map[string]*User{},
},
expected: map[string][]string{},
},
{
name: "multiple users in same group",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1"},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group1"},
Blocked: false,
},
"user3": {
Id: "user3",
AutoGroups: []string{"group1"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user1", "user2", "user3"},
"": {"user1", "user2", "user3"},
},
},
{
name: "user in multiple groups with blocked users",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1", "group2", "group3"},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group1", "group2"},
Blocked: true,
},
"user3": {
Id: "user3",
AutoGroups: []string{"group3"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user1"},
"group2": {"user1"},
"group3": {"user1", "user3"},
"": {"user1", "user3"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetActiveGroupUsers()
// Check that the number of groups matches
assert.Equal(t, len(tt.expected), len(result), "number of groups should match")
// Check each group's users
for groupID, expectedUsers := range tt.expected {
actualUsers, exists := result[groupID]
assert.True(t, exists, "group %s should exist in result", groupID)
assert.ElementsMatch(t, expectedUsers, actualUsers, "users in group %s should match", groupID)
}
// Ensure no extra groups in result
for groupID := range result {
_, exists := tt.expected[groupID]
assert.True(t, exists, "unexpected group %s in result", groupID)
}
})
}
}
func Test_FilterZoneRecordsForPeers(t *testing.T) {
tests := []struct {
name string

View File

@@ -38,6 +38,8 @@ type NetworkMap struct {
FirewallRules []*FirewallRule
RoutesFirewallRules []*RouteFirewallRule
ForwardingRules []*ForwardingRule
AuthorizedUsers map[string]map[string]struct{}
EnableSSH bool
}
func (nm *NetworkMap) Merge(other *NetworkMap) {

View File

@@ -69,7 +69,7 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(networkMap)
@@ -141,7 +141,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
b.Run("old builder", func(b *testing.B) {
for range b.N {
for _, peerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -201,7 +201,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(networkMap)
@@ -320,7 +320,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -395,7 +395,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(networkMap)
@@ -550,7 +550,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -604,7 +604,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(networkMap)
@@ -730,7 +730,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(networkMap)
@@ -847,7 +847,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
b.Run("old builder after delete", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})

View File

@@ -23,6 +23,8 @@ const (
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
// PolicyRuleProtocolICMP type of traffic
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
// PolicyRuleProtocolNetbirdSSH type of traffic
PolicyRuleProtocolNetbirdSSH = PolicyRuleProtocolType("netbird-ssh")
)
const (
@@ -167,6 +169,8 @@ func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error)
protocol = PolicyRuleProtocolUDP
case "icmp":
return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'")
case "netbird-ssh":
return PolicyRuleProtocolNetbirdSSH, RulePortRange{Start: nativeSSHPortNumber, End: nativeSSHPortNumber}, nil
default:
return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr)
}

View File

@@ -80,6 +80,12 @@ type PolicyRule struct {
// PortRanges a list of port ranges.
PortRanges []RulePortRange `gorm:"serializer:json"`
// AuthorizedGroups is a map of groupIDs and their respective access to local users via ssh
AuthorizedGroups map[string][]string `gorm:"serializer:json"`
// AuthorizedUser is a list of userIDs that are authorized to access local resources via ssh
AuthorizedUser string
}
// Copy returns a copy of a policy rule
@@ -99,10 +105,16 @@ func (pm *PolicyRule) Copy() *PolicyRule {
Protocol: pm.Protocol,
Ports: make([]string, len(pm.Ports)),
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
AuthorizedGroups: make(map[string][]string, len(pm.AuthorizedGroups)),
AuthorizedUser: pm.AuthorizedUser,
}
copy(rule.Destinations, pm.Destinations)
copy(rule.Sources, pm.Sources)
copy(rule.Ports, pm.Ports)
copy(rule.PortRanges, pm.PortRanges)
for k, v := range pm.AuthorizedGroups {
rule.AuthorizedGroups[k] = make([]string, len(v))
copy(rule.AuthorizedGroups[k], v)
}
return rule
}

View File

@@ -523,16 +523,14 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
_, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
)
if err != nil {
return fmt.Errorf("failed to process update for user %s: %w", update.Id, err)
}
if userHadPeers {
updateAccountPeers = true
}
updateAccountPeers = true
err = transaction.SaveUser(ctx, updatedUser)
if err != nil {
@@ -581,7 +579,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
}
if settings.GroupsPropagationEnabled && updateAccountPeers {
if updateAccountPeers {
if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil {
return nil, fmt.Errorf("failed to increment network serial: %w", err)
}

View File

@@ -1379,11 +1379,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
updateManager.CloseChannel(context.Background(), peer1.ID)
})
// Creating a new regular user should not update account peers and not send peer update
// Creating a new regular user should send peer update (as users are not filtered yet)
t.Run("creating new regular user with no groups", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
@@ -1402,11 +1402,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
}
})
// updating user with no linked peers should not update account peers and not send peer update
// updating user with no linked peers should update account peers and send peer update (as users are not filtered yet)
t.Run("updating user with no linked peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()

View File

@@ -0,0 +1,216 @@
#!/bin/bash
#
# FreeBSD Port Diff Generator for NetBird
#
# This script generates the diff file required for submitting a FreeBSD port update.
# It works on macOS, Linux, and FreeBSD by fetching files from FreeBSD cgit and
# computing checksums from the Go module proxy.
#
# Usage: ./freebsd-port-diff.sh [new_version]
# Example: ./freebsd-port-diff.sh 0.60.7
#
# If no version is provided, it fetches the latest from GitHub.
set -e
GITHUB_REPO="netbirdio/netbird"
PORTS_CGIT_BASE="https://cgit.freebsd.org/ports/plain/security/netbird"
GO_PROXY="https://proxy.golang.org/github.com/netbirdio/netbird/@v"
OUTPUT_DIR="${OUTPUT_DIR:-.}"
AWK_FIRST_FIELD='{print $1}'
fetch_all_tags() {
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
sed 's/.*\/v//' | \
sort -u -V
return 0
}
fetch_current_ports_version() {
echo "Fetching current version from FreeBSD ports..." >&2
curl -sL "${PORTS_CGIT_BASE}/Makefile" 2>/dev/null | \
grep -E "^DISTVERSION=" | \
sed 's/DISTVERSION=[[:space:]]*//' | \
tr -d '\t '
return 0
}
fetch_latest_github_release() {
echo "Fetching latest release from GitHub..." >&2
fetch_all_tags | tail -1
return 0
}
fetch_ports_file() {
local filename="$1"
curl -sL "${PORTS_CGIT_BASE}/${filename}" 2>/dev/null
return 0
}
compute_checksums() {
local version="$1"
local tmpdir
tmpdir=$(mktemp -d)
# shellcheck disable=SC2064
trap "rm -rf '$tmpdir'" EXIT
echo "Downloading files from Go module proxy for v${version}..." >&2
local mod_file="${tmpdir}/v${version}.mod"
local zip_file="${tmpdir}/v${version}.zip"
curl -sL "${GO_PROXY}/v${version}.mod" -o "$mod_file" 2>/dev/null
curl -sL "${GO_PROXY}/v${version}.zip" -o "$zip_file" 2>/dev/null
if [[ ! -s "$mod_file" ]] || [[ ! -s "$zip_file" ]]; then
echo "Error: Could not download files from Go module proxy" >&2
return 1
fi
local mod_sha256 mod_size zip_sha256 zip_size
if command -v sha256sum &>/dev/null; then
mod_sha256=$(sha256sum "$mod_file" | awk "$AWK_FIRST_FIELD")
zip_sha256=$(sha256sum "$zip_file" | awk "$AWK_FIRST_FIELD")
elif command -v shasum &>/dev/null; then
mod_sha256=$(shasum -a 256 "$mod_file" | awk "$AWK_FIRST_FIELD")
zip_sha256=$(shasum -a 256 "$zip_file" | awk "$AWK_FIRST_FIELD")
else
echo "Error: No sha256 command found" >&2
return 1
fi
if [[ "$OSTYPE" == "darwin"* ]]; then
mod_size=$(stat -f%z "$mod_file")
zip_size=$(stat -f%z "$zip_file")
else
mod_size=$(stat -c%s "$mod_file")
zip_size=$(stat -c%s "$zip_file")
fi
echo "TIMESTAMP = $(date +%s)"
echo "SHA256 (go/security_netbird/netbird-v${version}/v${version}.mod) = ${mod_sha256}"
echo "SIZE (go/security_netbird/netbird-v${version}/v${version}.mod) = ${mod_size}"
echo "SHA256 (go/security_netbird/netbird-v${version}/v${version}.zip) = ${zip_sha256}"
echo "SIZE (go/security_netbird/netbird-v${version}/v${version}.zip) = ${zip_size}"
return 0
}
generate_new_makefile() {
local new_version="$1"
local old_makefile="$2"
# Check if old version had PORTREVISION
if echo "$old_makefile" | grep -q "^PORTREVISION="; then
# Remove PORTREVISION line and update DISTVERSION
echo "$old_makefile" | \
sed "s/^DISTVERSION=.*/DISTVERSION= ${new_version}/" | \
grep -v "^PORTREVISION="
else
# Just update DISTVERSION
echo "$old_makefile" | \
sed "s/^DISTVERSION=.*/DISTVERSION= ${new_version}/"
fi
return 0
}
# Parse arguments
NEW_VERSION="${1:-}"
# Auto-detect versions if not provided
OLD_VERSION=$(fetch_current_ports_version)
if [[ -z "$OLD_VERSION" ]]; then
echo "Error: Could not fetch current version from FreeBSD ports" >&2
exit 1
fi
echo "Current FreeBSD ports version: ${OLD_VERSION}" >&2
if [[ -z "$NEW_VERSION" ]]; then
NEW_VERSION=$(fetch_latest_github_release)
if [[ -z "$NEW_VERSION" ]]; then
echo "Error: Could not fetch latest release from GitHub" >&2
exit 1
fi
fi
echo "Target version: ${NEW_VERSION}" >&2
if [[ "$OLD_VERSION" = "$NEW_VERSION" ]]; then
echo "Port is already at version ${NEW_VERSION}. Nothing to do." >&2
exit 0
fi
echo "" >&2
# Fetch current files
echo "Fetching current Makefile from FreeBSD ports..." >&2
OLD_MAKEFILE=$(fetch_ports_file "Makefile")
if [[ -z "$OLD_MAKEFILE" ]]; then
echo "Error: Could not fetch Makefile" >&2
exit 1
fi
echo "Fetching current distinfo from FreeBSD ports..." >&2
OLD_DISTINFO=$(fetch_ports_file "distinfo")
if [[ -z "$OLD_DISTINFO" ]]; then
echo "Error: Could not fetch distinfo" >&2
exit 1
fi
# Generate new files
echo "Generating new Makefile..." >&2
NEW_MAKEFILE=$(generate_new_makefile "$NEW_VERSION" "$OLD_MAKEFILE")
echo "Computing checksums for new version..." >&2
NEW_DISTINFO=$(compute_checksums "$NEW_VERSION")
if [[ -z "$NEW_DISTINFO" ]]; then
echo "Error: Could not compute checksums" >&2
exit 1
fi
# Create temp files for diff
TMPDIR=$(mktemp -d)
# shellcheck disable=SC2064
trap "rm -rf '$TMPDIR'" EXIT
mkdir -p "${TMPDIR}/a/security/netbird" "${TMPDIR}/b/security/netbird"
echo "$OLD_MAKEFILE" > "${TMPDIR}/a/security/netbird/Makefile"
echo "$OLD_DISTINFO" > "${TMPDIR}/a/security/netbird/distinfo"
echo "$NEW_MAKEFILE" > "${TMPDIR}/b/security/netbird/Makefile"
echo "$NEW_DISTINFO" > "${TMPDIR}/b/security/netbird/distinfo"
# Generate diff
OUTPUT_FILE="${OUTPUT_DIR}/netbird-${NEW_VERSION}.diff"
echo "" >&2
echo "Generating diff..." >&2
# Generate diff and clean up temp paths to show standard a/b paths
(cd "${TMPDIR}" && diff -ruN "a/security/netbird" "b/security/netbird") > "$OUTPUT_FILE" || true
if [[ ! -s "$OUTPUT_FILE" ]]; then
echo "Error: Generated diff is empty" >&2
exit 1
fi
echo "" >&2
echo "========================================="
echo "Diff saved to: ${OUTPUT_FILE}"
echo "========================================="
echo ""
cat "$OUTPUT_FILE"
echo ""
echo "========================================="
echo ""
echo "Next steps:"
echo "1. Review the diff above"
echo "2. Submit to https://bugs.freebsd.org/bugzilla/"
echo "3. Use ./freebsd-port-issue-body.sh to generate the issue content"
echo ""
echo "For FreeBSD testing (optional but recommended):"
echo " cd /usr/ports/security/netbird"
echo " patch < ${OUTPUT_FILE}"
echo " make stage && make stage-qa && make package && make install"
echo " netbird status"
echo " make deinstall"

View File

@@ -0,0 +1,159 @@
#!/bin/bash
#
# FreeBSD Port Issue Body Generator for NetBird
#
# This script generates the issue body content for submitting a FreeBSD port update
# to the FreeBSD Bugzilla at https://bugs.freebsd.org/bugzilla/
#
# Usage: ./freebsd-port-issue-body.sh [old_version] [new_version]
# Example: ./freebsd-port-issue-body.sh 0.56.0 0.59.1
#
# If no versions are provided, the script will:
# - Fetch OLD version from FreeBSD ports cgit (current version in ports tree)
# - Fetch NEW version from latest NetBird GitHub release tag
set -e
GITHUB_REPO="netbirdio/netbird"
PORTS_CGIT_URL="https://cgit.freebsd.org/ports/plain/security/netbird/Makefile"
fetch_current_ports_version() {
echo "Fetching current version from FreeBSD ports..." >&2
local makefile_content
makefile_content=$(curl -sL "$PORTS_CGIT_URL" 2>/dev/null)
if [[ -z "$makefile_content" ]]; then
echo "Error: Could not fetch Makefile from FreeBSD ports" >&2
return 1
fi
echo "$makefile_content" | grep -E "^DISTVERSION=" | sed 's/DISTVERSION=[[:space:]]*//' | tr -d '\t '
return 0
}
fetch_all_tags() {
# Fetch tags from GitHub tags page (no rate limiting, no auth needed)
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
sed 's/.*\/v//' | \
sort -u -V
return 0
}
fetch_latest_github_release() {
echo "Fetching latest release from GitHub..." >&2
local latest
# Fetch from GitHub tags page
latest=$(fetch_all_tags | tail -1)
if [[ -z "$latest" ]]; then
# Fallback to GitHub API
latest=$(curl -sL "https://api.github.com/repos/${GITHUB_REPO}/releases/latest" 2>/dev/null | \
grep '"tag_name"' | sed 's/.*"tag_name": *"v\([^"]*\)".*/\1/')
fi
if [[ -z "$latest" ]]; then
echo "Error: Could not fetch latest release from GitHub" >&2
return 1
fi
echo "$latest"
return 0
}
OLD_VERSION="${1:-}"
NEW_VERSION="${2:-}"
if [[ -z "$OLD_VERSION" ]]; then
OLD_VERSION=$(fetch_current_ports_version)
if [[ -z "$OLD_VERSION" ]]; then
echo "Error: Could not determine old version. Please provide it manually." >&2
echo "Usage: $0 <old_version> <new_version>" >&2
exit 1
fi
echo "Detected OLD version from FreeBSD ports: $OLD_VERSION" >&2
fi
if [[ -z "$NEW_VERSION" ]]; then
NEW_VERSION=$(fetch_latest_github_release)
if [[ -z "$NEW_VERSION" ]]; then
echo "Error: Could not determine new version. Please provide it manually." >&2
echo "Usage: $0 <old_version> <new_version>" >&2
exit 1
fi
echo "Detected NEW version from GitHub: $NEW_VERSION" >&2
fi
if [[ "$OLD_VERSION" = "$NEW_VERSION" ]]; then
echo "Warning: OLD and NEW versions are the same ($OLD_VERSION). Port may already be up to date." >&2
fi
echo "" >&2
OUTPUT_DIR="${OUTPUT_DIR:-.}"
fetch_releases_between_versions() {
echo "Fetching release history from GitHub..." >&2
# Fetch all tags and filter to those between OLD and NEW versions
fetch_all_tags | \
while read -r ver; do
if [[ "$(printf '%s\n' "$OLD_VERSION" "$ver" | sort -V | head -n1)" = "$OLD_VERSION" ]] && \
[[ "$(printf '%s\n' "$ver" "$NEW_VERSION" | sort -V | head -n1)" = "$ver" ]] && \
[[ "$ver" != "$OLD_VERSION" ]]; then
echo "$ver"
fi
done
return 0
}
generate_changelog_section() {
local releases
releases=$(fetch_releases_between_versions)
echo "Changelogs:"
if [[ -n "$releases" ]]; then
echo "$releases" | while read -r ver; do
echo "https://github.com/${GITHUB_REPO}/releases/tag/v${ver}"
done
else
echo "https://github.com/${GITHUB_REPO}/releases/tag/v${NEW_VERSION}"
fi
return 0
}
OUTPUT_FILE="${OUTPUT_DIR}/netbird-${NEW_VERSION}-issue.txt"
cat << EOF > "$OUTPUT_FILE"
BUGZILLA ISSUE DETAILS
======================
Severity: Affects Some People
Summary: security/netbird: Update to ${NEW_VERSION}
Description:
------------
security/netbird: Update ${OLD_VERSION} => ${NEW_VERSION}
$(generate_changelog_section)
Commit log:
https://github.com/${GITHUB_REPO}/compare/v${OLD_VERSION}...v${NEW_VERSION}
EOF
echo "========================================="
echo "Issue body saved to: ${OUTPUT_FILE}"
echo "========================================="
echo ""
cat "$OUTPUT_FILE"
echo ""
echo "========================================="
echo ""
echo "Next steps:"
echo "1. Go to https://bugs.freebsd.org/bugzilla/ and login"
echo "2. Click 'Report an update or defect to a port'"
echo "3. Fill in:"
echo " - Severity: Affects Some People"
echo " - Summary: security/netbird: Update to ${NEW_VERSION}"
echo " - Description: Copy content from ${OUTPUT_FILE}"
echo "4. Attach diff file: netbird-${NEW_VERSION}.diff"
echo "5. Submit the bug report"

View File

@@ -111,6 +111,8 @@ func (c *GrpcClient) ready() bool {
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function
func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
backOff := defaultBackoff(ctx)
operation := func() error {
log.Debugf("management connection state %v", c.conn.GetState())
connState := c.conn.GetState()
@@ -128,10 +130,10 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
return err
}
return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler)
return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler, backOff)
}
err := backoff.Retry(operation, defaultBackoff(ctx))
err := backoff.Retry(operation, backOff)
if err != nil {
log.Warnf("exiting the Management service connection retry loop due to the unrecoverable error: %s", err)
}
@@ -140,7 +142,7 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
}
func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info,
msgHandler func(msg *proto.SyncResponse) error) error {
msgHandler func(msg *proto.SyncResponse) error, backOff backoff.BackOff) error {
ctx, cancelStream := context.WithCancel(ctx)
defer cancelStream()
@@ -158,6 +160,9 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key,
// blocking until error
err = c.receiveEvents(stream, serverPubKey, msgHandler)
// we need this reset because after a successful connection and a consequent error, backoff lib doesn't
// reset times and next try will start with a long delay
backOff.Reset()
if err != nil {
c.notifyDisconnected(err)
s, _ := gstatus.FromError(err)

View File

@@ -488,6 +488,8 @@ components:
description: Indicates whether the peer is ephemeral or not
type: boolean
example: false
local_flags:
$ref: '#/components/schemas/PeerLocalFlags'
required:
- city_name
- connected
@@ -514,6 +516,49 @@ components:
- serial_number
- extra_dns_labels
- ephemeral
PeerLocalFlags:
type: object
properties:
rosenpass_enabled:
description: Indicates whether Rosenpass is enabled on this peer
type: boolean
example: true
rosenpass_permissive:
description: Indicates whether Rosenpass is in permissive mode or not
type: boolean
example: false
server_ssh_allowed:
description: Indicates whether SSH access this peer is allowed or not
type: boolean
example: true
disable_client_routes:
description: Indicates whether client routes are disabled on this peer or not
type: boolean
example: false
disable_server_routes:
description: Indicates whether server routes are disabled on this peer or not
type: boolean
example: false
disable_dns:
description: Indicates whether DNS management is disabled on this peer or not
type: boolean
example: false
disable_firewall:
description: Indicates whether firewall management is disabled on this peer or not
type: boolean
example: false
block_lan_access:
description: Indicates whether LAN access is blocked on this peer when used as a routing peer
type: boolean
example: false
block_inbound:
description: Indicates whether inbound traffic is blocked on this peer
type: boolean
example: false
lazy_connection_enabled:
description: Indicates whether lazy connection is enabled on this peer
type: boolean
example: false
PeerTemporaryAccessRequest:
type: object
properties:
@@ -936,7 +981,7 @@ components:
protocol:
description: Policy rule type of the traffic
type: string
enum: ["all", "tcp", "udp", "icmp"]
enum: ["all", "tcp", "udp", "icmp", "netbird-ssh"]
example: "tcp"
ports:
description: Policy rule affected ports
@@ -949,6 +994,14 @@ components:
type: array
items:
$ref: '#/components/schemas/RulePortRange'
authorized_groups:
description: Map of user group ids to a list of local users
type: object
additionalProperties:
type: array
items:
type: string
example: "group1"
required:
- name
- enabled

View File

@@ -130,10 +130,11 @@ const (
// Defines values for PolicyRuleProtocol.
const (
PolicyRuleProtocolAll PolicyRuleProtocol = "all"
PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp"
PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp"
PolicyRuleProtocolUdp PolicyRuleProtocol = "udp"
PolicyRuleProtocolAll PolicyRuleProtocol = "all"
PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp"
PolicyRuleProtocolNetbirdSsh PolicyRuleProtocol = "netbird-ssh"
PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp"
PolicyRuleProtocolUdp PolicyRuleProtocol = "udp"
)
// Defines values for PolicyRuleMinimumAction.
@@ -144,10 +145,11 @@ const (
// Defines values for PolicyRuleMinimumProtocol.
const (
PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all"
PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp"
PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp"
PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp"
PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all"
PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp"
PolicyRuleMinimumProtocolNetbirdSsh PolicyRuleMinimumProtocol = "netbird-ssh"
PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp"
PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp"
)
// Defines values for PolicyRuleUpdateAction.
@@ -158,10 +160,11 @@ const (
// Defines values for PolicyRuleUpdateProtocol.
const (
PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all"
PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp"
PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp"
PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp"
PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all"
PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp"
PolicyRuleUpdateProtocolNetbirdSsh PolicyRuleUpdateProtocol = "netbird-ssh"
PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp"
PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp"
)
// Defines values for ResourceType.
@@ -1077,7 +1080,8 @@ type Peer struct {
LastLogin time.Time `json:"last_login"`
// LastSeen Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"`
LastSeen time.Time `json:"last_seen"`
LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"`
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
@@ -1167,7 +1171,8 @@ type PeerBatch struct {
LastLogin time.Time `json:"last_login"`
// LastSeen Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"`
LastSeen time.Time `json:"last_seen"`
LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"`
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
@@ -1197,6 +1202,39 @@ type PeerBatch struct {
Version string `json:"version"`
}
// PeerLocalFlags defines model for PeerLocalFlags.
type PeerLocalFlags struct {
// BlockInbound Indicates whether inbound traffic is blocked on this peer
BlockInbound *bool `json:"block_inbound,omitempty"`
// BlockLanAccess Indicates whether LAN access is blocked on this peer when used as a routing peer
BlockLanAccess *bool `json:"block_lan_access,omitempty"`
// DisableClientRoutes Indicates whether client routes are disabled on this peer or not
DisableClientRoutes *bool `json:"disable_client_routes,omitempty"`
// DisableDns Indicates whether DNS management is disabled on this peer or not
DisableDns *bool `json:"disable_dns,omitempty"`
// DisableFirewall Indicates whether firewall management is disabled on this peer or not
DisableFirewall *bool `json:"disable_firewall,omitempty"`
// DisableServerRoutes Indicates whether server routes are disabled on this peer or not
DisableServerRoutes *bool `json:"disable_server_routes,omitempty"`
// LazyConnectionEnabled Indicates whether lazy connection is enabled on this peer
LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"`
// RosenpassEnabled Indicates whether Rosenpass is enabled on this peer
RosenpassEnabled *bool `json:"rosenpass_enabled,omitempty"`
// RosenpassPermissive Indicates whether Rosenpass is in permissive mode or not
RosenpassPermissive *bool `json:"rosenpass_permissive,omitempty"`
// ServerSshAllowed Indicates whether SSH access this peer is allowed or not
ServerSshAllowed *bool `json:"server_ssh_allowed,omitempty"`
}
// PeerMinimum defines model for PeerMinimum.
type PeerMinimum struct {
// Id Peer ID
@@ -1349,6 +1387,9 @@ type PolicyRule struct {
// Action Policy rule accept or drops packets
Action PolicyRuleAction `json:"action"`
// AuthorizedGroups Map of user group ids to a list of local users
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
Bidirectional bool `json:"bidirectional"`
@@ -1393,6 +1434,9 @@ type PolicyRuleMinimum struct {
// Action Policy rule accept or drops packets
Action PolicyRuleMinimumAction `json:"action"`
// AuthorizedGroups Map of user group ids to a list of local users
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
Bidirectional bool `json:"bidirectional"`
@@ -1426,6 +1470,9 @@ type PolicyRuleUpdate struct {
// Action Policy rule accept or drops packets
Action PolicyRuleUpdateAction `json:"action"`
// AuthorizedGroups Map of user group ids to a list of local users
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
Bidirectional bool `json:"bidirectional"`

File diff suppressed because it is too large Load Diff

View File

@@ -332,6 +332,24 @@ message NetworkMap {
bool routesFirewallRulesIsEmpty = 11;
repeated ForwardingRule forwardingRules = 12;
// SSHAuth represents SSH authorization configuration
SSHAuth sshAuth = 13;
}
message SSHAuth {
// UserIDClaim is the JWT claim to be used to get the users ID
string UserIDClaim = 1;
// AuthorizedUsers is a list of hashed user IDs authorized to access this peer via SSH
repeated bytes AuthorizedUsers = 2;
// MachineUsers is a map of machine user names to their corresponding indexes in the AuthorizedUsers list
map<string, MachineUserIndexes> machine_users = 3;
}
message MachineUserIndexes {
repeated uint32 indexes = 1;
}
// RemotePeerConfig represents a configuration of a remote peer.

View File

@@ -0,0 +1,28 @@
package sshauth
import (
"encoding/hex"
"golang.org/x/crypto/blake2b"
)
// UserIDHash represents a hashed user ID (BLAKE2b-128)
type UserIDHash [16]byte
// HashUserID hashes a user ID using BLAKE2b-128 and returns the hash value
// This function must produce the same hash on both client and management server
func HashUserID(userID string) (UserIDHash, error) {
hash, err := blake2b.New(16, nil)
if err != nil {
return UserIDHash{}, err
}
hash.Write([]byte(userID))
var result UserIDHash
copy(result[:], hash.Sum(nil))
return result, nil
}
// String returns the hexadecimal string representation of the hash
func (h UserIDHash) String() string {
return hex.EncodeToString(h[:])
}

View File

@@ -0,0 +1,210 @@
package sshauth
import (
"testing"
)
func TestHashUserID(t *testing.T) {
tests := []struct {
name string
userID string
}{
{
name: "simple user ID",
userID: "user@example.com",
},
{
name: "UUID format",
userID: "550e8400-e29b-41d4-a716-446655440000",
},
{
name: "numeric ID",
userID: "12345",
},
{
name: "empty string",
userID: "",
},
{
name: "special characters",
userID: "user+test@domain.com",
},
{
name: "unicode characters",
userID: "用户@example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hash, err := HashUserID(tt.userID)
if err != nil {
t.Errorf("HashUserID() error = %v, want nil", err)
return
}
// Verify hash is non-zero for non-empty inputs
if tt.userID != "" && hash == [16]byte{} {
t.Errorf("HashUserID() returned zero hash for non-empty input")
}
})
}
}
func TestHashUserID_Consistency(t *testing.T) {
userID := "test@example.com"
hash1, err1 := HashUserID(userID)
if err1 != nil {
t.Fatalf("First HashUserID() error = %v", err1)
}
hash2, err2 := HashUserID(userID)
if err2 != nil {
t.Fatalf("Second HashUserID() error = %v", err2)
}
if hash1 != hash2 {
t.Errorf("HashUserID() is not consistent: got %v and %v for same input", hash1, hash2)
}
}
func TestHashUserID_Uniqueness(t *testing.T) {
tests := []struct {
userID1 string
userID2 string
}{
{"user1@example.com", "user2@example.com"},
{"alice@domain.com", "bob@domain.com"},
{"test", "test1"},
{"", "a"},
}
for _, tt := range tests {
hash1, err1 := HashUserID(tt.userID1)
if err1 != nil {
t.Fatalf("HashUserID(%s) error = %v", tt.userID1, err1)
}
hash2, err2 := HashUserID(tt.userID2)
if err2 != nil {
t.Fatalf("HashUserID(%s) error = %v", tt.userID2, err2)
}
if hash1 == hash2 {
t.Errorf("HashUserID() collision: %s and %s produced same hash %v", tt.userID1, tt.userID2, hash1)
}
}
}
func TestUserIDHash_String(t *testing.T) {
tests := []struct {
name string
hash UserIDHash
expected string
}{
{
name: "zero hash",
hash: [16]byte{},
expected: "00000000000000000000000000000000",
},
{
name: "small value",
hash: [16]byte{15: 0xff},
expected: "000000000000000000000000000000ff",
},
{
name: "large value",
hash: [16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe},
expected: "0000000000000000deadbeefcafebabe",
},
{
name: "max value",
hash: [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
expected: "ffffffffffffffffffffffffffffffff",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.hash.String()
if result != tt.expected {
t.Errorf("UserIDHash.String() = %v, want %v", result, tt.expected)
}
})
}
}
func TestUserIDHash_String_Length(t *testing.T) {
// Test that String() always returns 32 hex characters (16 bytes * 2)
userID := "test@example.com"
hash, err := HashUserID(userID)
if err != nil {
t.Fatalf("HashUserID() error = %v", err)
}
result := hash.String()
if len(result) != 32 {
t.Errorf("UserIDHash.String() length = %d, want 32", len(result))
}
// Verify it's valid hex
for i, c := range result {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
t.Errorf("UserIDHash.String() contains non-hex character at position %d: %c", i, c)
}
}
}
func TestHashUserID_KnownValues(t *testing.T) {
// Test with known BLAKE2b-128 values to ensure correct implementation
tests := []struct {
name string
userID string
expected UserIDHash
}{
{
name: "empty string",
userID: "",
// BLAKE2b-128 of empty string
expected: [16]byte{0xca, 0xe6, 0x69, 0x41, 0xd9, 0xef, 0xbd, 0x40, 0x4e, 0x4d, 0x88, 0x75, 0x8e, 0xa6, 0x76, 0x70},
},
{
name: "single character 'a'",
userID: "a",
// BLAKE2b-128 of "a"
expected: [16]byte{0x27, 0xc3, 0x5e, 0x6e, 0x93, 0x73, 0x87, 0x7f, 0x29, 0xe5, 0x62, 0x46, 0x4e, 0x46, 0x49, 0x7e},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hash, err := HashUserID(tt.userID)
if err != nil {
t.Errorf("HashUserID() error = %v", err)
return
}
if hash != tt.expected {
t.Errorf("HashUserID(%q) = %x, want %x",
tt.userID, hash, tt.expected)
}
})
}
}
func BenchmarkHashUserID(b *testing.B) {
userID := "user@example.com"
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = HashUserID(userID)
}
}
func BenchmarkUserIDHash_String(b *testing.B) {
hash := UserIDHash([16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe})
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = hash.String()
}
}

View File

@@ -2,12 +2,10 @@ package semaphoregroup
import (
"context"
"sync"
)
// SemaphoreGroup is a custom type that combines sync.WaitGroup and a semaphore.
type SemaphoreGroup struct {
waitGroup sync.WaitGroup
semaphore chan struct{}
}
@@ -18,31 +16,18 @@ func NewSemaphoreGroup(limit int) *SemaphoreGroup {
}
}
// Add increments the internal WaitGroup counter and acquires a semaphore slot.
func (sg *SemaphoreGroup) Add(ctx context.Context) {
sg.waitGroup.Add(1)
// Add acquire a slot
func (sg *SemaphoreGroup) Add(ctx context.Context) error {
// Acquire semaphore slot
select {
case <-ctx.Done():
return
return ctx.Err()
case sg.semaphore <- struct{}{}:
return nil
}
}
// Done decrements the internal WaitGroup counter and releases a semaphore slot.
func (sg *SemaphoreGroup) Done(ctx context.Context) {
sg.waitGroup.Done()
// Release semaphore slot
select {
case <-ctx.Done():
return
case <-sg.semaphore:
}
}
// Wait waits until the internal WaitGroup counter is zero.
func (sg *SemaphoreGroup) Wait() {
sg.waitGroup.Wait()
// Done releases a slot. Must be called after a successful Add.
func (sg *SemaphoreGroup) Done() {
<-sg.semaphore
}

View File

@@ -2,65 +2,89 @@ package semaphoregroup
import (
"context"
"sync"
"testing"
"time"
)
func TestSemaphoreGroup(t *testing.T) {
semGroup := NewSemaphoreGroup(2)
for i := 0; i < 5; i++ {
semGroup.Add(context.Background())
go func(id int) {
defer semGroup.Done(context.Background())
got := len(semGroup.semaphore)
if got == 0 {
t.Errorf("Expected semaphore length > 0 , got 0")
}
time.Sleep(time.Millisecond)
t.Logf("Goroutine %d is running\n", id)
}(i)
}
semGroup.Wait()
want := 0
got := len(semGroup.semaphore)
if got != want {
t.Errorf("Expected semaphore length %d, got %d", want, got)
}
}
func TestSemaphoreGroupContext(t *testing.T) {
semGroup := NewSemaphoreGroup(1)
semGroup.Add(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
_ = semGroup.Add(context.Background())
ctxTimeout, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
t.Cleanup(cancel)
rChan := make(chan struct{})
go func() {
semGroup.Add(ctx)
rChan <- struct{}{}
}()
select {
case <-rChan:
case <-time.NewTimer(2 * time.Second).C:
t.Error("Adding to semaphore group should not block when context is not done")
}
semGroup.Done(context.Background())
ctxDone, cancelDone := context.WithTimeout(context.Background(), 1*time.Second)
t.Cleanup(cancelDone)
go func() {
semGroup.Done(ctxDone)
rChan <- struct{}{}
}()
select {
case <-rChan:
case <-time.NewTimer(2 * time.Second).C:
t.Error("Releasing from semaphore group should not block when context is not done")
if err := semGroup.Add(ctxTimeout); err == nil {
t.Error("Adding to semaphore group should not block")
}
}
func TestSemaphoreGroupFreeUp(t *testing.T) {
semGroup := NewSemaphoreGroup(1)
_ = semGroup.Add(context.Background())
semGroup.Done()
ctxTimeout, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
t.Cleanup(cancel)
if err := semGroup.Add(ctxTimeout); err != nil {
t.Error(err)
}
}
func TestSemaphoreGroupCanceledContext(t *testing.T) {
semGroup := NewSemaphoreGroup(1)
_ = semGroup.Add(context.Background())
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
if err := semGroup.Add(ctx); err == nil {
t.Error("Add should return error when context is already canceled")
}
}
func TestSemaphoreGroupCancelWhileWaiting(t *testing.T) {
semGroup := NewSemaphoreGroup(1)
_ = semGroup.Add(context.Background())
ctx, cancel := context.WithCancel(context.Background())
errChan := make(chan error, 1)
go func() {
errChan <- semGroup.Add(ctx)
}()
time.Sleep(10 * time.Millisecond)
cancel()
if err := <-errChan; err == nil {
t.Error("Add should return error when context is canceled while waiting")
}
}
func TestSemaphoreGroupHighConcurrency(t *testing.T) {
const limit = 10
const numGoroutines = 100
semGroup := NewSemaphoreGroup(limit)
var wg sync.WaitGroup
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := semGroup.Add(context.Background()); err != nil {
t.Errorf("Unexpected error: %v", err)
return
}
time.Sleep(time.Millisecond)
semGroup.Done()
}()
}
wg.Wait()
// Verify all slots were released
if got := len(semGroup.semaphore); got != 0 {
t.Errorf("Expected semaphore to be empty, got %d slots occupied", got)
}
}