mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 17:56:39 +00:00
Compare commits
22 Commits
coderabbit
...
v0.61.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
08b782d6ba | ||
|
|
80a312cc9c | ||
|
|
9ba067391f | ||
|
|
7ac65bf1ad | ||
|
|
2e9c316852 | ||
|
|
96cdd56902 | ||
|
|
9ed1437442 | ||
|
|
a8604ef51c | ||
|
|
d88e046d00 | ||
|
|
1d2c7776fd | ||
|
|
4035f07248 | ||
|
|
ef2721f4e1 | ||
|
|
e11970e32e | ||
|
|
38f9d5ed58 | ||
|
|
b6a327e0c9 | ||
|
|
67f7b2404e | ||
|
|
73201c4f3e | ||
|
|
33d1761fe8 | ||
|
|
aa914a0f26 | ||
|
|
ab6a9e85de | ||
|
|
d3b123c76d | ||
|
|
fc4932a23f |
2
.github/workflows/golang-test-freebsd.yml
vendored
2
.github/workflows/golang-test-freebsd.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
# check all component except management, since we do not support management server on freebsd
|
||||
time go test -timeout 1m -failfast ./base62/...
|
||||
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
||||
time go test -timeout 8m -failfast -p 1 ./client/...
|
||||
time go test -timeout 8m -failfast -v -p 1 ./client/...
|
||||
time go test -timeout 1m -failfast ./dns/...
|
||||
time go test -timeout 1m -failfast ./encryption/...
|
||||
time go test -timeout 1m -failfast ./formatter/...
|
||||
|
||||
96
.github/workflows/release.yml
vendored
96
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.23"
|
||||
SIGN_PIPE_VER: "v0.1.0"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
@@ -19,6 +19,100 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
release_freebsd_port:
|
||||
name: "FreeBSD Port / Build & Test"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Generate FreeBSD port diff
|
||||
run: bash release_files/freebsd-port-diff.sh
|
||||
|
||||
- name: Generate FreeBSD port issue body
|
||||
run: bash release_files/freebsd-port-issue-body.sh
|
||||
|
||||
- name: Check if diff was generated
|
||||
id: check_diff
|
||||
run: |
|
||||
if ls netbird-*.diff 1> /dev/null 2>&1; then
|
||||
echo "diff_exists=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "diff_exists=false" >> $GITHUB_OUTPUT
|
||||
echo "No diff file generated (port may already be up to date)"
|
||||
fi
|
||||
|
||||
- name: Extract version
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
id: version
|
||||
run: |
|
||||
VERSION=$(ls netbird-*.diff | sed 's/netbird-\(.*\)\.diff/\1/')
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
echo "Generated files for version: $VERSION"
|
||||
cat netbird-*.diff
|
||||
|
||||
- name: Test FreeBSD port
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "15.0"
|
||||
prepare: |
|
||||
# Install required packages
|
||||
pkg install -y git curl portlint go
|
||||
|
||||
# Install Go for building
|
||||
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -LO "$GO_URL"
|
||||
tar -C /usr/local -xzf "$GO_TARBALL"
|
||||
|
||||
# Clone ports tree (shallow, only what we need)
|
||||
git clone --depth 1 --filter=blob:none https://git.FreeBSD.org/ports.git /usr/ports
|
||||
cd /usr/ports
|
||||
|
||||
run: |
|
||||
set -e -x
|
||||
export PATH=$PATH:/usr/local/go/bin
|
||||
|
||||
# Find the diff file
|
||||
echo "Finding diff file..."
|
||||
DIFF_FILE=$(find $PWD -name "netbird-*.diff" -type f 2>/dev/null | head -1)
|
||||
echo "Found: $DIFF_FILE"
|
||||
|
||||
if [[ -z "$DIFF_FILE" ]]; then
|
||||
echo "ERROR: Could not find diff file"
|
||||
find ~ -name "*.diff" -type f 2>/dev/null || true
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Apply the generated diff from /usr/ports (diff has a/security/netbird/... paths)
|
||||
cd /usr/ports
|
||||
patch -p1 -V none < "$DIFF_FILE"
|
||||
|
||||
# Show patched Makefile
|
||||
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
|
||||
|
||||
cd /usr/ports/security/netbird
|
||||
export BATCH=yes
|
||||
make package
|
||||
pkg add ./work/pkg/netbird-*.pkg
|
||||
|
||||
netbird version | grep "$version"
|
||||
|
||||
echo "FreeBSD port test completed successfully!"
|
||||
|
||||
- name: Upload FreeBSD port files
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: freebsd-port-files
|
||||
path: |
|
||||
./netbird-*-issue.txt
|
||||
./netbird-*.diff
|
||||
retention-days: 30
|
||||
|
||||
release:
|
||||
runs-on: ubuntu-latest-m
|
||||
env:
|
||||
|
||||
@@ -113,7 +113,7 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
|
||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
||||
|
||||
<p float="left" align="middle">
|
||||
<img src="https://docs.netbird.io/docs-static/img/architecture/high-level-dia.png" width="700"/>
|
||||
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
|
||||
</p>
|
||||
|
||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||
|
||||
@@ -386,6 +386,97 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("iptables-save"); err != nil {
|
||||
t.Skipf("iptables-save not available on this system: %v", err)
|
||||
}
|
||||
|
||||
// First ensure iptables-nft tables exist by running iptables-save
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "failed to create manager")
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset manager state")
|
||||
|
||||
// Verify iptables output after reset
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
const octet2Count = 25
|
||||
const octet3Count = 255
|
||||
prefixes := make([]netip.Prefix, 0, (octet2Count-1)*(octet3Count-1))
|
||||
for i := 1; i < octet2Count; i++ {
|
||||
for j := 1; j < octet3Count; j++ {
|
||||
addr := netip.AddrFrom4([4]byte{192, byte(j), byte(i), 0})
|
||||
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
|
||||
}
|
||||
}
|
||||
_, err = manager.AddRouteFiltering(
|
||||
nil,
|
||||
prefixes,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err, "failed to add route filtering rule")
|
||||
|
||||
stdout, stderr = runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("iptables-save"); err != nil {
|
||||
t.Skipf("iptables-save not available on this system: %v", err)
|
||||
}
|
||||
|
||||
// First ensure iptables-nft tables exist by running iptables-save
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "failed to create manager")
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset manager state")
|
||||
|
||||
// Verify iptables output after reset
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err, "failed to add route filtering rule")
|
||||
|
||||
stdout, stderr = runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
|
||||
t.Helper()
|
||||
require.Equal(t, len(got), len(want), "expression count mismatch")
|
||||
|
||||
@@ -48,9 +48,11 @@ const (
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
)
|
||||
|
||||
const refreshRulesMapError = "refresh rules map: %w"
|
||||
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
|
||||
maxPrefixesSet = 1500
|
||||
refreshRulesMapError = "refresh rules map: %w"
|
||||
)
|
||||
|
||||
var (
|
||||
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
|
||||
@@ -513,16 +515,35 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
|
||||
}
|
||||
|
||||
elements := convertPrefixesToSet(prefixes)
|
||||
if err := r.conn.AddSet(nfset, elements); err != nil {
|
||||
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
||||
}
|
||||
nElements := len(elements)
|
||||
|
||||
maxElements := maxPrefixesSet * 2
|
||||
initialElements := elements[:min(maxElements, nElements)]
|
||||
|
||||
if err := r.conn.AddSet(nfset, initialElements); err != nil {
|
||||
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
|
||||
|
||||
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
||||
var subEnd int
|
||||
for subStart := maxElements; subStart < nElements; subStart += maxElements {
|
||||
subEnd = min(subStart+maxElements, nElements)
|
||||
subElement := elements[subStart:subEnd]
|
||||
nSubPrefixes := len(subElement) / 2
|
||||
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
|
||||
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
|
||||
return nil, fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, setName, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
|
||||
}
|
||||
|
||||
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
|
||||
return nfset, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -45,10 +46,31 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu
|
||||
}
|
||||
}
|
||||
|
||||
// ErrInvalidTunnelFD is returned when the tunnel file descriptor is invalid (0).
|
||||
// This typically means the Swift code couldn't find the utun control socket.
|
||||
var ErrInvalidTunnelFD = fmt.Errorf("invalid tunnel file descriptor: fd is 0 (Swift failed to locate utun socket)")
|
||||
|
||||
func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
log.Infof("create tun interface")
|
||||
|
||||
dupTunFd, err := unix.Dup(t.tunFd)
|
||||
var tunDevice tun.Device
|
||||
var err error
|
||||
|
||||
// Validate the tunnel file descriptor.
|
||||
// On iOS/tvOS, the FD must be provided by the NEPacketTunnelProvider.
|
||||
// A value of 0 means the Swift code couldn't find the utun control socket
|
||||
// (the low-level APIs like ctl_info, sockaddr_ctl may not be exposed in
|
||||
// tvOS SDK headers). This is a hard error - there's no viable fallback
|
||||
// since tun.CreateTUN() cannot work within the iOS/tvOS sandbox.
|
||||
if t.tunFd == 0 {
|
||||
log.Errorf("Tunnel file descriptor is 0 - Swift code failed to locate the utun control socket. " +
|
||||
"On tvOS, ensure the NEPacketTunnelProvider is properly configured and the tunnel is started.")
|
||||
return nil, ErrInvalidTunnelFD
|
||||
}
|
||||
|
||||
// Normal iOS/tvOS path: use the provided file descriptor from NEPacketTunnelProvider
|
||||
var dupTunFd int
|
||||
dupTunFd, err = unix.Dup(t.tunFd)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to dup tun fd: %v", err)
|
||||
return nil, err
|
||||
@@ -60,7 +82,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
_ = unix.Close(dupTunFd)
|
||||
return nil, err
|
||||
}
|
||||
tunDevice, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
|
||||
tunDevice, err = tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to create new tun device from fd: %v", err)
|
||||
_ = unix.Close(dupTunFd)
|
||||
|
||||
@@ -80,6 +80,7 @@ type DefaultServer struct {
|
||||
updateSerial uint64
|
||||
previousConfigHash uint64
|
||||
currentConfig HostDNSConfig
|
||||
currentConfigHash uint64
|
||||
handlerChain *HandlerChain
|
||||
extraDomains map[domain.Domain]int
|
||||
|
||||
@@ -207,6 +208,7 @@ func newDefaultServer(
|
||||
hostsDNSHolder: newHostsDNSHolder(),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
mgmtCacheResolver: mgmtCacheResolver,
|
||||
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
|
||||
}
|
||||
|
||||
// register with root zone, handler chain takes care of the routing
|
||||
@@ -586,8 +588,29 @@ func (s *DefaultServer) applyHostConfig() {
|
||||
|
||||
log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
|
||||
|
||||
hash, err := hashstructure.Hash(config, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||
ZeroNil: true,
|
||||
IgnoreZeroValue: true,
|
||||
SlicesAsSets: true,
|
||||
UseStringer: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warnf("unable to hash the host dns configuration, will apply config anyway: %s", err)
|
||||
// Fall through to apply config anyway (fail-safe approach)
|
||||
} else if s.currentConfigHash == hash {
|
||||
log.Debugf("not applying host config as there are no changes")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("applying host config as there are changes")
|
||||
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
|
||||
log.Errorf("failed to apply DNS host manager update: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Only update hash if it was computed successfully and config was applied
|
||||
if err == nil {
|
||||
s.currentConfigHash = hash
|
||||
}
|
||||
|
||||
s.registerFallback(config)
|
||||
|
||||
@@ -1602,7 +1602,10 @@ func TestExtraDomains(t *testing.T) {
|
||||
"other.example.com.",
|
||||
"duplicate.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 4,
|
||||
// Expect 3 calls instead of 4 because when deregistering duplicate.example.com,
|
||||
// the domain remains in the config (ref count goes from 2 to 1), so the host
|
||||
// config hash doesn't change and applyDNSConfig is not called.
|
||||
applyHostConfigCall: 3,
|
||||
},
|
||||
{
|
||||
name: "Config update with new domains after registration",
|
||||
@@ -1657,7 +1660,10 @@ func TestExtraDomains(t *testing.T) {
|
||||
expectedMatchOnly: []string{
|
||||
"extra.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 3,
|
||||
// Expect 2 calls instead of 3 because when deregistering protected.example.com,
|
||||
// it's removed from extraDomains but still remains in the config (from customZones),
|
||||
// so the host config hash doesn't change and applyDNSConfig is not called.
|
||||
applyHostConfigCall: 2,
|
||||
},
|
||||
{
|
||||
name: "Register domain that is part of nameserver group",
|
||||
|
||||
@@ -1121,6 +1121,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||
|
||||
// Filter out own peer from the remote peers list
|
||||
localPubKey := e.config.WgPrivateKey.PublicKey().String()
|
||||
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
|
||||
for _, p := range networkMap.GetRemotePeers() {
|
||||
if p.GetWgPubKey() != localPubKey {
|
||||
remotePeers = append(remotePeers, p)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup request, most likely our peer has been deleted
|
||||
if networkMap.GetRemotePeersIsEmpty() {
|
||||
err := e.removeAllPeers()
|
||||
@@ -1129,32 +1138,34 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err := e.removePeers(networkMap.GetRemotePeers())
|
||||
err := e.removePeers(remotePeers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = e.modifyPeers(networkMap.GetRemotePeers())
|
||||
err = e.modifyPeers(remotePeers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = e.addNewPeers(networkMap.GetRemotePeers())
|
||||
err = e.addNewPeers(remotePeers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
|
||||
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
|
||||
e.updatePeerSSHHostKeys(remotePeers)
|
||||
|
||||
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
|
||||
if err := e.updateSSHClientConfig(remotePeers); err != nil {
|
||||
log.Warnf("failed to update SSH client config: %v", err)
|
||||
}
|
||||
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
}
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, networkMap.GetRemotePeers())
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
|
||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||
|
||||
e.networkSerial = serial
|
||||
|
||||
@@ -11,15 +11,18 @@ import (
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
type sshServer interface {
|
||||
Start(ctx context.Context, addr netip.AddrPort) error
|
||||
Stop() error
|
||||
GetStatus() (bool, []sshserver.SessionInfo)
|
||||
UpdateSSHAuth(config *sshauth.Config)
|
||||
}
|
||||
|
||||
func (e *Engine) setupSSHPortRedirection() error {
|
||||
@@ -353,3 +356,38 @@ func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.Sessio
|
||||
|
||||
return sshServer.GetStatus()
|
||||
}
|
||||
|
||||
// updateSSHServerAuth updates SSH fine-grained access control configuration on a running SSH server
|
||||
func (e *Engine) updateSSHServerAuth(sshAuth *mgmProto.SSHAuth) {
|
||||
if sshAuth == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if e.sshServer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
protoUsers := sshAuth.GetAuthorizedUsers()
|
||||
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
|
||||
for i, hash := range protoUsers {
|
||||
if len(hash) != 16 {
|
||||
log.Warnf("invalid hash length %d, expected 16 - skipping SSH server auth update", len(hash))
|
||||
return
|
||||
}
|
||||
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
|
||||
}
|
||||
|
||||
machineUsers := make(map[string][]uint32)
|
||||
for osUser, indexes := range sshAuth.GetMachineUsers() {
|
||||
machineUsers[osUser] = indexes.GetIndexes()
|
||||
}
|
||||
|
||||
// Update SSH server with new authorization configuration
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshAuth.GetUserIDClaim(),
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
MachineUsers: machineUsers,
|
||||
}
|
||||
|
||||
e.sshServer.UpdateSSHAuth(authConfig)
|
||||
}
|
||||
|
||||
@@ -148,13 +148,15 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
||||
// be used.
|
||||
func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
conn.semaphore.Add(engineCtx)
|
||||
if err := conn.semaphore.Add(engineCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
if conn.opened {
|
||||
conn.semaphore.Done(engineCtx)
|
||||
conn.semaphore.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -165,6 +167,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
conn.semaphore.Done()
|
||||
return err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
@@ -200,7 +203,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
defer conn.wg.Done()
|
||||
|
||||
conn.waitInitialRandomSleepTime(conn.ctx)
|
||||
conn.semaphore.Done(conn.ctx)
|
||||
conn.semaphore.Done()
|
||||
|
||||
conn.guard.Start(conn.ctx, conn.onGuardEvent)
|
||||
}()
|
||||
|
||||
@@ -3,6 +3,7 @@ package profilemanager
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -820,3 +821,85 @@ func readConfig(configPath string, createIfMissing bool) (*Config, error) {
|
||||
func WriteOutConfig(path string, config *Config) error {
|
||||
return util.WriteJson(context.Background(), path, config)
|
||||
}
|
||||
|
||||
// DirectWriteOutConfig writes config directly without atomic temp file operations.
|
||||
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
|
||||
func DirectWriteOutConfig(path string, config *Config) error {
|
||||
return util.DirectWriteJson(context.Background(), path, config)
|
||||
}
|
||||
|
||||
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
|
||||
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
|
||||
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if !fileExists(input.ConfigPath) {
|
||||
log.Infof("generating new config %s", input.ConfigPath)
|
||||
cfg, err := createNewConfig(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = util.DirectWriteJson(context.Background(), input.ConfigPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||
input.PreSharedKey = nil
|
||||
}
|
||||
|
||||
// Enforce permissions on existing config files (same as UpdateOrCreateConfig)
|
||||
if err := util.EnforcePermission(input.ConfigPath); err != nil {
|
||||
log.Errorf("failed to enforce permission on config file: %v", err)
|
||||
}
|
||||
|
||||
return directUpdate(input)
|
||||
}
|
||||
|
||||
func directUpdate(input ConfigInput) (*Config, error) {
|
||||
config := &Config{}
|
||||
|
||||
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updated, err := config.apply(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := util.DirectWriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// ConfigToJSON serializes a Config struct to a JSON string.
|
||||
// This is useful for exporting config to alternative storage mechanisms
|
||||
// (e.g., UserDefaults on tvOS where file writes are blocked).
|
||||
func ConfigToJSON(config *Config) (string, error) {
|
||||
bs, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(bs), nil
|
||||
}
|
||||
|
||||
// ConfigFromJSON deserializes a JSON string to a Config struct.
|
||||
// This is useful for restoring config from alternative storage mechanisms.
|
||||
// After unmarshaling, defaults are applied to ensure the config is fully initialized.
|
||||
func ConfigFromJSON(jsonStr string) (*Config, error) {
|
||||
config := &Config{}
|
||||
err := json.Unmarshal([]byte(jsonStr), config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Apply defaults to ensure required fields are initialized.
|
||||
// This mirrors what readConfig does after loading from file.
|
||||
if _, err := config.apply(ConfigInput{}); err != nil {
|
||||
return nil, fmt.Errorf("failed to apply defaults to config: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ const (
|
||||
|
||||
defaultTempDir = "/var/lib/netbird/tmp-install"
|
||||
|
||||
pkgDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
|
||||
pkgDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -22,8 +22,8 @@ const (
|
||||
|
||||
msiLogFile = "msi.log"
|
||||
|
||||
msiDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
|
||||
exeDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
|
||||
msiDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
|
||||
exeDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -75,6 +75,8 @@ type Client struct {
|
||||
dnsManager dns.IosDnsManager
|
||||
loginComplete bool
|
||||
connectClient *internal.ConnectClient
|
||||
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
|
||||
preloadedConfig *profilemanager.Config
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
@@ -92,17 +94,44 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s
|
||||
}
|
||||
}
|
||||
|
||||
// SetConfigFromJSON loads config from a JSON string into memory.
|
||||
// This is used on tvOS where file writes to App Group containers are blocked.
|
||||
// When set, IsLoginRequired() and Run() will use this preloaded config instead of reading from file.
|
||||
func (c *Client) SetConfigFromJSON(jsonStr string) error {
|
||||
cfg, err := profilemanager.ConfigFromJSON(jsonStr)
|
||||
if err != nil {
|
||||
log.Errorf("SetConfigFromJSON: failed to parse config JSON: %v", err)
|
||||
return err
|
||||
}
|
||||
c.preloadedConfig = cfg
|
||||
log.Infof("SetConfigFromJSON: config loaded successfully from JSON")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run start the internal client. It is a blocker function
|
||||
func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||
exportEnvList(envList)
|
||||
log.Infof("Starting NetBird client")
|
||||
log.Debugf("Tunnel uses interface: %s", interfaceName)
|
||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
StateFilePath: c.stateFile,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
var cfg *profilemanager.Config
|
||||
var err error
|
||||
|
||||
// Use preloaded config if available (tvOS where file writes are blocked)
|
||||
if c.preloadedConfig != nil {
|
||||
log.Infof("Run: using preloaded config from memory")
|
||||
cfg = c.preloadedConfig
|
||||
} else {
|
||||
log.Infof("Run: loading config from file")
|
||||
// Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
|
||||
// which are blocked by the tvOS sandbox in App Group containers
|
||||
cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
StateFilePath: c.stateFile,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
||||
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
|
||||
@@ -120,7 +149,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||
c.ctxCancelLock.Unlock()
|
||||
|
||||
auth := NewAuthWithConfig(ctx, cfg)
|
||||
err = auth.Login()
|
||||
err = auth.LoginSync()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -208,14 +237,45 @@ func (c *Client) IsLoginRequired() bool {
|
||||
defer c.ctxCancelLock.Unlock()
|
||||
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
|
||||
|
||||
cfg, _ := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
var cfg *profilemanager.Config
|
||||
var err error
|
||||
|
||||
needsLogin, _ := internal.IsLoginRequired(ctx, cfg)
|
||||
// Use preloaded config if available (tvOS where file writes are blocked)
|
||||
if c.preloadedConfig != nil {
|
||||
log.Infof("IsLoginRequired: using preloaded config from memory")
|
||||
cfg = c.preloadedConfig
|
||||
} else {
|
||||
log.Infof("IsLoginRequired: loading config from file")
|
||||
// Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
|
||||
// which are blocked by the tvOS sandbox in App Group containers
|
||||
cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("IsLoginRequired: failed to load config: %v", err)
|
||||
// If we can't load config, assume login is required
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
log.Errorf("IsLoginRequired: config is nil")
|
||||
return true
|
||||
}
|
||||
|
||||
needsLogin, err := internal.IsLoginRequired(ctx, cfg)
|
||||
if err != nil {
|
||||
log.Errorf("IsLoginRequired: check failed: %v", err)
|
||||
// If the check fails, assume login is required to be safe
|
||||
return true
|
||||
}
|
||||
log.Infof("IsLoginRequired: needsLogin=%v", needsLogin)
|
||||
return needsLogin
|
||||
}
|
||||
|
||||
// loginForMobileAuthTimeout is the timeout for requesting auth info from the server
|
||||
const loginForMobileAuthTimeout = 30 * time.Second
|
||||
|
||||
func (c *Client) LoginForMobile() string {
|
||||
var ctx context.Context
|
||||
//nolint
|
||||
@@ -228,16 +288,26 @@ func (c *Client) LoginForMobile() string {
|
||||
defer c.ctxCancelLock.Unlock()
|
||||
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
|
||||
|
||||
cfg, _ := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
// Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
|
||||
// which are blocked by the tvOS sandbox in App Group containers
|
||||
cfg, err := profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("LoginForMobile: failed to load config: %v", err)
|
||||
return fmt.Sprintf("failed to load config: %v", err)
|
||||
}
|
||||
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, false, "")
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||
// Use a bounded timeout for the auth info request to prevent indefinite hangs
|
||||
authInfoCtx, authInfoCancel := context.WithTimeout(ctx, loginForMobileAuthTimeout)
|
||||
defer authInfoCancel()
|
||||
|
||||
flowInfo, err := oAuthFlow.RequestAuthInfo(authInfoCtx)
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
@@ -249,10 +319,14 @@ func (c *Client) LoginForMobile() string {
|
||||
defer cancel()
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
if err != nil {
|
||||
log.Errorf("LoginForMobile: WaitToken failed: %v", err)
|
||||
return
|
||||
}
|
||||
jwtToken := tokenInfo.GetTokenToUse()
|
||||
_ = internal.Login(ctx, cfg, "", jwtToken)
|
||||
if err := internal.Login(ctx, cfg, "", jwtToken); err != nil {
|
||||
log.Errorf("LoginForMobile: Login failed: %v", err)
|
||||
return
|
||||
}
|
||||
c.loginComplete = true
|
||||
}()
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
@@ -33,7 +34,8 @@ type ErrListener interface {
|
||||
// URLOpener it is a callback interface. The Open function will be triggered if
|
||||
// the backend want to show an url for the user
|
||||
type URLOpener interface {
|
||||
Open(string)
|
||||
Open(url string, userCode string)
|
||||
OnLoginSuccess()
|
||||
}
|
||||
|
||||
// Auth can register or login new client
|
||||
@@ -72,13 +74,32 @@ func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth
|
||||
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
|
||||
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
||||
// is not supported and returns false without saving the configuration. For other errors return false.
|
||||
func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
|
||||
func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||
if listener == nil {
|
||||
log.Errorf("SaveConfigIfSSOSupported: listener is nil")
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
sso, err := a.saveConfigIfSSOSupported()
|
||||
if err != nil {
|
||||
listener.OnError(err)
|
||||
} else {
|
||||
listener.OnSuccess(sso)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
supportsSSO := true
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
s, ok := gstatus.FromError(err)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
||||
supportsSSO = false
|
||||
err = nil
|
||||
}
|
||||
@@ -97,12 +118,29 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
|
||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
||||
// which are blocked by the tvOS sandbox in App Group containers
|
||||
err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
|
||||
return true, err
|
||||
}
|
||||
|
||||
// LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key.
|
||||
func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||
func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupKey string, deviceName string) {
|
||||
if resultListener == nil {
|
||||
log.Errorf("LoginWithSetupKeyAndSaveConfig: resultListener is nil")
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
err := a.loginWithSetupKeyAndSaveConfig(setupKey, deviceName)
|
||||
if err != nil {
|
||||
resultListener.OnError(err)
|
||||
} else {
|
||||
resultListener.OnSuccess()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||
//nolint
|
||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||
|
||||
@@ -118,10 +156,14 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
||||
// which are blocked by the tvOS sandbox in App Group containers
|
||||
return profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
|
||||
}
|
||||
|
||||
func (a *Auth) Login() error {
|
||||
// LoginSync performs a synchronous login check without UI interaction
|
||||
// Used for background VPN connection where user should already be authenticated
|
||||
func (a *Auth) LoginSync() error {
|
||||
var needsLogin bool
|
||||
|
||||
// check if we need to generate JWT token
|
||||
@@ -135,23 +177,142 @@ func (a *Auth) Login() error {
|
||||
|
||||
jwtToken := ""
|
||||
if needsLogin {
|
||||
return fmt.Errorf("Not authenticated")
|
||||
return fmt.Errorf("not authenticated")
|
||||
}
|
||||
|
||||
err = a.withBackOff(a.ctx, func() error {
|
||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
return nil
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// PermissionDenied means registration is required or peer is blocked
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Login performs interactive login with device authentication support
|
||||
// Deprecated: Use LoginWithDeviceName instead to ensure proper device naming on tvOS
|
||||
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, forceDeviceAuth bool) {
|
||||
// Use empty device name - system will use hostname as fallback
|
||||
a.LoginWithDeviceName(resultListener, urlOpener, forceDeviceAuth, "")
|
||||
}
|
||||
|
||||
// LoginWithDeviceName performs interactive login with device authentication support
|
||||
// The deviceName parameter allows specifying a custom device name (required for tvOS)
|
||||
func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpener, forceDeviceAuth bool, deviceName string) {
|
||||
if resultListener == nil {
|
||||
log.Errorf("LoginWithDeviceName: resultListener is nil")
|
||||
return
|
||||
}
|
||||
if urlOpener == nil {
|
||||
log.Errorf("LoginWithDeviceName: urlOpener is nil")
|
||||
resultListener.OnError(fmt.Errorf("urlOpener is nil"))
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
err := a.login(urlOpener, forceDeviceAuth, deviceName)
|
||||
if err != nil {
|
||||
resultListener.OnError(err)
|
||||
} else {
|
||||
resultListener.OnSuccess()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
|
||||
var needsLogin bool
|
||||
|
||||
// Create context with device name if provided
|
||||
ctx := a.ctx
|
||||
if deviceName != "" {
|
||||
//nolint:staticcheck
|
||||
ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||
}
|
||||
|
||||
// check if we need to generate JWT token
|
||||
err := a.withBackOff(ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(ctx, a.config)
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
if needsLogin {
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
err = a.withBackOff(ctx, func() error {
|
||||
err := internal.Login(ctx, a.config, "", jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// PermissionDenied means registration is required or peer is blocked
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
// Save the config before notifying success to ensure persistence completes
|
||||
// before the callback potentially triggers teardown on the Swift side.
|
||||
// Note: This differs from Android which doesn't save config after login.
|
||||
// On iOS/tvOS, we save here because:
|
||||
// 1. The config may have been modified during login (e.g., new tokens)
|
||||
// 2. On tvOS, the Network Extension context may be the only place with
|
||||
// write permissions to the App Group container
|
||||
if a.cfgPath != "" {
|
||||
if err := profilemanager.DirectWriteOutConfig(a.cfgPath, a.config); err != nil {
|
||||
log.Warnf("failed to save config after login: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Notify caller of successful login synchronously before returning
|
||||
urlOpener.OnLoginSuccess()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const authInfoRequestTimeout = 30 * time.Second
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use a bounded timeout for the auth info request to prevent indefinite hangs
|
||||
authInfoCtx, authInfoCancel := context.WithTimeout(a.ctx, authInfoRequestTimeout)
|
||||
defer authInfoCancel()
|
||||
|
||||
flowInfo, err := oAuthFlow.RequestAuthInfo(authInfoCtx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||
}
|
||||
|
||||
urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||
}
|
||||
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
||||
return backoff.RetryNotify(
|
||||
bf,
|
||||
@@ -160,3 +321,24 @@ func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
||||
})
|
||||
}
|
||||
|
||||
// GetConfigJSON returns the current config as a JSON string.
|
||||
// This can be used by the caller to persist the config via alternative storage
|
||||
// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).
|
||||
func (a *Auth) GetConfigJSON() (string, error) {
|
||||
if a.config == nil {
|
||||
return "", fmt.Errorf("no config available")
|
||||
}
|
||||
return profilemanager.ConfigToJSON(a.config)
|
||||
}
|
||||
|
||||
// SetConfigFromJSON loads config from a JSON string.
|
||||
// This can be used to restore config from alternative storage mechanisms.
|
||||
func (a *Auth) SetConfigFromJSON(jsonStr string) error {
|
||||
cfg, err := profilemanager.ConfigFromJSON(jsonStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.config = cfg
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -112,6 +112,8 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) {
|
||||
|
||||
// Commit write out the changes into config file
|
||||
func (p *Preferences) Commit() error {
|
||||
_, err := profilemanager.UpdateOrCreateConfig(p.configInput)
|
||||
// Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
|
||||
// which are blocked by the tvOS sandbox in App Group containers
|
||||
_, err := profilemanager.DirectUpdateOrCreateConfig(p.configInput)
|
||||
return err
|
||||
}
|
||||
|
||||
184
client/ssh/auth/auth.go
Normal file
184
client/ssh/auth/auth.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultUserIDClaim is the default JWT claim used to extract user IDs
|
||||
DefaultUserIDClaim = "sub"
|
||||
// Wildcard is a special user ID that matches all users
|
||||
Wildcard = "*"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyUserID = errors.New("JWT user ID is empty")
|
||||
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
|
||||
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
|
||||
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
|
||||
)
|
||||
|
||||
// Authorizer handles SSH fine-grained access control authorization
|
||||
type Authorizer struct {
|
||||
// UserIDClaim is the JWT claim to extract the user ID from
|
||||
userIDClaim string
|
||||
|
||||
// authorizedUsers is a list of hashed user IDs authorized to access this peer
|
||||
authorizedUsers []sshuserhash.UserIDHash
|
||||
|
||||
// machineUsers maps OS login usernames to lists of authorized user indexes
|
||||
machineUsers map[string][]uint32
|
||||
|
||||
// mu protects the list of users
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Config contains configuration for the SSH authorizer
|
||||
type Config struct {
|
||||
// UserIDClaim is the JWT claim to extract the user ID from (e.g., "sub", "email")
|
||||
UserIDClaim string
|
||||
|
||||
// AuthorizedUsers is a list of hashed user IDs (FNV-1a 64-bit) authorized to access this peer
|
||||
AuthorizedUsers []sshuserhash.UserIDHash
|
||||
|
||||
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
|
||||
// If a user wants to login as a specific OS user, their index must be in the corresponding list
|
||||
MachineUsers map[string][]uint32
|
||||
}
|
||||
|
||||
// NewAuthorizer creates a new SSH authorizer with empty configuration
|
||||
func NewAuthorizer() *Authorizer {
|
||||
a := &Authorizer{
|
||||
userIDClaim: DefaultUserIDClaim,
|
||||
machineUsers: make(map[string][]uint32),
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
// Update updates the authorizer configuration with new values
|
||||
func (a *Authorizer) Update(config *Config) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
if config == nil {
|
||||
// Clear authorization
|
||||
a.userIDClaim = DefaultUserIDClaim
|
||||
a.authorizedUsers = []sshuserhash.UserIDHash{}
|
||||
a.machineUsers = make(map[string][]uint32)
|
||||
log.Info("SSH authorization cleared")
|
||||
return
|
||||
}
|
||||
|
||||
userIDClaim := config.UserIDClaim
|
||||
if userIDClaim == "" {
|
||||
userIDClaim = DefaultUserIDClaim
|
||||
}
|
||||
a.userIDClaim = userIDClaim
|
||||
|
||||
// Store authorized users list
|
||||
a.authorizedUsers = config.AuthorizedUsers
|
||||
|
||||
// Store machine users mapping
|
||||
machineUsers := make(map[string][]uint32)
|
||||
for osUser, indexes := range config.MachineUsers {
|
||||
if len(indexes) > 0 {
|
||||
machineUsers[osUser] = indexes
|
||||
}
|
||||
}
|
||||
a.machineUsers = machineUsers
|
||||
|
||||
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
|
||||
len(config.AuthorizedUsers), len(machineUsers))
|
||||
}
|
||||
|
||||
// Authorize validates if a user is authorized to login as the specified OS user
|
||||
// Returns nil if authorized, or an error describing why authorization failed
|
||||
func (a *Authorizer) Authorize(jwtUserID, osUsername string) error {
|
||||
if jwtUserID == "" {
|
||||
log.Warnf("SSH auth denied: JWT user ID is empty for OS user '%s'", osUsername)
|
||||
return ErrEmptyUserID
|
||||
}
|
||||
|
||||
// Hash the JWT user ID for comparison
|
||||
hashedUserID, err := sshuserhash.HashUserID(jwtUserID)
|
||||
if err != nil {
|
||||
log.Errorf("SSH auth denied: failed to hash user ID '%s' for OS user '%s': %v", jwtUserID, osUsername, err)
|
||||
return fmt.Errorf("failed to hash user ID: %w", err)
|
||||
}
|
||||
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
|
||||
// Find the index of this user in the authorized list
|
||||
userIndex, found := a.findUserIndex(hashedUserID)
|
||||
if !found {
|
||||
log.Warnf("SSH auth denied: user '%s' (hash: %s) not in authorized list for OS user '%s'", jwtUserID, hashedUserID, osUsername)
|
||||
return ErrUserNotAuthorized
|
||||
}
|
||||
|
||||
return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex)
|
||||
}
|
||||
|
||||
// checkMachineUserMapping validates if a user's index is authorized for the specified OS user
|
||||
// Checks wildcard mapping first, then specific OS user mappings
|
||||
func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) error {
|
||||
// If wildcard exists and user's index is in the wildcard list, allow access to any OS user
|
||||
if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard {
|
||||
if a.isIndexInList(uint32(userIndex), wildcardIndexes) {
|
||||
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' via wildcard (index: %d)", jwtUserID, osUsername, userIndex)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check for specific OS username mapping
|
||||
allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername]
|
||||
if !hasMachineUserMapping {
|
||||
// No mapping for this OS user - deny by default (fail closed)
|
||||
log.Warnf("SSH auth denied: no machine user mapping for OS user '%s' (JWT user: %s)", osUsername, jwtUserID)
|
||||
return ErrNoMachineUserMapping
|
||||
}
|
||||
|
||||
// Check if user's index is in the allowed indexes for this specific OS user
|
||||
if !a.isIndexInList(uint32(userIndex), allowedIndexes) {
|
||||
log.Warnf("SSH auth denied: user '%s' not mapped to OS user '%s' (user index: %d)", jwtUserID, osUsername, userIndex)
|
||||
return ErrUserNotMappedToOSUser
|
||||
}
|
||||
|
||||
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' (index: %d)", jwtUserID, osUsername, userIndex)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserIDClaim returns the JWT claim name used to extract user IDs
|
||||
func (a *Authorizer) GetUserIDClaim() string {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
return a.userIDClaim
|
||||
}
|
||||
|
||||
// findUserIndex finds the index of a hashed user ID in the authorized users list
|
||||
// Returns the index and true if found, 0 and false if not found
|
||||
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {
|
||||
for i, id := range a.authorizedUsers {
|
||||
if id == hashedUserID {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// isIndexInList checks if an index exists in a list of indexes
|
||||
func (a *Authorizer) isIndexInList(index uint32, indexes []uint32) bool {
|
||||
for _, idx := range indexes {
|
||||
if idx == index {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
612
client/ssh/auth/auth_test.go
Normal file
612
client/ssh/auth/auth_test.go
Normal file
@@ -0,0 +1,612 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
func TestAuthorizer_Authorize_UserNotInList(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Set up authorized users list with one user
|
||||
authorizedUserHash, err := sshauth.HashUserID("authorized-user")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{authorizedUserHash},
|
||||
MachineUsers: map[string][]uint32{},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// Try to authorize a different user
|
||||
err = authorizer.Authorize("unauthorized-user", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
user2Hash, err := sshauth.HashUserID("user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
|
||||
MachineUsers: map[string][]uint32{}, // Empty = deny all (fail closed)
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// All attempts should fail when no machine user mappings exist (fail closed)
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
|
||||
err = authorizer.Authorize("user2", "admin")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
|
||||
err = authorizer.Authorize("user1", "postgres")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
user2Hash, err := sshauth.HashUserID("user2")
|
||||
require.NoError(t, err)
|
||||
user3Hash, err := sshauth.HashUserID("user3")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": {0, 1}, // user1 and user2 can access root
|
||||
"postgres": {1, 2}, // user2 and user3 can access postgres
|
||||
"admin": {0}, // only user1 can access admin
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// user1 (index 0) should access root and admin
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user1", "admin")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// user2 (index 1) should access root and postgres
|
||||
err = authorizer.Authorize("user2", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user2", "postgres")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// user3 (index 2) should access postgres
|
||||
err = authorizer.Authorize("user3", "postgres")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Set up authorized users list
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
user2Hash, err := sshauth.HashUserID("user2")
|
||||
require.NoError(t, err)
|
||||
user3Hash, err := sshauth.HashUserID("user3")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": {0, 1}, // user1 and user2 can access root
|
||||
"postgres": {1, 2}, // user2 and user3 can access postgres
|
||||
"admin": {0}, // only user1 can access admin
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// user1 (index 0) should NOT access postgres
|
||||
err = authorizer.Authorize("user1", "postgres")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||
|
||||
// user2 (index 1) should NOT access admin
|
||||
err = authorizer.Authorize("user2", "admin")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||
|
||||
// user3 (index 2) should NOT access root
|
||||
err = authorizer.Authorize("user3", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||
|
||||
// user3 (index 2) should NOT access admin
|
||||
err = authorizer.Authorize("user3", "admin")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Set up authorized users list
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": {0}, // only root is mapped
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// user1 should NOT access an unmapped OS user (fail closed)
|
||||
err = authorizer.Authorize("user1", "postgres")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Set up authorized users list
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||
MachineUsers: map[string][]uint32{},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// Empty user ID should fail
|
||||
err = authorizer.Authorize("", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrEmptyUserID)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Set up multiple authorized users
|
||||
userHashes := make([]sshauth.UserIDHash, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
hash, err := sshauth.HashUserID("user" + string(rune('0'+i)))
|
||||
require.NoError(t, err)
|
||||
userHashes[i] = hash
|
||||
}
|
||||
|
||||
// Create machine user mapping for all users
|
||||
rootIndexes := make([]uint32, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
rootIndexes[i] = uint32(i)
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: userHashes,
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": rootIndexes,
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// All users should be authorized for root
|
||||
for i := 0; i < 10; i++ {
|
||||
err := authorizer.Authorize("user"+string(rune('0'+i)), "root")
|
||||
assert.NoError(t, err, "user%d should be authorized", i)
|
||||
}
|
||||
|
||||
// User not in list should fail
|
||||
err := authorizer.Authorize("unknown-user", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Set up initial configuration
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||
MachineUsers: map[string][]uint32{"root": {0}},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// user1 should be authorized
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Clear configuration
|
||||
authorizer.Update(nil)
|
||||
|
||||
// user1 should no longer be authorized
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Machine users with empty index lists should be filtered out
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": {0},
|
||||
"postgres": {}, // empty list - should be filtered out
|
||||
"admin": nil, // nil list - should be filtered out
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// root should work
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// postgres should fail (no mapping)
|
||||
err = authorizer.Authorize("user1", "postgres")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
|
||||
// admin should fail (no mapping)
|
||||
err = authorizer.Authorize("user1", "admin")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
}
|
||||
|
||||
func TestAuthorizer_CustomUserIDClaim(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Set up with custom user ID claim
|
||||
user1Hash, err := sshauth.HashUserID("user@example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: "email",
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": {0},
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// Verify the custom claim is set
|
||||
assert.Equal(t, "email", authorizer.GetUserIDClaim())
|
||||
|
||||
// Authorize with email as user ID
|
||||
err = authorizer.Authorize("user@example.com", "root")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuthorizer_DefaultUserIDClaim(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Verify default claim
|
||||
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
|
||||
assert.Equal(t, "sub", authorizer.GetUserIDClaim())
|
||||
|
||||
// Set up with empty user ID claim (should use default)
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: "", // empty - should use default
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||
MachineUsers: map[string][]uint32{},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// Should fall back to default
|
||||
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
|
||||
}
|
||||
|
||||
func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Create a large authorized users list
|
||||
const numUsers = 1000
|
||||
userHashes := make([]sshauth.UserIDHash, numUsers)
|
||||
for i := 0; i < numUsers; i++ {
|
||||
hash, err := sshauth.HashUserID("user" + string(rune(i)))
|
||||
require.NoError(t, err)
|
||||
userHashes[i] = hash
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: userHashes,
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": {0, 500, 999}, // first, middle, and last user
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// First user should have access
|
||||
err := authorizer.Authorize("user"+string(rune(0)), "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Middle user should have access
|
||||
err = authorizer.Authorize("user"+string(rune(500)), "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Last user should have access
|
||||
err = authorizer.Authorize("user"+string(rune(999)), "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// User not in mapping should NOT have access
|
||||
err = authorizer.Authorize("user"+string(rune(100)), "root")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAuthorizer_ConcurrentAuthorization(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Set up authorized users
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
user2Hash, err := sshauth.HashUserID("user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": {0, 1},
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// Test concurrent authorization calls (should be safe to read concurrently)
|
||||
const numGoroutines = 100
|
||||
errChan := make(chan error, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(idx int) {
|
||||
user := "user1"
|
||||
if idx%2 == 0 {
|
||||
user = "user2"
|
||||
}
|
||||
err := authorizer.Authorize(user, "root")
|
||||
errChan <- err
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete and collect errors
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
err := <-errChan
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
user2Hash, err := sshauth.HashUserID("user2")
|
||||
require.NoError(t, err)
|
||||
user3Hash, err := sshauth.HashUserID("user3")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Configure with wildcard - all authorized users can access any OS user
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"*": {0, 1, 2}, // wildcard with all user indexes
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// All authorized users should be able to access any OS user
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user2", "postgres")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user3", "admin")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user1", "ubuntu")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user2", "nginx")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user3", "docker")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Configure with wildcard
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"*": {0},
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// user1 should have access
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Unauthorized user should still be denied even with wildcard
|
||||
err = authorizer.Authorize("unauthorized-user", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
user2Hash, err := sshauth.HashUserID("user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Configure with both wildcard and specific mappings
|
||||
// Wildcard takes precedence for users in the wildcard index list
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"*": {0, 1}, // wildcard for both users
|
||||
"root": {0}, // specific mapping that would normally restrict to user1 only
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// Both users should be able to access root via wildcard (takes precedence over specific mapping)
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user2", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Both users should be able to access any other OS user via wildcard
|
||||
err = authorizer.Authorize("user1", "postgres")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user2", "admin")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
user2Hash, err := sshauth.HashUserID("user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Configure WITHOUT wildcard - only specific mappings
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": {0}, // only user1
|
||||
"postgres": {1}, // only user2
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// user1 can access root
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// user2 can access postgres
|
||||
err = authorizer.Authorize("user2", "postgres")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// user1 cannot access postgres
|
||||
err = authorizer.Authorize("user1", "postgres")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||
|
||||
// user2 cannot access root
|
||||
err = authorizer.Authorize("user2", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||
|
||||
// Neither can access unmapped OS users
|
||||
err = authorizer.Authorize("user1", "admin")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
|
||||
err = authorizer.Authorize("user2", "admin")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
|
||||
// This test covers the scenario where wildcard exists with limited indexes.
|
||||
// Only users whose indexes are in the wildcard list can access any OS user via wildcard.
|
||||
// Other users can only access OS users they are explicitly mapped to.
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Create two authorized user hashes (simulating the base64-encoded hashes in the config)
|
||||
wasmHash, err := sshauth.HashUserID("wasm")
|
||||
require.NoError(t, err)
|
||||
user2Hash, err := sshauth.HashUserID("user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Configure with wildcard having only index 0, and specific mappings for other OS users
|
||||
config := &Config{
|
||||
UserIDClaim: "sub",
|
||||
AuthorizedUsers: []sshauth.UserIDHash{wasmHash, user2Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"*": {0}, // wildcard with only index 0 - only wasm has wildcard access
|
||||
"alice": {1}, // specific mapping for user2
|
||||
"bob": {1}, // specific mapping for user2
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// wasm (index 0) should access any OS user via wildcard
|
||||
err = authorizer.Authorize("wasm", "root")
|
||||
assert.NoError(t, err, "wasm should access root via wildcard")
|
||||
|
||||
err = authorizer.Authorize("wasm", "alice")
|
||||
assert.NoError(t, err, "wasm should access alice via wildcard")
|
||||
|
||||
err = authorizer.Authorize("wasm", "bob")
|
||||
assert.NoError(t, err, "wasm should access bob via wildcard")
|
||||
|
||||
err = authorizer.Authorize("wasm", "postgres")
|
||||
assert.NoError(t, err, "wasm should access postgres via wildcard")
|
||||
|
||||
// user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres
|
||||
err = authorizer.Authorize("user2", "alice")
|
||||
assert.NoError(t, err, "user2 should access alice via explicit mapping")
|
||||
|
||||
err = authorizer.Authorize("user2", "bob")
|
||||
assert.NoError(t, err, "user2 should access bob via explicit mapping")
|
||||
|
||||
err = authorizer.Authorize("user2", "root")
|
||||
assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)")
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
|
||||
err = authorizer.Authorize("user2", "postgres")
|
||||
assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)")
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
|
||||
// Unauthorized user should still be denied
|
||||
err = authorizer.Authorize("user3", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
|
||||
}
|
||||
@@ -27,9 +27,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
"github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -137,6 +139,21 @@ func TestSSHProxy_Connect(t *testing.T) {
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
// Configure SSH authorization for the test user
|
||||
testUsername := testutil.GetTestUsername(t)
|
||||
testJWTUser := "test-username"
|
||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
testUsername: {0}, // Index 0 in AuthorizedUsers
|
||||
},
|
||||
}
|
||||
sshServer.UpdateSSHAuth(authConfig)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
defer func() { _ = sshServer.Stop() }()
|
||||
|
||||
@@ -150,10 +167,10 @@ func TestSSHProxy_Connect(t *testing.T) {
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience)
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil)
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
@@ -347,12 +364,12 @@ func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||
return privateKey, jwksJSON
|
||||
}
|
||||
|
||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
|
||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
|
||||
t.Helper()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"sub": user,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
@@ -23,10 +23,12 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
"github.com/netbirdio/netbird/client/ssh/client"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
func TestJWTEnforcement(t *testing.T) {
|
||||
@@ -577,6 +579,22 @@ func TestJWTAuthentication(t *testing.T) {
|
||||
tc.setupServer(server)
|
||||
}
|
||||
|
||||
// Always set up authorization for test-user to ensure tests fail at JWT validation stage
|
||||
testUserHash, err := sshuserhash.HashUserID("test-user")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get current OS username for machine user mapping
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
currentUser: {0}, // Allow test-user (index 0) to access current OS user
|
||||
},
|
||||
}
|
||||
server.UpdateSSHAuth(authConfig)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
@@ -138,6 +139,8 @@ type Server struct {
|
||||
jwtExtractor *jwt.ClaimsExtractor
|
||||
jwtConfig *JWTConfig
|
||||
|
||||
authorizer *sshauth.Authorizer
|
||||
|
||||
suSupportsPty bool
|
||||
loginIsUtilLinux bool
|
||||
}
|
||||
@@ -179,6 +182,7 @@ func New(config *Config) *Server {
|
||||
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
|
||||
jwtEnabled: config.JWT != nil,
|
||||
jwtConfig: config.JWT,
|
||||
authorizer: sshauth.NewAuthorizer(), // Initialize with empty config
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -320,6 +324,19 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
|
||||
s.wgAddress = addr
|
||||
}
|
||||
|
||||
// UpdateSSHAuth updates the SSH fine-grained access control configuration
|
||||
// This should be called when network map updates include new SSH auth configuration
|
||||
func (s *Server) UpdateSSHAuth(config *sshauth.Config) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Reset JWT validator/extractor to pick up new userIDClaim
|
||||
s.jwtValidator = nil
|
||||
s.jwtExtractor = nil
|
||||
|
||||
s.authorizer.Update(config)
|
||||
}
|
||||
|
||||
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
|
||||
func (s *Server) ensureJWTValidator() error {
|
||||
s.mu.RLock()
|
||||
@@ -328,6 +345,7 @@ func (s *Server) ensureJWTValidator() error {
|
||||
return nil
|
||||
}
|
||||
config := s.jwtConfig
|
||||
authorizer := s.authorizer
|
||||
s.mu.RUnlock()
|
||||
|
||||
if config == nil {
|
||||
@@ -343,9 +361,16 @@ func (s *Server) ensureJWTValidator() error {
|
||||
true,
|
||||
)
|
||||
|
||||
extractor := jwt.NewClaimsExtractor(
|
||||
// Use custom userIDClaim from authorizer if available
|
||||
extractorOptions := []jwt.ClaimsExtractorOption{
|
||||
jwt.WithAudience(config.Audience),
|
||||
)
|
||||
}
|
||||
if authorizer.GetUserIDClaim() != "" {
|
||||
extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
|
||||
log.Debugf("Using custom user ID claim: %s", authorizer.GetUserIDClaim())
|
||||
}
|
||||
|
||||
extractor := jwt.NewClaimsExtractor(extractorOptions...)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -493,29 +518,41 @@ func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]int
|
||||
}
|
||||
|
||||
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
|
||||
osUsername := ctx.User()
|
||||
remoteAddr := ctx.RemoteAddr()
|
||||
|
||||
if err := s.ensureJWTValidator(); err != nil {
|
||||
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
log.Errorf("JWT validator initialization failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
||||
return false
|
||||
}
|
||||
|
||||
token, err := s.validateJWTToken(password)
|
||||
if err != nil {
|
||||
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
log.Warnf("JWT authentication failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
||||
return false
|
||||
}
|
||||
|
||||
userAuth, err := s.extractAndValidateUser(token)
|
||||
if err != nil {
|
||||
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
log.Warnf("User validation failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
||||
return false
|
||||
}
|
||||
|
||||
key := newAuthKey(ctx.User(), ctx.RemoteAddr())
|
||||
s.mu.RLock()
|
||||
authorizer := s.authorizer
|
||||
s.mu.RUnlock()
|
||||
|
||||
if err := authorizer.Authorize(userAuth.UserId, osUsername); err != nil {
|
||||
log.Warnf("SSH authorization denied for user %s (JWT user ID: %s) from %s: %v", osUsername, userAuth.UserId, remoteAddr, err)
|
||||
return false
|
||||
}
|
||||
|
||||
key := newAuthKey(osUsername, remoteAddr)
|
||||
s.mu.Lock()
|
||||
s.pendingAuthJWT[key] = userAuth.UserId
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
|
||||
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", osUsername, userAuth.UserId, remoteAddr)
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -312,6 +312,8 @@ type serviceClient struct {
|
||||
daemonVersion string
|
||||
updateIndicationLock sync.Mutex
|
||||
isUpdateIconActive bool
|
||||
settingsEnabled bool
|
||||
profilesEnabled bool
|
||||
showNetworks bool
|
||||
wNetworks fyne.Window
|
||||
wProfiles fyne.Window
|
||||
@@ -907,7 +909,7 @@ func (s *serviceClient) updateStatus() error {
|
||||
var systrayIconState bool
|
||||
|
||||
switch {
|
||||
case status.Status == string(internal.StatusConnected):
|
||||
case status.Status == string(internal.StatusConnected) && !s.connected:
|
||||
s.connected = true
|
||||
s.sendNotification = true
|
||||
if s.isUpdateIconActive {
|
||||
@@ -921,6 +923,7 @@ func (s *serviceClient) updateStatus() error {
|
||||
s.mUp.Disable()
|
||||
s.mDown.Enable()
|
||||
s.mNetworks.Enable()
|
||||
s.mExitNode.Enable()
|
||||
go s.updateExitNodes()
|
||||
systrayIconState = true
|
||||
case status.Status == string(internal.StatusConnecting):
|
||||
@@ -1274,19 +1277,22 @@ func (s *serviceClient) checkAndUpdateFeatures() {
|
||||
return
|
||||
}
|
||||
|
||||
s.updateIndicationLock.Lock()
|
||||
defer s.updateIndicationLock.Unlock()
|
||||
|
||||
// Update settings menu based on current features
|
||||
if features != nil && features.DisableUpdateSettings {
|
||||
s.setSettingsEnabled(false)
|
||||
} else {
|
||||
s.setSettingsEnabled(true)
|
||||
settingsEnabled := features == nil || !features.DisableUpdateSettings
|
||||
if s.settingsEnabled != settingsEnabled {
|
||||
s.settingsEnabled = settingsEnabled
|
||||
s.setSettingsEnabled(settingsEnabled)
|
||||
}
|
||||
|
||||
// Update profile menu based on current features
|
||||
if s.mProfile != nil {
|
||||
if features != nil && features.DisableProfiles {
|
||||
s.mProfile.setEnabled(false)
|
||||
} else {
|
||||
s.mProfile.setEnabled(true)
|
||||
profilesEnabled := features == nil || !features.DisableProfiles
|
||||
if s.profilesEnabled != profilesEnabled {
|
||||
s.profilesEnabled = profilesEnabled
|
||||
s.mProfile.setEnabled(profilesEnabled)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,6 @@ func (s *serviceClient) getWindowsFontFilePath() string {
|
||||
"chr-CHER-US": "Gadugi.ttf",
|
||||
"zh-HK": "Segoeui.ttf",
|
||||
"zh-TW": "Segoeui.ttf",
|
||||
"ja-JP": "Yugothm.ttc",
|
||||
"km-KH": "Leelawui.ttf",
|
||||
"ko-KR": "Malgun.ttf",
|
||||
"th-TH": "Leelawui.ttf",
|
||||
|
||||
17
go.mod
17
go.mod
@@ -22,7 +22,7 @@ require (
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
google.golang.org/grpc v1.73.0
|
||||
google.golang.org/grpc v1.75.0
|
||||
google.golang.org/protobuf v1.36.8
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||
)
|
||||
@@ -41,6 +41,7 @@ require (
|
||||
github.com/coder/websocket v1.8.13
|
||||
github.com/coreos/go-iptables v0.7.0
|
||||
github.com/creack/pty v1.1.18
|
||||
github.com/dexidp/dex/api/v2 v2.4.0
|
||||
github.com/eko/gocache/lib/v4 v4.2.0
|
||||
github.com/eko/gocache/store/go_cache/v4 v4.2.2
|
||||
github.com/eko/gocache/store/redis/v4 v4.2.2
|
||||
@@ -97,10 +98,10 @@ require (
|
||||
github.com/yusufpapurcu/wmi v1.2.4
|
||||
github.com/zcalusic/sysinfo v1.1.3
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0
|
||||
go.opentelemetry.io/otel v1.35.0
|
||||
go.opentelemetry.io/otel v1.37.0
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
|
||||
go.opentelemetry.io/otel/metric v1.35.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.35.0
|
||||
go.opentelemetry.io/otel/metric v1.37.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.37.0
|
||||
go.uber.org/mock v0.5.0
|
||||
go.uber.org/zap v1.27.0
|
||||
goauthentik.io/api/v3 v3.2023051.3
|
||||
@@ -124,7 +125,7 @@ require (
|
||||
require (
|
||||
cloud.google.com/go/auth v0.3.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.6.0 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.7.0 // indirect
|
||||
dario.cat/mergo v1.0.0 // indirect
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
|
||||
@@ -170,7 +171,7 @@ require (
|
||||
github.com/fyne-io/oksvg v0.2.0 // indirect
|
||||
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||
@@ -248,8 +249,8 @@ require (
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.37.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/image v0.33.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
|
||||
40
go.sum
40
go.sum
@@ -4,8 +4,8 @@ cloud.google.com/go/auth v0.3.0/go.mod h1:lBv6NKTWp8E3LPzmO1TbiiRKc4drLOfHsgmlH9
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q=
|
||||
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
|
||||
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
|
||||
cloud.google.com/go/compute/metadata v0.7.0 h1:PBWF+iiAerVNe8UCHxdOt6eHLVc3ydFeOCw78U8ytSU=
|
||||
cloud.google.com/go/compute/metadata v0.7.0/go.mod h1:j5MvL9PprKL39t166CoB1uVHfQMs4tFQZZcKwksXUjo=
|
||||
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
|
||||
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
|
||||
dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
|
||||
@@ -117,6 +117,8 @@ github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70J
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dexidp/dex/api/v2 v2.4.0 h1:gNba7n6BKVp8X4Jp24cxYn5rIIGhM6kDOXcZoL6tr9A=
|
||||
github.com/dexidp/dex/api/v2 v2.4.0/go.mod h1:/p550ADvFFh7K95VmhUD+jgm15VdaNnab9td8DHOpyI=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
@@ -164,8 +166,8 @@ github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71/go.mod h1:9YTyiznxEY1fVin
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+xzxf1jTJKMKZw3H0swfWk9RpWbBbDK5+0=
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
@@ -561,22 +563,22 @@ go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.4
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc=
|
||||
go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
|
||||
go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y=
|
||||
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
|
||||
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o=
|
||||
go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M=
|
||||
go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE=
|
||||
go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY=
|
||||
go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w=
|
||||
go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs=
|
||||
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
|
||||
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
|
||||
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps=
|
||||
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
|
||||
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
|
||||
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
|
||||
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
@@ -761,6 +763,8 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/api v0.177.0 h1:8a0p/BbPa65GlqGWtUKxot4p0TV8OGOfyTjtmkXNXmk=
|
||||
google.golang.org/api v0.177.0/go.mod h1:srbhue4MLjkjbkux5p3dw/ocYOSZTaIEvf7bCOnFQDw=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
@@ -770,8 +774,8 @@ google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoA
|
||||
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
|
||||
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
|
||||
google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463 h1:hE3bRWtU6uceqlh4fhrSnUyjKHMKB9KrTLLG+bc0ddM=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463/go.mod h1:U90ffi8eUL9MwPcrJylN5+Mk2v3vuPDptd5yyNUiRR8=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7 h1:FiusG7LWj+4byqhbvmB+Q93B/mOxJLN2DTozDuZm4EU=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:kXqgZtrWaf6qS3jZOCnCH7WYfrvFjkC51bM8fz3RsCA=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 h1:pFyd6EwwL2TqFf8emdthzeX+gZE1ElRq3iM8pui4KBY=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
@@ -779,8 +783,8 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac
|
||||
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
|
||||
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
|
||||
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
|
||||
google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok=
|
||||
google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc=
|
||||
google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4=
|
||||
google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
|
||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
||||
|
||||
@@ -53,7 +53,8 @@ services:
|
||||
command: [
|
||||
"--cert-file", "$NETBIRD_MGMT_API_CERT_FILE",
|
||||
"--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE",
|
||||
"--log-file", "console"
|
||||
"--log-file", "console",
|
||||
"--port", "80"
|
||||
]
|
||||
|
||||
# Relay
|
||||
|
||||
554
infrastructure_files/getting-started-with-dex.sh
Executable file
554
infrastructure_files/getting-started-with-dex.sh
Executable file
@@ -0,0 +1,554 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# NetBird Getting Started with Dex IDP
|
||||
# This script sets up NetBird with Dex as the identity provider
|
||||
|
||||
# Sed pattern to strip base64 padding characters
|
||||
SED_STRIP_PADDING='s/=//g'
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null
|
||||
then
|
||||
echo "docker-compose"
|
||||
return
|
||||
fi
|
||||
if docker compose --help &> /dev/null
|
||||
then
|
||||
echo "docker compose"
|
||||
return
|
||||
fi
|
||||
|
||||
echo "docker-compose is not installed or not in PATH. Please follow the steps from the official guide: https://docs.docker.com/engine/install/" > /dev/stderr
|
||||
exit 1
|
||||
}
|
||||
|
||||
check_jq() {
|
||||
if ! command -v jq &> /dev/null
|
||||
then
|
||||
echo "jq is not installed or not in PATH, please install with your package manager. e.g. sudo apt install jq" > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
get_main_ip_address() {
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
interface=$(route -n get default | grep 'interface:' | awk '{print $2}')
|
||||
ip_address=$(ifconfig "$interface" | grep 'inet ' | awk '{print $2}')
|
||||
else
|
||||
interface=$(ip route | grep default | awk '{print $5}' | head -n 1)
|
||||
ip_address=$(ip addr show "$interface" | grep 'inet ' | awk '{print $2}' | cut -d'/' -f1)
|
||||
fi
|
||||
|
||||
echo "$ip_address"
|
||||
return 0
|
||||
}
|
||||
|
||||
check_nb_domain() {
|
||||
DOMAIN=$1
|
||||
if [[ "$DOMAIN-x" == "-x" ]]; then
|
||||
echo "The NETBIRD_DOMAIN variable cannot be empty." > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
|
||||
if [[ "$DOMAIN" == "netbird.example.com" ]]; then
|
||||
echo "The NETBIRD_DOMAIN cannot be netbird.example.com" > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
read_nb_domain() {
|
||||
READ_NETBIRD_DOMAIN=""
|
||||
echo -n "Enter the domain you want to use for NetBird (e.g. netbird.my-domain.com): " > /dev/stderr
|
||||
read -r READ_NETBIRD_DOMAIN < /dev/tty
|
||||
if ! check_nb_domain "$READ_NETBIRD_DOMAIN"; then
|
||||
read_nb_domain
|
||||
fi
|
||||
echo "$READ_NETBIRD_DOMAIN"
|
||||
return 0
|
||||
}
|
||||
|
||||
get_turn_external_ip() {
|
||||
TURN_EXTERNAL_IP_CONFIG="#external-ip="
|
||||
IP=$(curl -s -4 https://jsonip.com | jq -r '.ip')
|
||||
if [[ "x-$IP" != "x-" ]]; then
|
||||
TURN_EXTERNAL_IP_CONFIG="external-ip=$IP"
|
||||
fi
|
||||
echo "$TURN_EXTERNAL_IP_CONFIG"
|
||||
return 0
|
||||
}
|
||||
|
||||
wait_dex() {
|
||||
set +e
|
||||
echo -n "Waiting for Dex to become ready (via $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN)"
|
||||
counter=1
|
||||
while true; do
|
||||
# Check Dex through Caddy proxy (also validates TLS is working)
|
||||
if curl -sk -f -o /dev/null "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex/.well-known/openid-configuration" 2>/dev/null; then
|
||||
break
|
||||
fi
|
||||
if [[ $counter -eq 60 ]]; then
|
||||
echo ""
|
||||
echo "Taking too long. Checking logs..."
|
||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 caddy
|
||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 dex
|
||||
fi
|
||||
echo -n " ."
|
||||
sleep 2
|
||||
counter=$((counter + 1))
|
||||
done
|
||||
echo " done"
|
||||
set -e
|
||||
return 0
|
||||
}
|
||||
|
||||
init_environment() {
|
||||
CADDY_SECURE_DOMAIN=""
|
||||
NETBIRD_PORT=80
|
||||
NETBIRD_HTTP_PROTOCOL="http"
|
||||
NETBIRD_RELAY_PROTO="rel"
|
||||
TURN_USER="self"
|
||||
TURN_PASSWORD=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING")
|
||||
NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING")
|
||||
TURN_MIN_PORT=49152
|
||||
TURN_MAX_PORT=65535
|
||||
TURN_EXTERNAL_IP_CONFIG=$(get_turn_external_ip)
|
||||
|
||||
# Generate secrets for Dex
|
||||
DEX_DASHBOARD_CLIENT_SECRET=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING")
|
||||
|
||||
# Generate admin password
|
||||
NETBIRD_ADMIN_PASSWORD=$(openssl rand -base64 16 | sed "$SED_STRIP_PADDING")
|
||||
|
||||
if ! check_nb_domain "$NETBIRD_DOMAIN"; then
|
||||
NETBIRD_DOMAIN=$(read_nb_domain)
|
||||
fi
|
||||
|
||||
if [[ "$NETBIRD_DOMAIN" == "use-ip" ]]; then
|
||||
NETBIRD_DOMAIN=$(get_main_ip_address)
|
||||
else
|
||||
NETBIRD_PORT=443
|
||||
CADDY_SECURE_DOMAIN=", $NETBIRD_DOMAIN:$NETBIRD_PORT"
|
||||
NETBIRD_HTTP_PROTOCOL="https"
|
||||
NETBIRD_RELAY_PROTO="rels"
|
||||
fi
|
||||
|
||||
check_jq
|
||||
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
|
||||
if [[ -f dex.yaml ]]; then
|
||||
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
|
||||
echo "You can use the following commands:"
|
||||
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
|
||||
echo " rm -f docker-compose.yml Caddyfile dex.yaml dashboard.env turnserver.conf management.json relay.env"
|
||||
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo Rendering initial files...
|
||||
render_docker_compose > docker-compose.yml
|
||||
render_caddyfile > Caddyfile
|
||||
render_dex_config > dex.yaml
|
||||
render_dashboard_env > dashboard.env
|
||||
render_management_json > management.json
|
||||
render_turn_server_conf > turnserver.conf
|
||||
render_relay_env > relay.env
|
||||
|
||||
echo -e "\nStarting Dex IDP\n"
|
||||
$DOCKER_COMPOSE_COMMAND up -d caddy dex
|
||||
|
||||
# Wait for Dex to be ready (through caddy proxy)
|
||||
sleep 3
|
||||
wait_dex
|
||||
|
||||
echo -e "\nStarting NetBird services\n"
|
||||
$DOCKER_COMPOSE_COMMAND up -d
|
||||
|
||||
echo -e "\nDone!\n"
|
||||
echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
||||
echo ""
|
||||
echo "Login with the following credentials:"
|
||||
echo "Email: admin@$NETBIRD_DOMAIN" | tee .env
|
||||
echo "Password: $NETBIRD_ADMIN_PASSWORD" | tee -a .env
|
||||
echo ""
|
||||
echo "Dex admin UI is not available (Dex has no built-in UI)."
|
||||
echo "To add more users, edit dex.yaml and restart: $DOCKER_COMPOSE_COMMAND restart dex"
|
||||
return 0
|
||||
}
|
||||
|
||||
render_caddyfile() {
|
||||
cat <<EOF
|
||||
{
|
||||
debug
|
||||
servers :80,:443 {
|
||||
protocols h1 h2c h2 h3
|
||||
}
|
||||
}
|
||||
|
||||
(security_headers) {
|
||||
header * {
|
||||
Strict-Transport-Security "max-age=3600; includeSubDomains; preload"
|
||||
X-Content-Type-Options "nosniff"
|
||||
X-Frame-Options "SAMEORIGIN"
|
||||
X-XSS-Protection "1; mode=block"
|
||||
-Server
|
||||
Referrer-Policy strict-origin-when-cross-origin
|
||||
}
|
||||
}
|
||||
|
||||
:80${CADDY_SECURE_DOMAIN} {
|
||||
import security_headers
|
||||
# Relay
|
||||
reverse_proxy /relay* relay:80
|
||||
# Signal
|
||||
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
|
||||
# Management
|
||||
reverse_proxy /api/* management:80
|
||||
reverse_proxy /management.ManagementService/* h2c://management:80
|
||||
# Dex
|
||||
reverse_proxy /dex/* dex:5556
|
||||
# Dashboard
|
||||
reverse_proxy /* dashboard:80
|
||||
}
|
||||
EOF
|
||||
return 0
|
||||
}
|
||||
|
||||
render_dex_config() {
|
||||
# Generate bcrypt hash of the admin password
|
||||
# Using a simple approach - htpasswd or python if available
|
||||
ADMIN_PASSWORD_HASH=""
|
||||
if command -v htpasswd &> /dev/null; then
|
||||
ADMIN_PASSWORD_HASH=$(htpasswd -bnBC 10 "" "$NETBIRD_ADMIN_PASSWORD" | tr -d ':\n')
|
||||
elif command -v python3 &> /dev/null; then
|
||||
ADMIN_PASSWORD_HASH=$(python3 -c "import bcrypt; print(bcrypt.hashpw('$NETBIRD_ADMIN_PASSWORD'.encode(), bcrypt.gensalt(rounds=10)).decode())" 2>/dev/null || echo "")
|
||||
fi
|
||||
|
||||
# Fallback to a known hash if we can't generate one
|
||||
if [[ -z "$ADMIN_PASSWORD_HASH" ]]; then
|
||||
# This is hash of "password" - user should change it
|
||||
ADMIN_PASSWORD_HASH='$2a$10$2b2cU8CPhOTaGrs1HRQuAueS7JTT5ZHsHSzYiFPm1leZck7Mc8T4W'
|
||||
NETBIRD_ADMIN_PASSWORD="password"
|
||||
echo "Warning: Could not generate password hash. Using default password: password. Please change it in dex.yaml" > /dev/stderr
|
||||
fi
|
||||
|
||||
cat <<EOF
|
||||
# Dex configuration for NetBird
|
||||
# Generated by getting-started-with-dex.sh
|
||||
|
||||
issuer: $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex
|
||||
|
||||
storage:
|
||||
type: sqlite3
|
||||
config:
|
||||
file: /var/dex/dex.db
|
||||
|
||||
web:
|
||||
http: 0.0.0.0:5556
|
||||
|
||||
# gRPC API for user management (used by NetBird IDP manager)
|
||||
grpc:
|
||||
addr: 0.0.0.0:5557
|
||||
|
||||
oauth2:
|
||||
skipApprovalScreen: true
|
||||
|
||||
# Static OAuth2 clients for NetBird
|
||||
staticClients:
|
||||
# Dashboard client
|
||||
- id: netbird-dashboard
|
||||
name: NetBird Dashboard
|
||||
secret: $DEX_DASHBOARD_CLIENT_SECRET
|
||||
redirectURIs:
|
||||
- $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/nb-auth
|
||||
- $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/nb-silent-auth
|
||||
|
||||
# CLI client (public - uses PKCE)
|
||||
- id: netbird-cli
|
||||
name: NetBird CLI
|
||||
public: true
|
||||
redirectURIs:
|
||||
- http://localhost:53000/
|
||||
- http://localhost:54000/
|
||||
|
||||
# Enable password database for static users
|
||||
enablePasswordDB: true
|
||||
|
||||
# Static users - add more users here as needed
|
||||
staticPasswords:
|
||||
- email: "admin@$NETBIRD_DOMAIN"
|
||||
hash: "$ADMIN_PASSWORD_HASH"
|
||||
username: "admin"
|
||||
userID: "$(uuidgen 2>/dev/null || cat /proc/sys/kernel/random/uuid 2>/dev/null || echo "admin-user-id-001")"
|
||||
|
||||
# Optional: Add external identity provider connectors
|
||||
# connectors:
|
||||
# - type: github
|
||||
# id: github
|
||||
# name: GitHub
|
||||
# config:
|
||||
# clientID: \$GITHUB_CLIENT_ID
|
||||
# clientSecret: \$GITHUB_CLIENT_SECRET
|
||||
# redirectURI: $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex/callback
|
||||
#
|
||||
# - type: ldap
|
||||
# id: ldap
|
||||
# name: LDAP
|
||||
# config:
|
||||
# host: ldap.example.com:636
|
||||
# insecureNoSSL: false
|
||||
# bindDN: cn=admin,dc=example,dc=com
|
||||
# bindPW: admin
|
||||
# userSearch:
|
||||
# baseDN: ou=users,dc=example,dc=com
|
||||
# filter: "(objectClass=person)"
|
||||
# username: uid
|
||||
# idAttr: uid
|
||||
# emailAttr: mail
|
||||
# nameAttr: cn
|
||||
EOF
|
||||
return 0
|
||||
}
|
||||
|
||||
render_turn_server_conf() {
|
||||
cat <<EOF
|
||||
listening-port=3478
|
||||
$TURN_EXTERNAL_IP_CONFIG
|
||||
tls-listening-port=5349
|
||||
min-port=$TURN_MIN_PORT
|
||||
max-port=$TURN_MAX_PORT
|
||||
fingerprint
|
||||
lt-cred-mech
|
||||
user=$TURN_USER:$TURN_PASSWORD
|
||||
realm=wiretrustee.com
|
||||
cert=/etc/coturn/certs/cert.pem
|
||||
pkey=/etc/coturn/private/privkey.pem
|
||||
log-file=stdout
|
||||
no-software-attribute
|
||||
pidfile="/var/tmp/turnserver.pid"
|
||||
no-cli
|
||||
EOF
|
||||
return 0
|
||||
}
|
||||
|
||||
render_management_json() {
|
||||
cat <<EOF
|
||||
{
|
||||
"Stuns": [
|
||||
{
|
||||
"Proto": "udp",
|
||||
"URI": "stun:$NETBIRD_DOMAIN:3478"
|
||||
}
|
||||
],
|
||||
"Relay": {
|
||||
"Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"],
|
||||
"CredentialsTTL": "24h",
|
||||
"Secret": "$NETBIRD_RELAY_AUTH_SECRET"
|
||||
},
|
||||
"Signal": {
|
||||
"Proto": "$NETBIRD_HTTP_PROTOCOL",
|
||||
"URI": "$NETBIRD_DOMAIN:$NETBIRD_PORT"
|
||||
},
|
||||
"HttpConfig": {
|
||||
"AuthIssuer": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex",
|
||||
"AuthAudience": "netbird-dashboard",
|
||||
"OIDCConfigEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex/.well-known/openid-configuration"
|
||||
},
|
||||
"IdpManagerConfig": {
|
||||
"ManagerType": "dex",
|
||||
"ClientConfig": {
|
||||
"Issuer": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex"
|
||||
},
|
||||
"ExtraConfig": {
|
||||
"GRPCAddr": "dex:5557"
|
||||
}
|
||||
},
|
||||
"DeviceAuthorizationFlow": {
|
||||
"Provider": "hosted",
|
||||
"ProviderConfig": {
|
||||
"Audience": "netbird-cli",
|
||||
"ClientID": "netbird-cli",
|
||||
"Scope": "openid profile email offline_access",
|
||||
"DeviceAuthEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex/device/code",
|
||||
"TokenEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex/token"
|
||||
}
|
||||
},
|
||||
"PKCEAuthorizationFlow": {
|
||||
"ProviderConfig": {
|
||||
"Audience": "netbird-cli",
|
||||
"ClientID": "netbird-cli",
|
||||
"Scope": "openid profile email offline_access",
|
||||
"RedirectURLs": ["http://localhost:53000/", "http://localhost:54000/"]
|
||||
}
|
||||
}
|
||||
}
|
||||
EOF
|
||||
return 0
|
||||
}
|
||||
|
||||
render_dashboard_env() {
|
||||
cat <<EOF
|
||||
# Endpoints
|
||||
NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN
|
||||
NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN
|
||||
# OIDC
|
||||
AUTH_AUDIENCE=netbird-dashboard
|
||||
AUTH_CLIENT_ID=netbird-dashboard
|
||||
AUTH_CLIENT_SECRET=$DEX_DASHBOARD_CLIENT_SECRET
|
||||
AUTH_AUTHORITY=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex
|
||||
USE_AUTH0=false
|
||||
AUTH_SUPPORTED_SCOPES=openid profile email offline_access
|
||||
AUTH_REDIRECT_URI=/nb-auth
|
||||
AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
|
||||
# SSL
|
||||
NGINX_SSL_PORT=443
|
||||
# Letsencrypt
|
||||
LETSENCRYPT_DOMAIN=none
|
||||
EOF
|
||||
return 0
|
||||
}
|
||||
|
||||
render_relay_env() {
|
||||
cat <<EOF
|
||||
NB_LOG_LEVEL=info
|
||||
NB_LISTEN_ADDRESS=:80
|
||||
NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT
|
||||
NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET
|
||||
EOF
|
||||
return 0
|
||||
}
|
||||
|
||||
render_docker_compose() {
|
||||
cat <<EOF
|
||||
services:
|
||||
# Caddy reverse proxy
|
||||
caddy:
|
||||
image: caddy
|
||||
container_name: netbird-caddy
|
||||
restart: unless-stopped
|
||||
networks: [netbird]
|
||||
ports:
|
||||
- '443:443'
|
||||
- '443:443/udp'
|
||||
- '80:80'
|
||||
volumes:
|
||||
- netbird_caddy_data:/data
|
||||
- ./Caddyfile:/etc/caddy/Caddyfile
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# Dex - identity provider
|
||||
dex:
|
||||
image: ghcr.io/dexidp/dex:v2.38.0
|
||||
container_name: netbird-dex
|
||||
restart: unless-stopped
|
||||
networks: [netbird]
|
||||
volumes:
|
||||
- ./dex.yaml:/etc/dex/config.docker.yaml:ro
|
||||
- netbird_dex_data:/var/dex
|
||||
command: ["dex", "serve", "/etc/dex/config.docker.yaml"]
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# UI dashboard
|
||||
dashboard:
|
||||
image: netbirdio/dashboard:latest
|
||||
container_name: netbird-dashboard
|
||||
restart: unless-stopped
|
||||
networks: [netbird]
|
||||
env_file:
|
||||
- ./dashboard.env
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# Signal
|
||||
signal:
|
||||
image: netbirdio/signal:latest
|
||||
container_name: netbird-signal
|
||||
restart: unless-stopped
|
||||
networks: [netbird]
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# Relay
|
||||
relay:
|
||||
image: netbirdio/relay:latest
|
||||
container_name: netbird-relay
|
||||
restart: unless-stopped
|
||||
networks: [netbird]
|
||||
env_file:
|
||||
- ./relay.env
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# Management
|
||||
management:
|
||||
image: netbirdio/management:latest
|
||||
container_name: netbird-management
|
||||
restart: unless-stopped
|
||||
networks: [netbird]
|
||||
volumes:
|
||||
- netbird_management:/var/lib/netbird
|
||||
- ./management.json:/etc/netbird/management.json
|
||||
command: [
|
||||
"--port", "80",
|
||||
"--log-file", "console",
|
||||
"--log-level", "info",
|
||||
"--disable-anonymous-metrics=false",
|
||||
"--single-account-mode-domain=netbird.selfhosted",
|
||||
"--dns-domain=netbird.selfhosted",
|
||||
"--idp-sign-key-refresh-enabled",
|
||||
]
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# Coturn, AKA TURN server
|
||||
coturn:
|
||||
image: coturn/coturn
|
||||
container_name: netbird-coturn
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./turnserver.conf:/etc/turnserver.conf:ro
|
||||
network_mode: host
|
||||
command:
|
||||
- -c /etc/turnserver.conf
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
volumes:
|
||||
netbird_caddy_data:
|
||||
netbird_dex_data:
|
||||
netbird_management:
|
||||
|
||||
networks:
|
||||
netbird:
|
||||
EOF
|
||||
return 0
|
||||
}
|
||||
|
||||
init_environment
|
||||
@@ -178,6 +178,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
||||
@@ -224,7 +225,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
@@ -320,6 +321,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peerId)
|
||||
if err != nil {
|
||||
@@ -338,7 +340,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
if c.experimentalNetworkMap(accountId) {
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
@@ -445,7 +447,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics)
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics, account.GetActiveGroupUsers())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
@@ -811,7 +813,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
if c.experimentalNetworkMap(peer.AccountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
|
||||
@@ -158,5 +158,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
||||
}
|
||||
}
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,7 +6,10 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
@@ -16,6 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||
@@ -84,15 +88,15 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken
|
||||
return nbConfig
|
||||
}
|
||||
|
||||
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.PeerConfig {
|
||||
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, enableSSH bool) *proto.PeerConfig {
|
||||
netmask, _ := network.Net.Mask.Size()
|
||||
fqdn := peer.FQDN(dnsName)
|
||||
|
||||
sshConfig := &proto.SSHConfig{
|
||||
SshEnabled: peer.SSHEnabled,
|
||||
SshEnabled: peer.SSHEnabled || enableSSH,
|
||||
}
|
||||
|
||||
if peer.SSHEnabled {
|
||||
if sshConfig.SshEnabled {
|
||||
sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig)
|
||||
}
|
||||
|
||||
@@ -110,12 +114,12 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
|
||||
|
||||
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
||||
response := &proto.SyncResponse{
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
@@ -151,9 +155,45 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
response.NetworkMap.ForwardingRules = forwardingRules
|
||||
}
|
||||
|
||||
if networkMap.AuthorizedUsers != nil {
|
||||
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
||||
userIDClaim := auth.DefaultUserIDClaim
|
||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||
userIDClaim = httpConfig.AuthUserIDClaim
|
||||
}
|
||||
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
|
||||
userIDToIndex := make(map[string]uint32)
|
||||
var hashedUsers [][]byte
|
||||
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
|
||||
|
||||
for machineUser, users := range authorizedUsers {
|
||||
indexes := make([]uint32, 0, len(users))
|
||||
for userID := range users {
|
||||
idx, exists := userIDToIndex[userID]
|
||||
if !exists {
|
||||
hash, err := sshauth.HashUserID(userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
|
||||
continue
|
||||
}
|
||||
idx = uint32(len(hashedUsers))
|
||||
userIDToIndex[userID] = idx
|
||||
hashedUsers = append(hashedUsers, hash[:])
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
|
||||
}
|
||||
|
||||
return hashedUsers, machineUsers
|
||||
}
|
||||
|
||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
dst = append(dst, &proto.RemotePeerConfig{
|
||||
|
||||
@@ -184,8 +184,14 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
realIP := getRealIP(ctx)
|
||||
sRealIP := realIP.String()
|
||||
peerMeta := extractPeerMeta(ctx, syncReq.GetMeta())
|
||||
userID, err := s.accountManager.GetUserIDByPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
s.syncSem.Add(-1)
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if userID == "" && !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
|
||||
}
|
||||
@@ -270,6 +276,8 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync took %s", time.Since(reqStart))
|
||||
|
||||
s.syncSem.Add(-1)
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||
@@ -559,6 +567,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
|
||||
}
|
||||
log.WithContext(ctx).Debugf("Login took %s", time.Since(reqStart))
|
||||
}()
|
||||
|
||||
if loginReq.GetMeta() == nil {
|
||||
@@ -635,7 +644,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
|
||||
// if peer has reached this point then it has logged in
|
||||
loginResp := &proto.LoginResponse{
|
||||
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH),
|
||||
Checks: toProtocolChecks(ctx, postureChecks),
|
||||
}
|
||||
|
||||
|
||||
@@ -1456,21 +1456,19 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
}
|
||||
}
|
||||
|
||||
if settings.GroupsPropagationEnabled {
|
||||
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
|
||||
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
|
||||
}
|
||||
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
|
||||
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -2158,3 +2156,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) {
|
||||
return am.Store.GetUserIDByPeerKey(ctx, store.LockingStrengthNone, peerKey)
|
||||
}
|
||||
|
||||
@@ -123,4 +123,5 @@ type Manager interface {
|
||||
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
|
||||
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||
}
|
||||
|
||||
@@ -397,7 +397,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||
}
|
||||
|
||||
@@ -427,7 +427,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthUpdate, accountID, groupID)
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
@@ -442,6 +442,10 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
}
|
||||
|
||||
if len(groupIDsToDelete) == 0 {
|
||||
return allErrors
|
||||
}
|
||||
|
||||
if err = transaction.DeleteGroups(ctx, accountID, groupIDsToDelete); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -299,7 +299,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
||||
|
||||
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||
}
|
||||
@@ -369,6 +369,9 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
|
||||
PortRanges: []types.RulePortRange{portRange},
|
||||
}},
|
||||
}
|
||||
if protocol == types.PolicyRuleProtocolNetbirdSSH {
|
||||
policy.Rules[0].AuthorizedUser = userAuth.UserId
|
||||
}
|
||||
|
||||
_, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true)
|
||||
if err != nil {
|
||||
@@ -449,6 +452,18 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
|
||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||
Ephemeral: peer.Ephemeral,
|
||||
LocalFlags: &api.PeerLocalFlags{
|
||||
BlockInbound: &peer.Meta.Flags.BlockInbound,
|
||||
BlockLanAccess: &peer.Meta.Flags.BlockLANAccess,
|
||||
DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes,
|
||||
DisableDns: &peer.Meta.Flags.DisableDNS,
|
||||
DisableFirewall: &peer.Meta.Flags.DisableFirewall,
|
||||
DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes,
|
||||
LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled,
|
||||
RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled,
|
||||
RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive,
|
||||
ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed,
|
||||
},
|
||||
}
|
||||
|
||||
if !approved {
|
||||
@@ -463,7 +478,6 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
|
||||
if osVersion == "" {
|
||||
osVersion = peer.Meta.Core
|
||||
}
|
||||
|
||||
return &api.PeerBatch{
|
||||
CreatedAt: peer.CreatedAt,
|
||||
Id: peer.ID,
|
||||
@@ -492,6 +506,18 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
|
||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||
Ephemeral: peer.Ephemeral,
|
||||
LocalFlags: &api.PeerLocalFlags{
|
||||
BlockInbound: &peer.Meta.Flags.BlockInbound,
|
||||
BlockLanAccess: &peer.Meta.Flags.BlockLANAccess,
|
||||
DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes,
|
||||
DisableDns: &peer.Meta.Flags.DisableDNS,
|
||||
DisableFirewall: &peer.Meta.Flags.DisableFirewall,
|
||||
DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes,
|
||||
LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled,
|
||||
RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled,
|
||||
RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive,
|
||||
ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -221,6 +221,8 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
||||
pr.Protocol = types.PolicyRuleProtocolUDP
|
||||
case api.PolicyRuleUpdateProtocolIcmp:
|
||||
pr.Protocol = types.PolicyRuleProtocolICMP
|
||||
case api.PolicyRuleUpdateProtocolNetbirdSsh:
|
||||
pr.Protocol = types.PolicyRuleProtocolNetbirdSSH
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
|
||||
return
|
||||
@@ -254,6 +256,17 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
||||
}
|
||||
}
|
||||
|
||||
if pr.Protocol == types.PolicyRuleProtocolNetbirdSSH && rule.AuthorizedGroups != nil && len(*rule.AuthorizedGroups) != 0 {
|
||||
for _, sourceGroupID := range pr.Sources {
|
||||
_, ok := (*rule.AuthorizedGroups)[sourceGroupID]
|
||||
if !ok {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "authorized group for netbird-ssh protocol should be specified for each source group"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
pr.AuthorizedGroups = *rule.AuthorizedGroups
|
||||
}
|
||||
|
||||
// validate policy object
|
||||
if pr.Protocol == types.PolicyRuleProtocolALL || pr.Protocol == types.PolicyRuleProtocolICMP {
|
||||
if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 {
|
||||
@@ -380,6 +393,11 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
|
||||
DestinationResource: r.DestinationResource.ToAPIResponse(),
|
||||
}
|
||||
|
||||
if len(r.AuthorizedGroups) != 0 {
|
||||
authorizedGroupsCopy := r.AuthorizedGroups
|
||||
rule.AuthorizedGroups = &authorizedGroupsCopy
|
||||
}
|
||||
|
||||
if len(r.Ports) != 0 {
|
||||
portsCopy := r.Ports
|
||||
rule.Ports = &portsCopy
|
||||
|
||||
445
management/server/idp/dex.go
Normal file
445
management/server/idp/dex.go
Normal file
@@ -0,0 +1,445 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/api/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
// DexManager implements the Manager interface for Dex IDP.
|
||||
// It uses Dex's gRPC API to manage users in the password database.
|
||||
type DexManager struct {
|
||||
grpcAddr string
|
||||
httpClient ManagerHTTPClient
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
mux sync.Mutex
|
||||
conn *grpc.ClientConn
|
||||
}
|
||||
|
||||
// DexClientConfig Dex manager client configuration.
|
||||
type DexClientConfig struct {
|
||||
// GRPCAddr is the address of Dex's gRPC API (e.g., "localhost:5557")
|
||||
GRPCAddr string
|
||||
// Issuer is the Dex issuer URL (e.g., "https://dex.example.com/dex")
|
||||
Issuer string
|
||||
}
|
||||
|
||||
// NewDexManager creates a new instance of DexManager.
|
||||
func NewDexManager(config DexClientConfig, appMetrics telemetry.AppMetrics) (*DexManager, error) {
|
||||
if config.GRPCAddr == "" {
|
||||
return nil, fmt.Errorf("dex IdP configuration is incomplete, GRPCAddr is missing")
|
||||
}
|
||||
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
return &DexManager{
|
||||
grpcAddr: config.GRPCAddr,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getConnection returns a gRPC connection to Dex, creating one if necessary.
|
||||
// It also checks if an existing connection is still healthy and reconnects if needed.
|
||||
func (dm *DexManager) getConnection(ctx context.Context) (*grpc.ClientConn, error) {
|
||||
dm.mux.Lock()
|
||||
defer dm.mux.Unlock()
|
||||
|
||||
if dm.conn != nil {
|
||||
state := dm.conn.GetState()
|
||||
// If connection is shutdown or in a transient failure, close and reconnect
|
||||
if state == connectivity.Shutdown || state == connectivity.TransientFailure {
|
||||
log.WithContext(ctx).Debugf("Dex gRPC connection in state %s, reconnecting", state)
|
||||
_ = dm.conn.Close()
|
||||
dm.conn = nil
|
||||
} else {
|
||||
return dm.conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("connecting to Dex gRPC API at %s", dm.grpcAddr)
|
||||
|
||||
conn, err := grpc.NewClient(dm.grpcAddr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Dex gRPC API: %w", err)
|
||||
}
|
||||
|
||||
dm.conn = conn
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// getDexClient returns a Dex API client.
|
||||
func (dm *DexManager) getDexClient(ctx context.Context) (api.DexClient, error) {
|
||||
conn, err := dm.getConnection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return api.NewDexClient(conn), nil
|
||||
}
|
||||
|
||||
// encodeDexUserID encodes a user ID and connector ID into Dex's composite format.
|
||||
// This is the reverse of parseDexUserID - it creates the base64-encoded protobuf
|
||||
// format that Dex uses in JWT tokens.
|
||||
func encodeDexUserID(userID, connectorID string) string {
|
||||
// Build simple protobuf structure:
|
||||
// Field 1 (tag 0x0a): user ID string
|
||||
// Field 2 (tag 0x12): connector ID string
|
||||
buf := make([]byte, 0, 2+len(userID)+2+len(connectorID))
|
||||
|
||||
// Field 1: user ID
|
||||
buf = append(buf, 0x0a) // tag for field 1, wire type 2 (length-delimited)
|
||||
buf = append(buf, byte(len(userID))) // length
|
||||
buf = append(buf, []byte(userID)...) // value
|
||||
|
||||
// Field 2: connector ID
|
||||
buf = append(buf, 0x12) // tag for field 2, wire type 2 (length-delimited)
|
||||
buf = append(buf, byte(len(connectorID))) // length
|
||||
buf = append(buf, []byte(connectorID)...) // value
|
||||
|
||||
return base64.StdEncoding.EncodeToString(buf)
|
||||
}
|
||||
|
||||
// parseDexUserID extracts the actual user ID from Dex's composite user ID.
|
||||
// Dex encodes user IDs in JWT tokens as base64-encoded protobuf with format:
|
||||
// - Field 1 (string): actual user ID
|
||||
// - Field 2 (string): connector ID (e.g., "local")
|
||||
// If the ID is not in this format, it returns the original ID.
|
||||
func parseDexUserID(compositeID string) string {
|
||||
// Try to decode as standard base64
|
||||
decoded, err := base64.StdEncoding.DecodeString(compositeID)
|
||||
if err != nil {
|
||||
// Try URL-safe base64
|
||||
decoded, err = base64.RawURLEncoding.DecodeString(compositeID)
|
||||
if err != nil {
|
||||
// Not base64 encoded, return as-is
|
||||
return compositeID
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the simple protobuf structure
|
||||
// Field 1 (tag 0x0a): user ID string
|
||||
// Field 2 (tag 0x12): connector ID string
|
||||
if len(decoded) < 2 {
|
||||
return compositeID
|
||||
}
|
||||
|
||||
// Check for field 1 tag (0x0a = field 1, wire type 2/length-delimited)
|
||||
if decoded[0] != 0x0a {
|
||||
return compositeID
|
||||
}
|
||||
|
||||
// Read the length of the user ID string
|
||||
length := int(decoded[1])
|
||||
if len(decoded) < 2+length {
|
||||
return compositeID
|
||||
}
|
||||
|
||||
// Extract the user ID
|
||||
userID := string(decoded[2 : 2+length])
|
||||
return userID
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
// Dex doesn't support app metadata, so this is a no-op.
|
||||
func (dm *DexManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from Dex via user ID.
|
||||
func (dm *DexManager) GetUserDataByID(ctx context.Context, userID string, _ AppMetadata) (*UserData, error) {
|
||||
if dm.appMetrics != nil {
|
||||
dm.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
client, err := dm.getDexClient(ctx)
|
||||
if err != nil {
|
||||
if dm.appMetrics != nil {
|
||||
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := client.ListPasswords(ctx, &api.ListPasswordReq{})
|
||||
if err != nil {
|
||||
if dm.appMetrics != nil {
|
||||
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, fmt.Errorf("failed to list passwords from Dex: %w", err)
|
||||
}
|
||||
|
||||
// Try to parse the composite user ID from Dex JWT token
|
||||
actualUserID := parseDexUserID(userID)
|
||||
|
||||
for _, p := range resp.Passwords {
|
||||
// Match against both the raw userID and the parsed actualUserID
|
||||
if p.UserId == userID || p.UserId == actualUserID {
|
||||
return &UserData{
|
||||
Email: p.Email,
|
||||
Name: p.Username,
|
||||
ID: userID, // Return the original ID for consistency
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("user with ID %s not found", userID)
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given account.
|
||||
// Since Dex doesn't have account concepts, this returns all users.
|
||||
func (dm *DexManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
if dm.appMetrics != nil {
|
||||
dm.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
users, err := dm.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set the account ID for all users
|
||||
for _, user := range users {
|
||||
user.AppMetadata.WTAccountID = accountID
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// Since Dex doesn't have account concepts, all users are returned under UnsetAccountID.
|
||||
func (dm *DexManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
if dm.appMetrics != nil {
|
||||
dm.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
users, err := dm.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
indexedUsers[UnsetAccountID] = users
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in Dex's password database.
|
||||
func (dm *DexManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
if dm.appMetrics != nil {
|
||||
dm.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
|
||||
client, err := dm.getDexClient(ctx)
|
||||
if err != nil {
|
||||
if dm.appMetrics != nil {
|
||||
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate a random password for the new user
|
||||
password := GeneratePassword(16, 2, 2, 2)
|
||||
|
||||
// Hash the password using bcrypt
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
// Generate a user ID from email (Dex uses email as the key, but we need a stable ID)
|
||||
userID := strings.ReplaceAll(email, "@", "-at-")
|
||||
userID = strings.ReplaceAll(userID, ".", "-")
|
||||
|
||||
req := &api.CreatePasswordReq{
|
||||
Password: &api.Password{
|
||||
Email: email,
|
||||
Username: name,
|
||||
UserId: userID,
|
||||
Hash: hashedPassword,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.CreatePassword(ctx, req)
|
||||
if err != nil {
|
||||
if dm.appMetrics != nil {
|
||||
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, fmt.Errorf("failed to create user in Dex: %w", err)
|
||||
}
|
||||
|
||||
if resp.AlreadyExists {
|
||||
return nil, fmt.Errorf("user with email %s already exists", email)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("created user %s in Dex", email)
|
||||
|
||||
return &UserData{
|
||||
Email: email,
|
||||
Name: name,
|
||||
ID: userID,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTInvitedBy: invitedByEmail,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (dm *DexManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
if dm.appMetrics != nil {
|
||||
dm.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
client, err := dm.getDexClient(ctx)
|
||||
if err != nil {
|
||||
if dm.appMetrics != nil {
|
||||
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := client.ListPasswords(ctx, &api.ListPasswordReq{})
|
||||
if err != nil {
|
||||
if dm.appMetrics != nil {
|
||||
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, fmt.Errorf("failed to list passwords from Dex: %w", err)
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, p := range resp.Passwords {
|
||||
if strings.EqualFold(p.Email, email) {
|
||||
// Encode the user ID in Dex's composite format to match stored IDs
|
||||
encodedID := encodeDexUserID(p.UserId, "local")
|
||||
users = append(users, &UserData{
|
||||
Email: p.Email,
|
||||
Name: p.Username,
|
||||
ID: encodedID,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resends an invitation to a user.
|
||||
// Dex doesn't support invitations, so this returns an error.
|
||||
func (dm *DexManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented for Dex")
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user from Dex by user ID.
|
||||
func (dm *DexManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
if dm.appMetrics != nil {
|
||||
dm.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
client, err := dm.getDexClient(ctx)
|
||||
if err != nil {
|
||||
if dm.appMetrics != nil {
|
||||
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// First, find the user's email by ID
|
||||
resp, err := client.ListPasswords(ctx, &api.ListPasswordReq{})
|
||||
if err != nil {
|
||||
if dm.appMetrics != nil {
|
||||
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return fmt.Errorf("failed to list passwords from Dex: %w", err)
|
||||
}
|
||||
|
||||
// Try to parse the composite user ID from Dex JWT token
|
||||
actualUserID := parseDexUserID(userID)
|
||||
|
||||
var email string
|
||||
for _, p := range resp.Passwords {
|
||||
if p.UserId == userID || p.UserId == actualUserID {
|
||||
email = p.Email
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if email == "" {
|
||||
return fmt.Errorf("user with ID %s not found", userID)
|
||||
}
|
||||
|
||||
// Delete the user by email
|
||||
deleteResp, err := client.DeletePassword(ctx, &api.DeletePasswordReq{
|
||||
Email: email,
|
||||
})
|
||||
if err != nil {
|
||||
if dm.appMetrics != nil {
|
||||
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return fmt.Errorf("failed to delete user from Dex: %w", err)
|
||||
}
|
||||
|
||||
if deleteResp.NotFound {
|
||||
return fmt.Errorf("user with email %s not found", email)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("deleted user %s from Dex", email)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAllUsers retrieves all users from Dex's password database.
|
||||
func (dm *DexManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||
client, err := dm.getDexClient(ctx)
|
||||
if err != nil {
|
||||
if dm.appMetrics != nil {
|
||||
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := client.ListPasswords(ctx, &api.ListPasswordReq{})
|
||||
if err != nil {
|
||||
if dm.appMetrics != nil {
|
||||
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, fmt.Errorf("failed to list passwords from Dex: %w", err)
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0, len(resp.Passwords))
|
||||
for _, p := range resp.Passwords {
|
||||
// Encode the user ID in Dex's composite format (base64-encoded protobuf)
|
||||
// to match how NetBird stores user IDs from Dex JWT tokens.
|
||||
// The connector ID "local" is used for Dex's password database.
|
||||
encodedID := encodeDexUserID(p.UserId, "local")
|
||||
users = append(users, &UserData{
|
||||
Email: p.Email,
|
||||
Name: p.Username,
|
||||
ID: encodedID,
|
||||
})
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
137
management/server/idp/dex_test.go
Normal file
137
management/server/idp/dex_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func TestNewDexManager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig DexClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := DexClientConfig{
|
||||
GRPCAddr: "localhost:5557",
|
||||
Issuer: "https://dex.example.com/dex",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
testCase2Config := defaultTestConfig
|
||||
testCase2Config.GRPCAddr = ""
|
||||
|
||||
testCase2 := test{
|
||||
name: "Missing GRPCAddr Configuration",
|
||||
inputConfig: testCase2Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when GRPCAddr is empty",
|
||||
}
|
||||
|
||||
// Test with empty issuer - should still work since issuer is optional for the manager
|
||||
testCase3Config := defaultTestConfig
|
||||
testCase3Config.Issuer = ""
|
||||
|
||||
testCase3 := test{
|
||||
name: "Missing Issuer Configuration - OK",
|
||||
inputConfig: testCase3Config,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error when issuer is empty",
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
manager, err := NewDexManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
if err == nil {
|
||||
require.NotNil(t, manager, "manager should not be nil")
|
||||
require.Equal(t, testCase.inputConfig.GRPCAddr, manager.grpcAddr, "grpcAddr should match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDexManagerUpdateUserAppMetadata(t *testing.T) {
|
||||
config := DexClientConfig{
|
||||
GRPCAddr: "localhost:5557",
|
||||
Issuer: "https://dex.example.com/dex",
|
||||
}
|
||||
|
||||
manager, err := NewDexManager(config, &telemetry.MockAppMetrics{})
|
||||
require.NoError(t, err, "should create manager without error")
|
||||
|
||||
// UpdateUserAppMetadata should be a no-op for Dex
|
||||
err = manager.UpdateUserAppMetadata(context.Background(), "test-user-id", AppMetadata{
|
||||
WTAccountID: "test-account",
|
||||
})
|
||||
require.NoError(t, err, "UpdateUserAppMetadata should not return error")
|
||||
}
|
||||
|
||||
func TestDexManagerInviteUserByID(t *testing.T) {
|
||||
config := DexClientConfig{
|
||||
GRPCAddr: "localhost:5557",
|
||||
Issuer: "https://dex.example.com/dex",
|
||||
}
|
||||
|
||||
manager, err := NewDexManager(config, &telemetry.MockAppMetrics{})
|
||||
require.NoError(t, err, "should create manager without error")
|
||||
|
||||
// InviteUserByID should return an error for Dex
|
||||
err = manager.InviteUserByID(context.Background(), "test-user-id")
|
||||
require.Error(t, err, "InviteUserByID should return error")
|
||||
require.Contains(t, err.Error(), "not implemented", "error should mention not implemented")
|
||||
}
|
||||
|
||||
func TestParseDexUserID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
compositeID string
|
||||
expectedID string
|
||||
}{
|
||||
{
|
||||
name: "Parse base64-encoded protobuf composite ID",
|
||||
// This is a real Dex composite ID: contains user ID "cf5db180-d360-484d-9b78-c5db92146420" and connector "local"
|
||||
compositeID: "CiRjZjVkYjE4MC1kMzYwLTQ4NGQtOWI3OC1jNWRiOTIxNDY0MjASBWxvY2Fs",
|
||||
expectedID: "cf5db180-d360-484d-9b78-c5db92146420",
|
||||
},
|
||||
{
|
||||
name: "Return plain ID unchanged",
|
||||
compositeID: "simple-user-id",
|
||||
expectedID: "simple-user-id",
|
||||
},
|
||||
{
|
||||
name: "Return UUID unchanged",
|
||||
compositeID: "cf5db180-d360-484d-9b78-c5db92146420",
|
||||
expectedID: "cf5db180-d360-484d-9b78-c5db92146420",
|
||||
},
|
||||
{
|
||||
name: "Handle empty string",
|
||||
compositeID: "",
|
||||
expectedID: "",
|
||||
},
|
||||
{
|
||||
name: "Handle invalid base64",
|
||||
compositeID: "not-valid-base64!!!",
|
||||
expectedID: "not-valid-base64!!!",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseDexUserID(tt.compositeID)
|
||||
require.Equal(t, tt.expectedID, result, "parsed user ID should match expected")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -173,40 +173,40 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
|
||||
|
||||
return NewZitadelManager(*zitadelClientConfig, appMetrics)
|
||||
case "authentik":
|
||||
authentikConfig := AuthentikClientConfig{
|
||||
return NewAuthentikManager(AuthentikClientConfig{
|
||||
Issuer: config.ClientConfig.Issuer,
|
||||
ClientID: config.ClientConfig.ClientID,
|
||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
||||
GrantType: config.ClientConfig.GrantType,
|
||||
Username: config.ExtraConfig["Username"],
|
||||
Password: config.ExtraConfig["Password"],
|
||||
}
|
||||
return NewAuthentikManager(authentikConfig, appMetrics)
|
||||
}, appMetrics)
|
||||
case "okta":
|
||||
oktaClientConfig := OktaClientConfig{
|
||||
return NewOktaManager(OktaClientConfig{
|
||||
Issuer: config.ClientConfig.Issuer,
|
||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
||||
GrantType: config.ClientConfig.GrantType,
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
}
|
||||
return NewOktaManager(oktaClientConfig, appMetrics)
|
||||
}, appMetrics)
|
||||
case "google":
|
||||
googleClientConfig := GoogleWorkspaceClientConfig{
|
||||
return NewGoogleWorkspaceManager(ctx, GoogleWorkspaceClientConfig{
|
||||
ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"],
|
||||
CustomerID: config.ExtraConfig["CustomerId"],
|
||||
}
|
||||
return NewGoogleWorkspaceManager(ctx, googleClientConfig, appMetrics)
|
||||
}, appMetrics)
|
||||
case "jumpcloud":
|
||||
jumpcloudConfig := JumpCloudClientConfig{
|
||||
return NewJumpCloudManager(JumpCloudClientConfig{
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
}
|
||||
return NewJumpCloudManager(jumpcloudConfig, appMetrics)
|
||||
}, appMetrics)
|
||||
case "pocketid":
|
||||
pocketidConfig := PocketIdClientConfig{
|
||||
return NewPocketIdManager(PocketIdClientConfig{
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"],
|
||||
}
|
||||
return NewPocketIdManager(pocketidConfig, appMetrics)
|
||||
}, appMetrics)
|
||||
case "dex":
|
||||
return NewDexManager(DexClientConfig{
|
||||
GRPCAddr: config.ExtraConfig["GRPCAddr"],
|
||||
Issuer: config.ClientConfig.Issuer,
|
||||
}, appMetrics)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
||||
}
|
||||
|
||||
@@ -2,11 +2,12 @@ package mock_server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
@@ -988,3 +989,7 @@ func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, ac
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) {
|
||||
return "something", nil
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
|
||||
|
||||
// fetch all the peers that have access to the user's peers
|
||||
for _, peer := range peers {
|
||||
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap)
|
||||
aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap, account.GetActiveGroupUsers())
|
||||
for _, p := range aclPeers {
|
||||
peersMap[p.ID] = p
|
||||
}
|
||||
@@ -1057,7 +1057,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
for _, p := range userPeers {
|
||||
aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap)
|
||||
aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap, account.GetActiveGroupUsers())
|
||||
for _, aclPeer := range aclPeers {
|
||||
if aclPeer.ID == peer.ID {
|
||||
return peer, nil
|
||||
|
||||
@@ -246,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
|
||||
t.Run("check that all peers get map", func(t *testing.T) {
|
||||
for _, p := range account.Peers {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), p, validatedPeers, account.GetActiveGroupUsers())
|
||||
assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
|
||||
assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check first peer map details", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 8)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
@@ -509,7 +509,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("check port ranges support for older peers", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 1)
|
||||
assert.Contains(t, peers, account.Peers["peerI"])
|
||||
|
||||
@@ -635,7 +635,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("check first peer map", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
@@ -665,7 +665,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("check second peer map", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Contains(t, peers, account.Peers["peerB"])
|
||||
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
@@ -697,7 +697,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
account.Policies[1].Rules[0].Bidirectional = false
|
||||
|
||||
t.Run("check first peer map directional only", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
@@ -719,7 +719,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("check second peer map directional only", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Contains(t, peers, account.Peers["peerB"])
|
||||
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
@@ -917,7 +917,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
|
||||
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
|
||||
// will establish a connection with all source peers satisfying the NB posture check.
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -927,7 +927,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, 7)
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
@@ -992,7 +992,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
|
||||
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -1002,7 +1002,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
|
||||
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -1017,19 +1017,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
|
||||
// no connection should be established to any peer of destination group
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 0)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
|
||||
// no connection should be established to any peer of destination group
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
|
||||
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 0)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
|
||||
|
||||
@@ -1044,14 +1044,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
|
||||
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 3)
|
||||
assert.Len(t, firewallRules, 3)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers)
|
||||
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 5)
|
||||
// assert peers from Group Swarm
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
|
||||
@@ -63,6 +63,8 @@ type SqlStore struct {
|
||||
installationPK int
|
||||
storeEngine types.Engine
|
||||
pool *pgxpool.Pool
|
||||
|
||||
transactionTimeout time.Duration
|
||||
}
|
||||
|
||||
type installation struct {
|
||||
@@ -84,6 +86,14 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
conns = runtime.NumCPU()
|
||||
}
|
||||
|
||||
transactionTimeout := 5 * time.Minute
|
||||
if v := os.Getenv("NB_STORE_TRANSACTION_TIMEOUT"); v != "" {
|
||||
if parsed, err := time.ParseDuration(v); err == nil {
|
||||
transactionTimeout = parsed
|
||||
}
|
||||
}
|
||||
log.WithContext(ctx).Infof("Setting transaction timeout to %v", transactionTimeout)
|
||||
|
||||
if storeEngine == types.SqliteStoreEngine {
|
||||
if err == nil {
|
||||
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
|
||||
@@ -101,7 +111,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
|
||||
if skipMigration {
|
||||
log.WithContext(ctx).Infof("skipping migration")
|
||||
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
|
||||
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1, transactionTimeout: transactionTimeout}, nil
|
||||
}
|
||||
|
||||
if err := migratePreAuto(ctx, db); err != nil {
|
||||
@@ -120,7 +130,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
return nil, fmt.Errorf("migratePostAuto: %w", err)
|
||||
}
|
||||
|
||||
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
|
||||
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1, transactionTimeout: transactionTimeout}, nil
|
||||
}
|
||||
|
||||
func GetKeyQueryCondition(s *SqlStore) string {
|
||||
@@ -1910,16 +1920,17 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t
|
||||
if len(policyIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)`
|
||||
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges, authorized_groups, authorized_user FROM policy_rules WHERE policy_id = ANY($1)`
|
||||
rows, err := s.pool.Query(ctx, query, policyIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) {
|
||||
var r types.PolicyRule
|
||||
var dest, destRes, sources, sourceRes, ports, portRanges []byte
|
||||
var dest, destRes, sources, sourceRes, ports, portRanges, authorizedGroups []byte
|
||||
var enabled, bidirectional sql.NullBool
|
||||
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges)
|
||||
var authorizedUser sql.NullString
|
||||
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges, &authorizedGroups, &authorizedUser)
|
||||
if err == nil {
|
||||
if enabled.Valid {
|
||||
r.Enabled = enabled.Bool
|
||||
@@ -1945,6 +1956,12 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t
|
||||
if portRanges != nil {
|
||||
_ = json.Unmarshal(portRanges, &r.PortRanges)
|
||||
}
|
||||
if authorizedGroups != nil {
|
||||
_ = json.Unmarshal(authorizedGroups, &r.AuthorizedGroups)
|
||||
}
|
||||
if authorizedUser.Valid {
|
||||
r.AuthorizedUser = authorizedUser.String
|
||||
}
|
||||
}
|
||||
return &r, err
|
||||
})
|
||||
@@ -2890,8 +2907,11 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), s.transactionTimeout)
|
||||
defer cancel()
|
||||
|
||||
startTime := time.Now()
|
||||
tx := s.db.Begin()
|
||||
tx := s.db.WithContext(timeoutCtx).Begin()
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
@@ -2926,6 +2946,9 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
|
||||
err := operation(repo)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) {
|
||||
log.WithContext(ctx).Warnf("transaction exceeded %s timeout after %v, stack: %s", s.transactionTimeout, time.Since(startTime), debug.Stack())
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2938,13 +2961,19 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) {
|
||||
log.WithContext(ctx).Warnf("transaction commit exceeded %s timeout after %v, stack: %s", s.transactionTimeout, time.Since(startTime), debug.Stack())
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime))
|
||||
if s.metrics != nil {
|
||||
s.metrics.StoreMetrics().CountTransactionDuration(time.Since(startTime))
|
||||
}
|
||||
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||
@@ -4075,3 +4104,21 @@ func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, gro
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var userID string
|
||||
result := tx.Model(&nbpeer.Peer{}).
|
||||
Select("user_id").
|
||||
Take(&userID, GetKeyQueryCondition(s), peerKey)
|
||||
|
||||
if result.Error != nil {
|
||||
return "", status.Errorf(status.Internal, "failed to get user ID by peer key")
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
@@ -3718,6 +3718,69 @@ func TestSqlStore_GetPeersByGroupIDs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetUserIDByPeerKey(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
userID := "test-user-123"
|
||||
peerKey := "peer-key-abc"
|
||||
|
||||
peer := &nbpeer.Peer{
|
||||
ID: "test-peer-1",
|
||||
Key: peerKey,
|
||||
AccountID: existingAccountID,
|
||||
UserID: userID,
|
||||
IP: net.IP{10, 0, 0, 1},
|
||||
DNSLabel: "test-peer-1",
|
||||
}
|
||||
|
||||
err = store.AddPeerToAccount(context.Background(), peer)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrievedUserID, err := store.GetUserIDByPeerKey(context.Background(), LockingStrengthNone, peerKey)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userID, retrievedUserID)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetUserIDByPeerKey_NotFound(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
nonExistentPeerKey := "non-existent-peer-key"
|
||||
|
||||
userID, err := store.GetUserIDByPeerKey(context.Background(), LockingStrengthNone, nonExistentPeerKey)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, "", userID)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetUserIDByPeerKey_NoUserID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
peerKey := "peer-key-abc"
|
||||
|
||||
peer := &nbpeer.Peer{
|
||||
ID: "test-peer-1",
|
||||
Key: peerKey,
|
||||
AccountID: existingAccountID,
|
||||
UserID: "",
|
||||
IP: net.IP{10, 0, 0, 1},
|
||||
DNSLabel: "test-peer-1",
|
||||
}
|
||||
|
||||
err = store.AddPeerToAccount(context.Background(), peer)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrievedUserID, err := store.GetUserIDByPeerKey(context.Background(), LockingStrengthNone, peerKey)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "", retrievedUserID)
|
||||
}
|
||||
|
||||
func TestSqlStore_ApproveAccountPeers(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
accountID := "test-account"
|
||||
@@ -3794,3 +3857,30 @@ func TestSqlStore_ApproveAccountPeers(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_ExecuteInTransaction_Timeout(t *testing.T) {
|
||||
if os.Getenv("NETBIRD_STORE_ENGINE") == "mysql" {
|
||||
t.Skip("Skipping timeout test for MySQL")
|
||||
}
|
||||
|
||||
t.Setenv("NB_STORE_TRANSACTION_TIMEOUT", "1s")
|
||||
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
sqlStore, ok := store.(*SqlStore)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, 1*time.Second, sqlStore.transactionTimeout)
|
||||
|
||||
ctx := context.Background()
|
||||
err = sqlStore.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
// Sleep for 2 seconds to exceed the 1 second timeout
|
||||
time.Sleep(2 * time.Second)
|
||||
return nil
|
||||
})
|
||||
|
||||
// The transaction should fail with an error (either timeout or already rolled back)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "transaction has already been committed or rolled back", "expected transaction rolled back error, got: %v", err)
|
||||
}
|
||||
|
||||
@@ -204,6 +204,7 @@ type Store interface {
|
||||
MarkAccountPrimary(ctx context.Context, accountID string) error
|
||||
UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error
|
||||
GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) ([]*types.PolicyRule, error)
|
||||
GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error)
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
@@ -45,8 +46,10 @@ const (
|
||||
|
||||
// nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections.
|
||||
nativeSSHPortString = "22022"
|
||||
nativeSSHPortNumber = 22022
|
||||
// defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections.
|
||||
defaultSSHPortString = "22"
|
||||
defaultSSHPortNumber = 22
|
||||
)
|
||||
|
||||
type supportedFeatures struct {
|
||||
@@ -275,6 +278,7 @@ func (a *Account) GetPeerNetworkMap(
|
||||
resourcePolicies map[string][]*Policy,
|
||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
groupIDToUserIDs map[string][]string,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
peer := a.Peers[peerID]
|
||||
@@ -290,7 +294,7 @@ func (a *Account) GetPeerNetworkMap(
|
||||
}
|
||||
}
|
||||
|
||||
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
|
||||
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
|
||||
// exclude expired peers
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
var expiredPeers []*nbpeer.Peer
|
||||
@@ -338,6 +342,8 @@ func (a *Account) GetPeerNetworkMap(
|
||||
OfflinePeers: expiredPeers,
|
||||
FirewallRules: firewallRules,
|
||||
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
EnableSSH: enableSSH,
|
||||
}
|
||||
|
||||
if metrics != nil {
|
||||
@@ -1009,8 +1015,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
|
||||
// GetPeerConnectionResources for a given peer
|
||||
//
|
||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, groupIDToUserIDs map[string][]string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) {
|
||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
|
||||
authorizedUsers := make(map[string]map[string]struct{}) // machine user to list of userIDs
|
||||
sshEnabled := false
|
||||
|
||||
for _, policy := range a.Policies {
|
||||
if !policy.Enabled {
|
||||
@@ -1053,10 +1061,58 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
|
||||
if peerInDestinations {
|
||||
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
|
||||
}
|
||||
|
||||
if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
sshEnabled = true
|
||||
switch {
|
||||
case len(rule.AuthorizedGroups) > 0:
|
||||
for groupID, localUsers := range rule.AuthorizedGroups {
|
||||
userIDs, ok := groupIDToUserIDs[groupID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Tracef("no user IDs found for group ID %s", groupID)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(localUsers) == 0 {
|
||||
localUsers = []string{auth.Wildcard}
|
||||
}
|
||||
|
||||
for _, localUser := range localUsers {
|
||||
if authorizedUsers[localUser] == nil {
|
||||
authorizedUsers[localUser] = make(map[string]struct{})
|
||||
}
|
||||
for _, userID := range userIDs {
|
||||
authorizedUsers[localUser][userID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
case rule.AuthorizedUser != "":
|
||||
if authorizedUsers[auth.Wildcard] == nil {
|
||||
authorizedUsers[auth.Wildcard] = make(map[string]struct{})
|
||||
}
|
||||
authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{}
|
||||
default:
|
||||
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
|
||||
}
|
||||
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
|
||||
sshEnabled = true
|
||||
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return getAccumulatedResources()
|
||||
peers, fwRules := getAccumulatedResources()
|
||||
return peers, fwRules, authorizedUsers, sshEnabled
|
||||
}
|
||||
|
||||
func (a *Account) getAllowedUserIDs() map[string]struct{} {
|
||||
users := make(map[string]struct{})
|
||||
for _, nbUser := range a.Users {
|
||||
if !nbUser.IsBlocked() && !nbUser.IsServiceUser {
|
||||
users[nbUser.Id] = struct{}{}
|
||||
}
|
||||
}
|
||||
return users
|
||||
}
|
||||
|
||||
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
|
||||
@@ -1081,12 +1137,17 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
|
||||
peersExists[peer.ID] = struct{}{}
|
||||
}
|
||||
|
||||
protocol := rule.Protocol
|
||||
if protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
protocol = PolicyRuleProtocolTCP
|
||||
}
|
||||
|
||||
fr := FirewallRule{
|
||||
PolicyID: rule.ID,
|
||||
PeerIP: peer.IP.String(),
|
||||
Direction: direction,
|
||||
Action: string(rule.Action),
|
||||
Protocol: string(rule.Protocol),
|
||||
Protocol: string(protocol),
|
||||
}
|
||||
|
||||
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||
@@ -1108,6 +1169,28 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
|
||||
}
|
||||
}
|
||||
|
||||
func policyRuleImpliesLegacySSH(rule *PolicyRule) bool {
|
||||
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
|
||||
}
|
||||
|
||||
func portRangeIncludesSSH(portRanges []RulePortRange) bool {
|
||||
for _, pr := range portRanges {
|
||||
if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func portsIncludesSSH(ports []string) bool {
|
||||
for _, port := range ports {
|
||||
if port == defaultSSHPortString || port == nativeSSHPortString {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getAllPeersFromGroups for given peer ID and list of groups
|
||||
//
|
||||
// Returns a list of peers from specified groups that pass specified posture checks
|
||||
@@ -1152,7 +1235,11 @@ func (a *Account) getPeerFromResource(resource Resource, peerID string) ([]*nbpe
|
||||
return []*nbpeer.Peer{}, false
|
||||
}
|
||||
|
||||
return []*nbpeer.Peer{peer}, resource.ID == peerID
|
||||
if peer.ID == peerID {
|
||||
return []*nbpeer.Peer{}, true
|
||||
}
|
||||
|
||||
return []*nbpeer.Peer{peer}, false
|
||||
}
|
||||
|
||||
// validatePostureChecksOnPeer validates the posture checks on a peer
|
||||
@@ -1660,6 +1747,26 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Account) GetActiveGroupUsers() map[string][]string {
|
||||
allGroupID := ""
|
||||
group, err := a.GetGroupAll()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get group all: %v", err)
|
||||
} else {
|
||||
allGroupID = group.ID
|
||||
}
|
||||
groups := make(map[string][]string, len(a.GroupsG))
|
||||
for _, user := range a.Users {
|
||||
if !user.IsBlocked() && !user.IsServiceUser {
|
||||
for _, groupID := range user.AutoGroups {
|
||||
groups[groupID] = append(groups[groupID], user.Id)
|
||||
}
|
||||
groups[allGroupID] = append(groups[allGroupID], user.Id)
|
||||
}
|
||||
}
|
||||
return groups
|
||||
}
|
||||
|
||||
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
|
||||
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
||||
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
|
||||
@@ -1691,7 +1798,7 @@ func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer
|
||||
expanded = append(expanded, &fr)
|
||||
}
|
||||
|
||||
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) {
|
||||
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) || rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
expanded = addNativeSSHRule(base, expanded)
|
||||
}
|
||||
|
||||
|
||||
@@ -1105,6 +1105,193 @@ func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GetActiveGroupUsers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
expected map[string][]string
|
||||
}{
|
||||
{
|
||||
name: "all users are active",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group2", "group3"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user3": {
|
||||
Id: "user3",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{
|
||||
"group1": {"user1", "user3"},
|
||||
"group2": {"user1", "user2"},
|
||||
"group3": {"user2"},
|
||||
"": {"user1", "user2", "user3"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "some users are blocked",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group2", "group3"},
|
||||
Blocked: true,
|
||||
},
|
||||
"user3": {
|
||||
Id: "user3",
|
||||
AutoGroups: []string{"group1", "group3"},
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{
|
||||
"group1": {"user1", "user3"},
|
||||
"group2": {"user1"},
|
||||
"group3": {"user3"},
|
||||
"": {"user1", "user3"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all users are blocked",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: true,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group2"},
|
||||
Blocked: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{},
|
||||
},
|
||||
{
|
||||
name: "user with no auto groups",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{},
|
||||
Blocked: false,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{
|
||||
"group1": {"user2"},
|
||||
"": {"user1", "user2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty account",
|
||||
account: &Account{
|
||||
Users: map[string]*User{},
|
||||
},
|
||||
expected: map[string][]string{},
|
||||
},
|
||||
{
|
||||
name: "multiple users in same group",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user3": {
|
||||
Id: "user3",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{
|
||||
"group1": {"user1", "user2", "user3"},
|
||||
"": {"user1", "user2", "user3"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user in multiple groups with blocked users",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{"group1", "group2", "group3"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
Blocked: true,
|
||||
},
|
||||
"user3": {
|
||||
Id: "user3",
|
||||
AutoGroups: []string{"group3"},
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{
|
||||
"group1": {"user1"},
|
||||
"group2": {"user1"},
|
||||
"group3": {"user1", "user3"},
|
||||
"": {"user1", "user3"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.account.GetActiveGroupUsers()
|
||||
|
||||
// Check that the number of groups matches
|
||||
assert.Equal(t, len(tt.expected), len(result), "number of groups should match")
|
||||
|
||||
// Check each group's users
|
||||
for groupID, expectedUsers := range tt.expected {
|
||||
actualUsers, exists := result[groupID]
|
||||
assert.True(t, exists, "group %s should exist in result", groupID)
|
||||
assert.ElementsMatch(t, expectedUsers, actualUsers, "users in group %s should match", groupID)
|
||||
}
|
||||
|
||||
// Ensure no extra groups in result
|
||||
for groupID := range result {
|
||||
_, exists := tt.expected[groupID]
|
||||
assert.True(t, exists, "unexpected group %s in result", groupID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -38,6 +38,8 @@ type NetworkMap struct {
|
||||
FirewallRules []*FirewallRule
|
||||
RoutesFirewallRules []*RouteFirewallRule
|
||||
ForwardingRules []*ForwardingRule
|
||||
AuthorizedUsers map[string]map[string]struct{}
|
||||
EnableSSH bool
|
||||
}
|
||||
|
||||
func (nm *NetworkMap) Merge(other *NetworkMap) {
|
||||
|
||||
@@ -69,7 +69,7 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
@@ -141,7 +141,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
|
||||
b.Run("old builder", func(b *testing.B) {
|
||||
for range b.N {
|
||||
for _, peerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -201,7 +201,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
@@ -320,7 +320,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
|
||||
b.Run("old builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -395,7 +395,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
@@ -550,7 +550,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
|
||||
b.Run("old builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -604,7 +604,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
@@ -730,7 +730,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
@@ -847,7 +847,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
|
||||
b.Run("old builder after delete", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -23,6 +23,8 @@ const (
|
||||
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
|
||||
// PolicyRuleProtocolICMP type of traffic
|
||||
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
|
||||
// PolicyRuleProtocolNetbirdSSH type of traffic
|
||||
PolicyRuleProtocolNetbirdSSH = PolicyRuleProtocolType("netbird-ssh")
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -167,6 +169,8 @@ func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error)
|
||||
protocol = PolicyRuleProtocolUDP
|
||||
case "icmp":
|
||||
return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'")
|
||||
case "netbird-ssh":
|
||||
return PolicyRuleProtocolNetbirdSSH, RulePortRange{Start: nativeSSHPortNumber, End: nativeSSHPortNumber}, nil
|
||||
default:
|
||||
return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr)
|
||||
}
|
||||
|
||||
@@ -80,6 +80,12 @@ type PolicyRule struct {
|
||||
|
||||
// PortRanges a list of port ranges.
|
||||
PortRanges []RulePortRange `gorm:"serializer:json"`
|
||||
|
||||
// AuthorizedGroups is a map of groupIDs and their respective access to local users via ssh
|
||||
AuthorizedGroups map[string][]string `gorm:"serializer:json"`
|
||||
|
||||
// AuthorizedUser is a list of userIDs that are authorized to access local resources via ssh
|
||||
AuthorizedUser string
|
||||
}
|
||||
|
||||
// Copy returns a copy of a policy rule
|
||||
@@ -99,10 +105,16 @@ func (pm *PolicyRule) Copy() *PolicyRule {
|
||||
Protocol: pm.Protocol,
|
||||
Ports: make([]string, len(pm.Ports)),
|
||||
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
|
||||
AuthorizedGroups: make(map[string][]string, len(pm.AuthorizedGroups)),
|
||||
AuthorizedUser: pm.AuthorizedUser,
|
||||
}
|
||||
copy(rule.Destinations, pm.Destinations)
|
||||
copy(rule.Sources, pm.Sources)
|
||||
copy(rule.Ports, pm.Ports)
|
||||
copy(rule.PortRanges, pm.PortRanges)
|
||||
for k, v := range pm.AuthorizedGroups {
|
||||
rule.AuthorizedGroups[k] = make([]string, len(v))
|
||||
copy(rule.AuthorizedGroups[k], v)
|
||||
}
|
||||
return rule
|
||||
}
|
||||
|
||||
@@ -523,16 +523,14 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
|
||||
_, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
|
||||
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process update for user %s: %w", update.Id, err)
|
||||
}
|
||||
|
||||
if userHadPeers {
|
||||
updateAccountPeers = true
|
||||
}
|
||||
updateAccountPeers = true
|
||||
|
||||
err = transaction.SaveUser(ctx, updatedUser)
|
||||
if err != nil {
|
||||
@@ -581,7 +579,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
}
|
||||
}
|
||||
|
||||
if settings.GroupsPropagationEnabled && updateAccountPeers {
|
||||
if updateAccountPeers {
|
||||
if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return nil, fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
@@ -1379,11 +1379,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Creating a new regular user should not update account peers and not send peer update
|
||||
// Creating a new regular user should send peer update (as users are not filtered yet)
|
||||
t.Run("creating new regular user with no groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1402,11 +1402,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// updating user with no linked peers should not update account peers and not send peer update
|
||||
// updating user with no linked peers should update account peers and send peer update (as users are not filtered yet)
|
||||
t.Run("updating user with no linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
216
release_files/freebsd-port-diff.sh
Executable file
216
release_files/freebsd-port-diff.sh
Executable file
@@ -0,0 +1,216 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# FreeBSD Port Diff Generator for NetBird
|
||||
#
|
||||
# This script generates the diff file required for submitting a FreeBSD port update.
|
||||
# It works on macOS, Linux, and FreeBSD by fetching files from FreeBSD cgit and
|
||||
# computing checksums from the Go module proxy.
|
||||
#
|
||||
# Usage: ./freebsd-port-diff.sh [new_version]
|
||||
# Example: ./freebsd-port-diff.sh 0.60.7
|
||||
#
|
||||
# If no version is provided, it fetches the latest from GitHub.
|
||||
|
||||
set -e
|
||||
|
||||
GITHUB_REPO="netbirdio/netbird"
|
||||
PORTS_CGIT_BASE="https://cgit.freebsd.org/ports/plain/security/netbird"
|
||||
GO_PROXY="https://proxy.golang.org/github.com/netbirdio/netbird/@v"
|
||||
OUTPUT_DIR="${OUTPUT_DIR:-.}"
|
||||
AWK_FIRST_FIELD='{print $1}'
|
||||
|
||||
fetch_all_tags() {
|
||||
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
|
||||
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
|
||||
sed 's/.*\/v//' | \
|
||||
sort -u -V
|
||||
return 0
|
||||
}
|
||||
|
||||
fetch_current_ports_version() {
|
||||
echo "Fetching current version from FreeBSD ports..." >&2
|
||||
curl -sL "${PORTS_CGIT_BASE}/Makefile" 2>/dev/null | \
|
||||
grep -E "^DISTVERSION=" | \
|
||||
sed 's/DISTVERSION=[[:space:]]*//' | \
|
||||
tr -d '\t '
|
||||
return 0
|
||||
}
|
||||
|
||||
fetch_latest_github_release() {
|
||||
echo "Fetching latest release from GitHub..." >&2
|
||||
fetch_all_tags | tail -1
|
||||
return 0
|
||||
}
|
||||
|
||||
fetch_ports_file() {
|
||||
local filename="$1"
|
||||
curl -sL "${PORTS_CGIT_BASE}/${filename}" 2>/dev/null
|
||||
return 0
|
||||
}
|
||||
|
||||
compute_checksums() {
|
||||
local version="$1"
|
||||
local tmpdir
|
||||
tmpdir=$(mktemp -d)
|
||||
# shellcheck disable=SC2064
|
||||
trap "rm -rf '$tmpdir'" EXIT
|
||||
|
||||
echo "Downloading files from Go module proxy for v${version}..." >&2
|
||||
|
||||
local mod_file="${tmpdir}/v${version}.mod"
|
||||
local zip_file="${tmpdir}/v${version}.zip"
|
||||
|
||||
curl -sL "${GO_PROXY}/v${version}.mod" -o "$mod_file" 2>/dev/null
|
||||
curl -sL "${GO_PROXY}/v${version}.zip" -o "$zip_file" 2>/dev/null
|
||||
|
||||
if [[ ! -s "$mod_file" ]] || [[ ! -s "$zip_file" ]]; then
|
||||
echo "Error: Could not download files from Go module proxy" >&2
|
||||
return 1
|
||||
fi
|
||||
|
||||
local mod_sha256 mod_size zip_sha256 zip_size
|
||||
|
||||
if command -v sha256sum &>/dev/null; then
|
||||
mod_sha256=$(sha256sum "$mod_file" | awk "$AWK_FIRST_FIELD")
|
||||
zip_sha256=$(sha256sum "$zip_file" | awk "$AWK_FIRST_FIELD")
|
||||
elif command -v shasum &>/dev/null; then
|
||||
mod_sha256=$(shasum -a 256 "$mod_file" | awk "$AWK_FIRST_FIELD")
|
||||
zip_sha256=$(shasum -a 256 "$zip_file" | awk "$AWK_FIRST_FIELD")
|
||||
else
|
||||
echo "Error: No sha256 command found" >&2
|
||||
return 1
|
||||
fi
|
||||
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
mod_size=$(stat -f%z "$mod_file")
|
||||
zip_size=$(stat -f%z "$zip_file")
|
||||
else
|
||||
mod_size=$(stat -c%s "$mod_file")
|
||||
zip_size=$(stat -c%s "$zip_file")
|
||||
fi
|
||||
|
||||
echo "TIMESTAMP = $(date +%s)"
|
||||
echo "SHA256 (go/security_netbird/netbird-v${version}/v${version}.mod) = ${mod_sha256}"
|
||||
echo "SIZE (go/security_netbird/netbird-v${version}/v${version}.mod) = ${mod_size}"
|
||||
echo "SHA256 (go/security_netbird/netbird-v${version}/v${version}.zip) = ${zip_sha256}"
|
||||
echo "SIZE (go/security_netbird/netbird-v${version}/v${version}.zip) = ${zip_size}"
|
||||
return 0
|
||||
}
|
||||
|
||||
generate_new_makefile() {
|
||||
local new_version="$1"
|
||||
local old_makefile="$2"
|
||||
|
||||
# Check if old version had PORTREVISION
|
||||
if echo "$old_makefile" | grep -q "^PORTREVISION="; then
|
||||
# Remove PORTREVISION line and update DISTVERSION
|
||||
echo "$old_makefile" | \
|
||||
sed "s/^DISTVERSION=.*/DISTVERSION= ${new_version}/" | \
|
||||
grep -v "^PORTREVISION="
|
||||
else
|
||||
# Just update DISTVERSION
|
||||
echo "$old_makefile" | \
|
||||
sed "s/^DISTVERSION=.*/DISTVERSION= ${new_version}/"
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
NEW_VERSION="${1:-}"
|
||||
|
||||
# Auto-detect versions if not provided
|
||||
OLD_VERSION=$(fetch_current_ports_version)
|
||||
if [[ -z "$OLD_VERSION" ]]; then
|
||||
echo "Error: Could not fetch current version from FreeBSD ports" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "Current FreeBSD ports version: ${OLD_VERSION}" >&2
|
||||
|
||||
if [[ -z "$NEW_VERSION" ]]; then
|
||||
NEW_VERSION=$(fetch_latest_github_release)
|
||||
if [[ -z "$NEW_VERSION" ]]; then
|
||||
echo "Error: Could not fetch latest release from GitHub" >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
echo "Target version: ${NEW_VERSION}" >&2
|
||||
|
||||
if [[ "$OLD_VERSION" = "$NEW_VERSION" ]]; then
|
||||
echo "Port is already at version ${NEW_VERSION}. Nothing to do." >&2
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "" >&2
|
||||
|
||||
# Fetch current files
|
||||
echo "Fetching current Makefile from FreeBSD ports..." >&2
|
||||
OLD_MAKEFILE=$(fetch_ports_file "Makefile")
|
||||
if [[ -z "$OLD_MAKEFILE" ]]; then
|
||||
echo "Error: Could not fetch Makefile" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Fetching current distinfo from FreeBSD ports..." >&2
|
||||
OLD_DISTINFO=$(fetch_ports_file "distinfo")
|
||||
if [[ -z "$OLD_DISTINFO" ]]; then
|
||||
echo "Error: Could not fetch distinfo" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Generate new files
|
||||
echo "Generating new Makefile..." >&2
|
||||
NEW_MAKEFILE=$(generate_new_makefile "$NEW_VERSION" "$OLD_MAKEFILE")
|
||||
|
||||
echo "Computing checksums for new version..." >&2
|
||||
NEW_DISTINFO=$(compute_checksums "$NEW_VERSION")
|
||||
if [[ -z "$NEW_DISTINFO" ]]; then
|
||||
echo "Error: Could not compute checksums" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Create temp files for diff
|
||||
TMPDIR=$(mktemp -d)
|
||||
# shellcheck disable=SC2064
|
||||
trap "rm -rf '$TMPDIR'" EXIT
|
||||
|
||||
mkdir -p "${TMPDIR}/a/security/netbird" "${TMPDIR}/b/security/netbird"
|
||||
|
||||
echo "$OLD_MAKEFILE" > "${TMPDIR}/a/security/netbird/Makefile"
|
||||
echo "$OLD_DISTINFO" > "${TMPDIR}/a/security/netbird/distinfo"
|
||||
echo "$NEW_MAKEFILE" > "${TMPDIR}/b/security/netbird/Makefile"
|
||||
echo "$NEW_DISTINFO" > "${TMPDIR}/b/security/netbird/distinfo"
|
||||
|
||||
# Generate diff
|
||||
OUTPUT_FILE="${OUTPUT_DIR}/netbird-${NEW_VERSION}.diff"
|
||||
|
||||
echo "" >&2
|
||||
echo "Generating diff..." >&2
|
||||
|
||||
# Generate diff and clean up temp paths to show standard a/b paths
|
||||
(cd "${TMPDIR}" && diff -ruN "a/security/netbird" "b/security/netbird") > "$OUTPUT_FILE" || true
|
||||
|
||||
if [[ ! -s "$OUTPUT_FILE" ]]; then
|
||||
echo "Error: Generated diff is empty" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "" >&2
|
||||
echo "========================================="
|
||||
echo "Diff saved to: ${OUTPUT_FILE}"
|
||||
echo "========================================="
|
||||
echo ""
|
||||
cat "$OUTPUT_FILE"
|
||||
echo ""
|
||||
echo "========================================="
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo "1. Review the diff above"
|
||||
echo "2. Submit to https://bugs.freebsd.org/bugzilla/"
|
||||
echo "3. Use ./freebsd-port-issue-body.sh to generate the issue content"
|
||||
echo ""
|
||||
echo "For FreeBSD testing (optional but recommended):"
|
||||
echo " cd /usr/ports/security/netbird"
|
||||
echo " patch < ${OUTPUT_FILE}"
|
||||
echo " make stage && make stage-qa && make package && make install"
|
||||
echo " netbird status"
|
||||
echo " make deinstall"
|
||||
159
release_files/freebsd-port-issue-body.sh
Executable file
159
release_files/freebsd-port-issue-body.sh
Executable file
@@ -0,0 +1,159 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# FreeBSD Port Issue Body Generator for NetBird
|
||||
#
|
||||
# This script generates the issue body content for submitting a FreeBSD port update
|
||||
# to the FreeBSD Bugzilla at https://bugs.freebsd.org/bugzilla/
|
||||
#
|
||||
# Usage: ./freebsd-port-issue-body.sh [old_version] [new_version]
|
||||
# Example: ./freebsd-port-issue-body.sh 0.56.0 0.59.1
|
||||
#
|
||||
# If no versions are provided, the script will:
|
||||
# - Fetch OLD version from FreeBSD ports cgit (current version in ports tree)
|
||||
# - Fetch NEW version from latest NetBird GitHub release tag
|
||||
|
||||
set -e
|
||||
|
||||
GITHUB_REPO="netbirdio/netbird"
|
||||
PORTS_CGIT_URL="https://cgit.freebsd.org/ports/plain/security/netbird/Makefile"
|
||||
|
||||
fetch_current_ports_version() {
|
||||
echo "Fetching current version from FreeBSD ports..." >&2
|
||||
local makefile_content
|
||||
makefile_content=$(curl -sL "$PORTS_CGIT_URL" 2>/dev/null)
|
||||
if [[ -z "$makefile_content" ]]; then
|
||||
echo "Error: Could not fetch Makefile from FreeBSD ports" >&2
|
||||
return 1
|
||||
fi
|
||||
echo "$makefile_content" | grep -E "^DISTVERSION=" | sed 's/DISTVERSION=[[:space:]]*//' | tr -d '\t '
|
||||
return 0
|
||||
}
|
||||
|
||||
fetch_all_tags() {
|
||||
# Fetch tags from GitHub tags page (no rate limiting, no auth needed)
|
||||
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
|
||||
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
|
||||
sed 's/.*\/v//' | \
|
||||
sort -u -V
|
||||
return 0
|
||||
}
|
||||
|
||||
fetch_latest_github_release() {
|
||||
echo "Fetching latest release from GitHub..." >&2
|
||||
local latest
|
||||
|
||||
# Fetch from GitHub tags page
|
||||
latest=$(fetch_all_tags | tail -1)
|
||||
|
||||
if [[ -z "$latest" ]]; then
|
||||
# Fallback to GitHub API
|
||||
latest=$(curl -sL "https://api.github.com/repos/${GITHUB_REPO}/releases/latest" 2>/dev/null | \
|
||||
grep '"tag_name"' | sed 's/.*"tag_name": *"v\([^"]*\)".*/\1/')
|
||||
fi
|
||||
|
||||
if [[ -z "$latest" ]]; then
|
||||
echo "Error: Could not fetch latest release from GitHub" >&2
|
||||
return 1
|
||||
fi
|
||||
echo "$latest"
|
||||
return 0
|
||||
}
|
||||
|
||||
OLD_VERSION="${1:-}"
|
||||
NEW_VERSION="${2:-}"
|
||||
|
||||
if [[ -z "$OLD_VERSION" ]]; then
|
||||
OLD_VERSION=$(fetch_current_ports_version)
|
||||
if [[ -z "$OLD_VERSION" ]]; then
|
||||
echo "Error: Could not determine old version. Please provide it manually." >&2
|
||||
echo "Usage: $0 <old_version> <new_version>" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "Detected OLD version from FreeBSD ports: $OLD_VERSION" >&2
|
||||
fi
|
||||
|
||||
if [[ -z "$NEW_VERSION" ]]; then
|
||||
NEW_VERSION=$(fetch_latest_github_release)
|
||||
if [[ -z "$NEW_VERSION" ]]; then
|
||||
echo "Error: Could not determine new version. Please provide it manually." >&2
|
||||
echo "Usage: $0 <old_version> <new_version>" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "Detected NEW version from GitHub: $NEW_VERSION" >&2
|
||||
fi
|
||||
|
||||
if [[ "$OLD_VERSION" = "$NEW_VERSION" ]]; then
|
||||
echo "Warning: OLD and NEW versions are the same ($OLD_VERSION). Port may already be up to date." >&2
|
||||
fi
|
||||
|
||||
echo "" >&2
|
||||
|
||||
OUTPUT_DIR="${OUTPUT_DIR:-.}"
|
||||
|
||||
fetch_releases_between_versions() {
|
||||
echo "Fetching release history from GitHub..." >&2
|
||||
|
||||
# Fetch all tags and filter to those between OLD and NEW versions
|
||||
fetch_all_tags | \
|
||||
while read -r ver; do
|
||||
if [[ "$(printf '%s\n' "$OLD_VERSION" "$ver" | sort -V | head -n1)" = "$OLD_VERSION" ]] && \
|
||||
[[ "$(printf '%s\n' "$ver" "$NEW_VERSION" | sort -V | head -n1)" = "$ver" ]] && \
|
||||
[[ "$ver" != "$OLD_VERSION" ]]; then
|
||||
echo "$ver"
|
||||
fi
|
||||
done
|
||||
return 0
|
||||
}
|
||||
|
||||
generate_changelog_section() {
|
||||
local releases
|
||||
releases=$(fetch_releases_between_versions)
|
||||
|
||||
echo "Changelogs:"
|
||||
if [[ -n "$releases" ]]; then
|
||||
echo "$releases" | while read -r ver; do
|
||||
echo "https://github.com/${GITHUB_REPO}/releases/tag/v${ver}"
|
||||
done
|
||||
else
|
||||
echo "https://github.com/${GITHUB_REPO}/releases/tag/v${NEW_VERSION}"
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
OUTPUT_FILE="${OUTPUT_DIR}/netbird-${NEW_VERSION}-issue.txt"
|
||||
|
||||
cat << EOF > "$OUTPUT_FILE"
|
||||
BUGZILLA ISSUE DETAILS
|
||||
======================
|
||||
|
||||
Severity: Affects Some People
|
||||
|
||||
Summary: security/netbird: Update to ${NEW_VERSION}
|
||||
|
||||
Description:
|
||||
------------
|
||||
security/netbird: Update ${OLD_VERSION} => ${NEW_VERSION}
|
||||
|
||||
$(generate_changelog_section)
|
||||
|
||||
Commit log:
|
||||
https://github.com/${GITHUB_REPO}/compare/v${OLD_VERSION}...v${NEW_VERSION}
|
||||
EOF
|
||||
|
||||
echo "========================================="
|
||||
echo "Issue body saved to: ${OUTPUT_FILE}"
|
||||
echo "========================================="
|
||||
echo ""
|
||||
cat "$OUTPUT_FILE"
|
||||
echo ""
|
||||
echo "========================================="
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo "1. Go to https://bugs.freebsd.org/bugzilla/ and login"
|
||||
echo "2. Click 'Report an update or defect to a port'"
|
||||
echo "3. Fill in:"
|
||||
echo " - Severity: Affects Some People"
|
||||
echo " - Summary: security/netbird: Update to ${NEW_VERSION}"
|
||||
echo " - Description: Copy content from ${OUTPUT_FILE}"
|
||||
echo "4. Attach diff file: netbird-${NEW_VERSION}.diff"
|
||||
echo "5. Submit the bug report"
|
||||
@@ -111,6 +111,8 @@ func (c *GrpcClient) ready() bool {
|
||||
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
|
||||
// Blocking request. The result will be sent via msgHandler callback function
|
||||
func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||
backOff := defaultBackoff(ctx)
|
||||
|
||||
operation := func() error {
|
||||
log.Debugf("management connection state %v", c.conn.GetState())
|
||||
connState := c.conn.GetState()
|
||||
@@ -128,10 +130,10 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
|
||||
return err
|
||||
}
|
||||
|
||||
return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler)
|
||||
return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler, backOff)
|
||||
}
|
||||
|
||||
err := backoff.Retry(operation, defaultBackoff(ctx))
|
||||
err := backoff.Retry(operation, backOff)
|
||||
if err != nil {
|
||||
log.Warnf("exiting the Management service connection retry loop due to the unrecoverable error: %s", err)
|
||||
}
|
||||
@@ -140,7 +142,7 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
|
||||
}
|
||||
|
||||
func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info,
|
||||
msgHandler func(msg *proto.SyncResponse) error) error {
|
||||
msgHandler func(msg *proto.SyncResponse) error, backOff backoff.BackOff) error {
|
||||
ctx, cancelStream := context.WithCancel(ctx)
|
||||
defer cancelStream()
|
||||
|
||||
@@ -158,6 +160,9 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key,
|
||||
|
||||
// blocking until error
|
||||
err = c.receiveEvents(stream, serverPubKey, msgHandler)
|
||||
// we need this reset because after a successful connection and a consequent error, backoff lib doesn't
|
||||
// reset times and next try will start with a long delay
|
||||
backOff.Reset()
|
||||
if err != nil {
|
||||
c.notifyDisconnected(err)
|
||||
s, _ := gstatus.FromError(err)
|
||||
|
||||
@@ -488,6 +488,8 @@ components:
|
||||
description: Indicates whether the peer is ephemeral or not
|
||||
type: boolean
|
||||
example: false
|
||||
local_flags:
|
||||
$ref: '#/components/schemas/PeerLocalFlags'
|
||||
required:
|
||||
- city_name
|
||||
- connected
|
||||
@@ -514,6 +516,49 @@ components:
|
||||
- serial_number
|
||||
- extra_dns_labels
|
||||
- ephemeral
|
||||
PeerLocalFlags:
|
||||
type: object
|
||||
properties:
|
||||
rosenpass_enabled:
|
||||
description: Indicates whether Rosenpass is enabled on this peer
|
||||
type: boolean
|
||||
example: true
|
||||
rosenpass_permissive:
|
||||
description: Indicates whether Rosenpass is in permissive mode or not
|
||||
type: boolean
|
||||
example: false
|
||||
server_ssh_allowed:
|
||||
description: Indicates whether SSH access this peer is allowed or not
|
||||
type: boolean
|
||||
example: true
|
||||
disable_client_routes:
|
||||
description: Indicates whether client routes are disabled on this peer or not
|
||||
type: boolean
|
||||
example: false
|
||||
disable_server_routes:
|
||||
description: Indicates whether server routes are disabled on this peer or not
|
||||
type: boolean
|
||||
example: false
|
||||
disable_dns:
|
||||
description: Indicates whether DNS management is disabled on this peer or not
|
||||
type: boolean
|
||||
example: false
|
||||
disable_firewall:
|
||||
description: Indicates whether firewall management is disabled on this peer or not
|
||||
type: boolean
|
||||
example: false
|
||||
block_lan_access:
|
||||
description: Indicates whether LAN access is blocked on this peer when used as a routing peer
|
||||
type: boolean
|
||||
example: false
|
||||
block_inbound:
|
||||
description: Indicates whether inbound traffic is blocked on this peer
|
||||
type: boolean
|
||||
example: false
|
||||
lazy_connection_enabled:
|
||||
description: Indicates whether lazy connection is enabled on this peer
|
||||
type: boolean
|
||||
example: false
|
||||
PeerTemporaryAccessRequest:
|
||||
type: object
|
||||
properties:
|
||||
@@ -936,7 +981,7 @@ components:
|
||||
protocol:
|
||||
description: Policy rule type of the traffic
|
||||
type: string
|
||||
enum: ["all", "tcp", "udp", "icmp"]
|
||||
enum: ["all", "tcp", "udp", "icmp", "netbird-ssh"]
|
||||
example: "tcp"
|
||||
ports:
|
||||
description: Policy rule affected ports
|
||||
@@ -949,6 +994,14 @@ components:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/RulePortRange'
|
||||
authorized_groups:
|
||||
description: Map of user group ids to a list of local users
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
example: "group1"
|
||||
required:
|
||||
- name
|
||||
- enabled
|
||||
|
||||
@@ -130,10 +130,11 @@ const (
|
||||
|
||||
// Defines values for PolicyRuleProtocol.
|
||||
const (
|
||||
PolicyRuleProtocolAll PolicyRuleProtocol = "all"
|
||||
PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp"
|
||||
PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp"
|
||||
PolicyRuleProtocolUdp PolicyRuleProtocol = "udp"
|
||||
PolicyRuleProtocolAll PolicyRuleProtocol = "all"
|
||||
PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp"
|
||||
PolicyRuleProtocolNetbirdSsh PolicyRuleProtocol = "netbird-ssh"
|
||||
PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp"
|
||||
PolicyRuleProtocolUdp PolicyRuleProtocol = "udp"
|
||||
)
|
||||
|
||||
// Defines values for PolicyRuleMinimumAction.
|
||||
@@ -144,10 +145,11 @@ const (
|
||||
|
||||
// Defines values for PolicyRuleMinimumProtocol.
|
||||
const (
|
||||
PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all"
|
||||
PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp"
|
||||
PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp"
|
||||
PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp"
|
||||
PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all"
|
||||
PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp"
|
||||
PolicyRuleMinimumProtocolNetbirdSsh PolicyRuleMinimumProtocol = "netbird-ssh"
|
||||
PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp"
|
||||
PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp"
|
||||
)
|
||||
|
||||
// Defines values for PolicyRuleUpdateAction.
|
||||
@@ -158,10 +160,11 @@ const (
|
||||
|
||||
// Defines values for PolicyRuleUpdateProtocol.
|
||||
const (
|
||||
PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all"
|
||||
PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp"
|
||||
PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp"
|
||||
PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp"
|
||||
PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all"
|
||||
PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp"
|
||||
PolicyRuleUpdateProtocolNetbirdSsh PolicyRuleUpdateProtocol = "netbird-ssh"
|
||||
PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp"
|
||||
PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp"
|
||||
)
|
||||
|
||||
// Defines values for ResourceType.
|
||||
@@ -1077,7 +1080,8 @@ type Peer struct {
|
||||
LastLogin time.Time `json:"last_login"`
|
||||
|
||||
// LastSeen Last time peer connected to Netbird's management service
|
||||
LastSeen time.Time `json:"last_seen"`
|
||||
LastSeen time.Time `json:"last_seen"`
|
||||
LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"`
|
||||
|
||||
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
|
||||
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
||||
@@ -1167,7 +1171,8 @@ type PeerBatch struct {
|
||||
LastLogin time.Time `json:"last_login"`
|
||||
|
||||
// LastSeen Last time peer connected to Netbird's management service
|
||||
LastSeen time.Time `json:"last_seen"`
|
||||
LastSeen time.Time `json:"last_seen"`
|
||||
LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"`
|
||||
|
||||
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
|
||||
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
||||
@@ -1197,6 +1202,39 @@ type PeerBatch struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// PeerLocalFlags defines model for PeerLocalFlags.
|
||||
type PeerLocalFlags struct {
|
||||
// BlockInbound Indicates whether inbound traffic is blocked on this peer
|
||||
BlockInbound *bool `json:"block_inbound,omitempty"`
|
||||
|
||||
// BlockLanAccess Indicates whether LAN access is blocked on this peer when used as a routing peer
|
||||
BlockLanAccess *bool `json:"block_lan_access,omitempty"`
|
||||
|
||||
// DisableClientRoutes Indicates whether client routes are disabled on this peer or not
|
||||
DisableClientRoutes *bool `json:"disable_client_routes,omitempty"`
|
||||
|
||||
// DisableDns Indicates whether DNS management is disabled on this peer or not
|
||||
DisableDns *bool `json:"disable_dns,omitempty"`
|
||||
|
||||
// DisableFirewall Indicates whether firewall management is disabled on this peer or not
|
||||
DisableFirewall *bool `json:"disable_firewall,omitempty"`
|
||||
|
||||
// DisableServerRoutes Indicates whether server routes are disabled on this peer or not
|
||||
DisableServerRoutes *bool `json:"disable_server_routes,omitempty"`
|
||||
|
||||
// LazyConnectionEnabled Indicates whether lazy connection is enabled on this peer
|
||||
LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"`
|
||||
|
||||
// RosenpassEnabled Indicates whether Rosenpass is enabled on this peer
|
||||
RosenpassEnabled *bool `json:"rosenpass_enabled,omitempty"`
|
||||
|
||||
// RosenpassPermissive Indicates whether Rosenpass is in permissive mode or not
|
||||
RosenpassPermissive *bool `json:"rosenpass_permissive,omitempty"`
|
||||
|
||||
// ServerSshAllowed Indicates whether SSH access this peer is allowed or not
|
||||
ServerSshAllowed *bool `json:"server_ssh_allowed,omitempty"`
|
||||
}
|
||||
|
||||
// PeerMinimum defines model for PeerMinimum.
|
||||
type PeerMinimum struct {
|
||||
// Id Peer ID
|
||||
@@ -1349,6 +1387,9 @@ type PolicyRule struct {
|
||||
// Action Policy rule accept or drops packets
|
||||
Action PolicyRuleAction `json:"action"`
|
||||
|
||||
// AuthorizedGroups Map of user group ids to a list of local users
|
||||
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
|
||||
|
||||
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
||||
Bidirectional bool `json:"bidirectional"`
|
||||
|
||||
@@ -1393,6 +1434,9 @@ type PolicyRuleMinimum struct {
|
||||
// Action Policy rule accept or drops packets
|
||||
Action PolicyRuleMinimumAction `json:"action"`
|
||||
|
||||
// AuthorizedGroups Map of user group ids to a list of local users
|
||||
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
|
||||
|
||||
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
||||
Bidirectional bool `json:"bidirectional"`
|
||||
|
||||
@@ -1426,6 +1470,9 @@ type PolicyRuleUpdate struct {
|
||||
// Action Policy rule accept or drops packets
|
||||
Action PolicyRuleUpdateAction `json:"action"`
|
||||
|
||||
// AuthorizedGroups Map of user group ids to a list of local users
|
||||
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
|
||||
|
||||
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
||||
Bidirectional bool `json:"bidirectional"`
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -332,6 +332,24 @@ message NetworkMap {
|
||||
bool routesFirewallRulesIsEmpty = 11;
|
||||
|
||||
repeated ForwardingRule forwardingRules = 12;
|
||||
|
||||
// SSHAuth represents SSH authorization configuration
|
||||
SSHAuth sshAuth = 13;
|
||||
}
|
||||
|
||||
message SSHAuth {
|
||||
// UserIDClaim is the JWT claim to be used to get the users ID
|
||||
string UserIDClaim = 1;
|
||||
|
||||
// AuthorizedUsers is a list of hashed user IDs authorized to access this peer via SSH
|
||||
repeated bytes AuthorizedUsers = 2;
|
||||
|
||||
// MachineUsers is a map of machine user names to their corresponding indexes in the AuthorizedUsers list
|
||||
map<string, MachineUserIndexes> machine_users = 3;
|
||||
}
|
||||
|
||||
message MachineUserIndexes {
|
||||
repeated uint32 indexes = 1;
|
||||
}
|
||||
|
||||
// RemotePeerConfig represents a configuration of a remote peer.
|
||||
|
||||
28
shared/sshauth/userhash.go
Normal file
28
shared/sshauth/userhash.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package sshauth
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
|
||||
"golang.org/x/crypto/blake2b"
|
||||
)
|
||||
|
||||
// UserIDHash represents a hashed user ID (BLAKE2b-128)
|
||||
type UserIDHash [16]byte
|
||||
|
||||
// HashUserID hashes a user ID using BLAKE2b-128 and returns the hash value
|
||||
// This function must produce the same hash on both client and management server
|
||||
func HashUserID(userID string) (UserIDHash, error) {
|
||||
hash, err := blake2b.New(16, nil)
|
||||
if err != nil {
|
||||
return UserIDHash{}, err
|
||||
}
|
||||
hash.Write([]byte(userID))
|
||||
var result UserIDHash
|
||||
copy(result[:], hash.Sum(nil))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// String returns the hexadecimal string representation of the hash
|
||||
func (h UserIDHash) String() string {
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
210
shared/sshauth/userhash_test.go
Normal file
210
shared/sshauth/userhash_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package sshauth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHashUserID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID string
|
||||
}{
|
||||
{
|
||||
name: "simple user ID",
|
||||
userID: "user@example.com",
|
||||
},
|
||||
{
|
||||
name: "UUID format",
|
||||
userID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
},
|
||||
{
|
||||
name: "numeric ID",
|
||||
userID: "12345",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
userID: "",
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
userID: "user+test@domain.com",
|
||||
},
|
||||
{
|
||||
name: "unicode characters",
|
||||
userID: "用户@example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hash, err := HashUserID(tt.userID)
|
||||
if err != nil {
|
||||
t.Errorf("HashUserID() error = %v, want nil", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify hash is non-zero for non-empty inputs
|
||||
if tt.userID != "" && hash == [16]byte{} {
|
||||
t.Errorf("HashUserID() returned zero hash for non-empty input")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashUserID_Consistency(t *testing.T) {
|
||||
userID := "test@example.com"
|
||||
|
||||
hash1, err1 := HashUserID(userID)
|
||||
if err1 != nil {
|
||||
t.Fatalf("First HashUserID() error = %v", err1)
|
||||
}
|
||||
|
||||
hash2, err2 := HashUserID(userID)
|
||||
if err2 != nil {
|
||||
t.Fatalf("Second HashUserID() error = %v", err2)
|
||||
}
|
||||
|
||||
if hash1 != hash2 {
|
||||
t.Errorf("HashUserID() is not consistent: got %v and %v for same input", hash1, hash2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashUserID_Uniqueness(t *testing.T) {
|
||||
tests := []struct {
|
||||
userID1 string
|
||||
userID2 string
|
||||
}{
|
||||
{"user1@example.com", "user2@example.com"},
|
||||
{"alice@domain.com", "bob@domain.com"},
|
||||
{"test", "test1"},
|
||||
{"", "a"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
hash1, err1 := HashUserID(tt.userID1)
|
||||
if err1 != nil {
|
||||
t.Fatalf("HashUserID(%s) error = %v", tt.userID1, err1)
|
||||
}
|
||||
|
||||
hash2, err2 := HashUserID(tt.userID2)
|
||||
if err2 != nil {
|
||||
t.Fatalf("HashUserID(%s) error = %v", tt.userID2, err2)
|
||||
}
|
||||
|
||||
if hash1 == hash2 {
|
||||
t.Errorf("HashUserID() collision: %s and %s produced same hash %v", tt.userID1, tt.userID2, hash1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserIDHash_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hash UserIDHash
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "zero hash",
|
||||
hash: [16]byte{},
|
||||
expected: "00000000000000000000000000000000",
|
||||
},
|
||||
{
|
||||
name: "small value",
|
||||
hash: [16]byte{15: 0xff},
|
||||
expected: "000000000000000000000000000000ff",
|
||||
},
|
||||
{
|
||||
name: "large value",
|
||||
hash: [16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe},
|
||||
expected: "0000000000000000deadbeefcafebabe",
|
||||
},
|
||||
{
|
||||
name: "max value",
|
||||
hash: [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
expected: "ffffffffffffffffffffffffffffffff",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.hash.String()
|
||||
if result != tt.expected {
|
||||
t.Errorf("UserIDHash.String() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserIDHash_String_Length(t *testing.T) {
|
||||
// Test that String() always returns 32 hex characters (16 bytes * 2)
|
||||
userID := "test@example.com"
|
||||
hash, err := HashUserID(userID)
|
||||
if err != nil {
|
||||
t.Fatalf("HashUserID() error = %v", err)
|
||||
}
|
||||
|
||||
result := hash.String()
|
||||
if len(result) != 32 {
|
||||
t.Errorf("UserIDHash.String() length = %d, want 32", len(result))
|
||||
}
|
||||
|
||||
// Verify it's valid hex
|
||||
for i, c := range result {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||
t.Errorf("UserIDHash.String() contains non-hex character at position %d: %c", i, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashUserID_KnownValues(t *testing.T) {
|
||||
// Test with known BLAKE2b-128 values to ensure correct implementation
|
||||
tests := []struct {
|
||||
name string
|
||||
userID string
|
||||
expected UserIDHash
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
userID: "",
|
||||
// BLAKE2b-128 of empty string
|
||||
expected: [16]byte{0xca, 0xe6, 0x69, 0x41, 0xd9, 0xef, 0xbd, 0x40, 0x4e, 0x4d, 0x88, 0x75, 0x8e, 0xa6, 0x76, 0x70},
|
||||
},
|
||||
{
|
||||
name: "single character 'a'",
|
||||
userID: "a",
|
||||
// BLAKE2b-128 of "a"
|
||||
expected: [16]byte{0x27, 0xc3, 0x5e, 0x6e, 0x93, 0x73, 0x87, 0x7f, 0x29, 0xe5, 0x62, 0x46, 0x4e, 0x46, 0x49, 0x7e},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hash, err := HashUserID(tt.userID)
|
||||
if err != nil {
|
||||
t.Errorf("HashUserID() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if hash != tt.expected {
|
||||
t.Errorf("HashUserID(%q) = %x, want %x",
|
||||
tt.userID, hash, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHashUserID(b *testing.B) {
|
||||
userID := "user@example.com"
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = HashUserID(userID)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUserIDHash_String(b *testing.B) {
|
||||
hash := UserIDHash([16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe})
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = hash.String()
|
||||
}
|
||||
}
|
||||
@@ -2,12 +2,10 @@ package semaphoregroup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SemaphoreGroup is a custom type that combines sync.WaitGroup and a semaphore.
|
||||
type SemaphoreGroup struct {
|
||||
waitGroup sync.WaitGroup
|
||||
semaphore chan struct{}
|
||||
}
|
||||
|
||||
@@ -18,31 +16,18 @@ func NewSemaphoreGroup(limit int) *SemaphoreGroup {
|
||||
}
|
||||
}
|
||||
|
||||
// Add increments the internal WaitGroup counter and acquires a semaphore slot.
|
||||
func (sg *SemaphoreGroup) Add(ctx context.Context) {
|
||||
sg.waitGroup.Add(1)
|
||||
|
||||
// Add acquire a slot
|
||||
func (sg *SemaphoreGroup) Add(ctx context.Context) error {
|
||||
// Acquire semaphore slot
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
return ctx.Err()
|
||||
case sg.semaphore <- struct{}{}:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Done decrements the internal WaitGroup counter and releases a semaphore slot.
|
||||
func (sg *SemaphoreGroup) Done(ctx context.Context) {
|
||||
sg.waitGroup.Done()
|
||||
|
||||
// Release semaphore slot
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-sg.semaphore:
|
||||
}
|
||||
}
|
||||
|
||||
// Wait waits until the internal WaitGroup counter is zero.
|
||||
func (sg *SemaphoreGroup) Wait() {
|
||||
sg.waitGroup.Wait()
|
||||
// Done releases a slot. Must be called after a successful Add.
|
||||
func (sg *SemaphoreGroup) Done() {
|
||||
<-sg.semaphore
|
||||
}
|
||||
|
||||
@@ -2,65 +2,89 @@ package semaphoregroup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSemaphoreGroup(t *testing.T) {
|
||||
semGroup := NewSemaphoreGroup(2)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
semGroup.Add(context.Background())
|
||||
go func(id int) {
|
||||
defer semGroup.Done(context.Background())
|
||||
|
||||
got := len(semGroup.semaphore)
|
||||
if got == 0 {
|
||||
t.Errorf("Expected semaphore length > 0 , got 0")
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
t.Logf("Goroutine %d is running\n", id)
|
||||
}(i)
|
||||
}
|
||||
|
||||
semGroup.Wait()
|
||||
|
||||
want := 0
|
||||
got := len(semGroup.semaphore)
|
||||
if got != want {
|
||||
t.Errorf("Expected semaphore length %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSemaphoreGroupContext(t *testing.T) {
|
||||
semGroup := NewSemaphoreGroup(1)
|
||||
semGroup.Add(context.Background())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
_ = semGroup.Add(context.Background())
|
||||
|
||||
ctxTimeout, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
t.Cleanup(cancel)
|
||||
rChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
semGroup.Add(ctx)
|
||||
rChan <- struct{}{}
|
||||
}()
|
||||
select {
|
||||
case <-rChan:
|
||||
case <-time.NewTimer(2 * time.Second).C:
|
||||
t.Error("Adding to semaphore group should not block when context is not done")
|
||||
}
|
||||
|
||||
semGroup.Done(context.Background())
|
||||
|
||||
ctxDone, cancelDone := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
t.Cleanup(cancelDone)
|
||||
go func() {
|
||||
semGroup.Done(ctxDone)
|
||||
rChan <- struct{}{}
|
||||
}()
|
||||
select {
|
||||
case <-rChan:
|
||||
case <-time.NewTimer(2 * time.Second).C:
|
||||
t.Error("Releasing from semaphore group should not block when context is not done")
|
||||
if err := semGroup.Add(ctxTimeout); err == nil {
|
||||
t.Error("Adding to semaphore group should not block")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSemaphoreGroupFreeUp(t *testing.T) {
|
||||
semGroup := NewSemaphoreGroup(1)
|
||||
_ = semGroup.Add(context.Background())
|
||||
semGroup.Done()
|
||||
|
||||
ctxTimeout, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
t.Cleanup(cancel)
|
||||
if err := semGroup.Add(ctxTimeout); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSemaphoreGroupCanceledContext(t *testing.T) {
|
||||
semGroup := NewSemaphoreGroup(1)
|
||||
_ = semGroup.Add(context.Background())
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
if err := semGroup.Add(ctx); err == nil {
|
||||
t.Error("Add should return error when context is already canceled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSemaphoreGroupCancelWhileWaiting(t *testing.T) {
|
||||
semGroup := NewSemaphoreGroup(1)
|
||||
_ = semGroup.Add(context.Background())
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
errChan <- semGroup.Add(ctx)
|
||||
}()
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
if err := <-errChan; err == nil {
|
||||
t.Error("Add should return error when context is canceled while waiting")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSemaphoreGroupHighConcurrency(t *testing.T) {
|
||||
const limit = 10
|
||||
const numGoroutines = 100
|
||||
|
||||
semGroup := NewSemaphoreGroup(limit)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := semGroup.Add(context.Background()); err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
semGroup.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all slots were released
|
||||
if got := len(semGroup.semaphore); got != 0 {
|
||||
t.Errorf("Expected semaphore to be empty, got %d slots occupied", got)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user