mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-23 18:56:38 +00:00
Compare commits
22 Commits
coderabbit
...
v0.61.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
08b782d6ba | ||
|
|
80a312cc9c | ||
|
|
9ba067391f | ||
|
|
7ac65bf1ad | ||
|
|
2e9c316852 | ||
|
|
96cdd56902 | ||
|
|
9ed1437442 | ||
|
|
a8604ef51c | ||
|
|
d88e046d00 | ||
|
|
1d2c7776fd | ||
|
|
4035f07248 | ||
|
|
ef2721f4e1 | ||
|
|
e11970e32e | ||
|
|
38f9d5ed58 | ||
|
|
b6a327e0c9 | ||
|
|
67f7b2404e | ||
|
|
73201c4f3e | ||
|
|
33d1761fe8 | ||
|
|
aa914a0f26 | ||
|
|
ab6a9e85de | ||
|
|
d3b123c76d | ||
|
|
fc4932a23f |
2
.github/workflows/golang-test-freebsd.yml
vendored
2
.github/workflows/golang-test-freebsd.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
|||||||
# check all component except management, since we do not support management server on freebsd
|
# check all component except management, since we do not support management server on freebsd
|
||||||
time go test -timeout 1m -failfast ./base62/...
|
time go test -timeout 1m -failfast ./base62/...
|
||||||
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
||||||
time go test -timeout 8m -failfast -p 1 ./client/...
|
time go test -timeout 8m -failfast -v -p 1 ./client/...
|
||||||
time go test -timeout 1m -failfast ./dns/...
|
time go test -timeout 1m -failfast ./dns/...
|
||||||
time go test -timeout 1m -failfast ./encryption/...
|
time go test -timeout 1m -failfast ./encryption/...
|
||||||
time go test -timeout 1m -failfast ./formatter/...
|
time go test -timeout 1m -failfast ./formatter/...
|
||||||
|
|||||||
96
.github/workflows/release.yml
vendored
96
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.23"
|
SIGN_PIPE_VER: "v0.1.0"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
@@ -19,6 +19,100 @@ concurrency:
|
|||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
release_freebsd_port:
|
||||||
|
name: "FreeBSD Port / Build & Test"
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Generate FreeBSD port diff
|
||||||
|
run: bash release_files/freebsd-port-diff.sh
|
||||||
|
|
||||||
|
- name: Generate FreeBSD port issue body
|
||||||
|
run: bash release_files/freebsd-port-issue-body.sh
|
||||||
|
|
||||||
|
- name: Check if diff was generated
|
||||||
|
id: check_diff
|
||||||
|
run: |
|
||||||
|
if ls netbird-*.diff 1> /dev/null 2>&1; then
|
||||||
|
echo "diff_exists=true" >> $GITHUB_OUTPUT
|
||||||
|
else
|
||||||
|
echo "diff_exists=false" >> $GITHUB_OUTPUT
|
||||||
|
echo "No diff file generated (port may already be up to date)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Extract version
|
||||||
|
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||||
|
id: version
|
||||||
|
run: |
|
||||||
|
VERSION=$(ls netbird-*.diff | sed 's/netbird-\(.*\)\.diff/\1/')
|
||||||
|
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||||
|
echo "Generated files for version: $VERSION"
|
||||||
|
cat netbird-*.diff
|
||||||
|
|
||||||
|
- name: Test FreeBSD port
|
||||||
|
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||||
|
uses: vmactions/freebsd-vm@v1
|
||||||
|
with:
|
||||||
|
usesh: true
|
||||||
|
copyback: false
|
||||||
|
release: "15.0"
|
||||||
|
prepare: |
|
||||||
|
# Install required packages
|
||||||
|
pkg install -y git curl portlint go
|
||||||
|
|
||||||
|
# Install Go for building
|
||||||
|
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
|
||||||
|
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||||
|
curl -LO "$GO_URL"
|
||||||
|
tar -C /usr/local -xzf "$GO_TARBALL"
|
||||||
|
|
||||||
|
# Clone ports tree (shallow, only what we need)
|
||||||
|
git clone --depth 1 --filter=blob:none https://git.FreeBSD.org/ports.git /usr/ports
|
||||||
|
cd /usr/ports
|
||||||
|
|
||||||
|
run: |
|
||||||
|
set -e -x
|
||||||
|
export PATH=$PATH:/usr/local/go/bin
|
||||||
|
|
||||||
|
# Find the diff file
|
||||||
|
echo "Finding diff file..."
|
||||||
|
DIFF_FILE=$(find $PWD -name "netbird-*.diff" -type f 2>/dev/null | head -1)
|
||||||
|
echo "Found: $DIFF_FILE"
|
||||||
|
|
||||||
|
if [[ -z "$DIFF_FILE" ]]; then
|
||||||
|
echo "ERROR: Could not find diff file"
|
||||||
|
find ~ -name "*.diff" -type f 2>/dev/null || true
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Apply the generated diff from /usr/ports (diff has a/security/netbird/... paths)
|
||||||
|
cd /usr/ports
|
||||||
|
patch -p1 -V none < "$DIFF_FILE"
|
||||||
|
|
||||||
|
# Show patched Makefile
|
||||||
|
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
|
||||||
|
|
||||||
|
cd /usr/ports/security/netbird
|
||||||
|
export BATCH=yes
|
||||||
|
make package
|
||||||
|
pkg add ./work/pkg/netbird-*.pkg
|
||||||
|
|
||||||
|
netbird version | grep "$version"
|
||||||
|
|
||||||
|
echo "FreeBSD port test completed successfully!"
|
||||||
|
|
||||||
|
- name: Upload FreeBSD port files
|
||||||
|
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: freebsd-port-files
|
||||||
|
path: |
|
||||||
|
./netbird-*-issue.txt
|
||||||
|
./netbird-*.diff
|
||||||
|
retention-days: 30
|
||||||
|
|
||||||
release:
|
release:
|
||||||
runs-on: ubuntu-latest-m
|
runs-on: ubuntu-latest-m
|
||||||
env:
|
env:
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
|
|||||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
||||||
|
|
||||||
<p float="left" align="middle">
|
<p float="left" align="middle">
|
||||||
<img src="https://docs.netbird.io/docs-static/img/architecture/high-level-dia.png" width="700"/>
|
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||||
|
|||||||
@@ -386,6 +386,97 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := exec.LookPath("iptables-save"); err != nil {
|
||||||
|
t.Skipf("iptables-save not available on this system: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// First ensure iptables-nft tables exist by running iptables-save
|
||||||
|
stdout, stderr := runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err, "failed to create manager")
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := manager.Close(nil)
|
||||||
|
require.NoError(t, err, "failed to reset manager state")
|
||||||
|
|
||||||
|
// Verify iptables output after reset
|
||||||
|
stdout, stderr := runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
})
|
||||||
|
|
||||||
|
const octet2Count = 25
|
||||||
|
const octet3Count = 255
|
||||||
|
prefixes := make([]netip.Prefix, 0, (octet2Count-1)*(octet3Count-1))
|
||||||
|
for i := 1; i < octet2Count; i++ {
|
||||||
|
for j := 1; j < octet3Count; j++ {
|
||||||
|
addr := netip.AddrFrom4([4]byte{192, byte(j), byte(i), 0})
|
||||||
|
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
prefixes,
|
||||||
|
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to add route filtering rule")
|
||||||
|
|
||||||
|
stdout, stderr = runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := exec.LookPath("iptables-save"); err != nil {
|
||||||
|
t.Skipf("iptables-save not available on this system: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// First ensure iptables-nft tables exist by running iptables-save
|
||||||
|
stdout, stderr := runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err, "failed to create manager")
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := manager.Close(nil)
|
||||||
|
require.NoError(t, err, "failed to reset manager state")
|
||||||
|
|
||||||
|
// Verify iptables output after reset
|
||||||
|
stdout, stderr := runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err = manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{},
|
||||||
|
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to add route filtering rule")
|
||||||
|
|
||||||
|
stdout, stderr = runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
}
|
||||||
|
|
||||||
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
|
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
require.Equal(t, len(got), len(want), "expression count mismatch")
|
require.Equal(t, len(got), len(want), "expression count mismatch")
|
||||||
|
|||||||
@@ -48,9 +48,11 @@ const (
|
|||||||
|
|
||||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||||
ipTCPHeaderMinSize = 40
|
ipTCPHeaderMinSize = 40
|
||||||
)
|
|
||||||
|
|
||||||
const refreshRulesMapError = "refresh rules map: %w"
|
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
|
||||||
|
maxPrefixesSet = 1500
|
||||||
|
refreshRulesMapError = "refresh rules map: %w"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
|
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
|
||||||
@@ -513,16 +515,35 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
elements := convertPrefixesToSet(prefixes)
|
elements := convertPrefixesToSet(prefixes)
|
||||||
if err := r.conn.AddSet(nfset, elements); err != nil {
|
nElements := len(elements)
|
||||||
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
maxElements := maxPrefixesSet * 2
|
||||||
|
initialElements := elements[:min(maxElements, nElements)]
|
||||||
|
|
||||||
|
if err := r.conn.AddSet(nfset, initialElements); err != nil {
|
||||||
|
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
|
||||||
|
}
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return nil, fmt.Errorf("flush error: %w", err)
|
return nil, fmt.Errorf("flush error: %w", err)
|
||||||
}
|
}
|
||||||
|
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
|
||||||
|
|
||||||
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
var subEnd int
|
||||||
|
for subStart := maxElements; subStart < nElements; subStart += maxElements {
|
||||||
|
subEnd = min(subStart+maxElements, nElements)
|
||||||
|
subElement := elements[subStart:subEnd]
|
||||||
|
nSubPrefixes := len(subElement) / 2
|
||||||
|
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
|
||||||
|
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
|
||||||
|
return nil, fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, setName, err)
|
||||||
|
}
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush error: %w", err)
|
||||||
|
}
|
||||||
|
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
|
||||||
return nfset, nil
|
return nfset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -45,10 +46,31 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrInvalidTunnelFD is returned when the tunnel file descriptor is invalid (0).
|
||||||
|
// This typically means the Swift code couldn't find the utun control socket.
|
||||||
|
var ErrInvalidTunnelFD = fmt.Errorf("invalid tunnel file descriptor: fd is 0 (Swift failed to locate utun socket)")
|
||||||
|
|
||||||
func (t *TunDevice) Create() (WGConfigurer, error) {
|
func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||||
log.Infof("create tun interface")
|
log.Infof("create tun interface")
|
||||||
|
|
||||||
dupTunFd, err := unix.Dup(t.tunFd)
|
var tunDevice tun.Device
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Validate the tunnel file descriptor.
|
||||||
|
// On iOS/tvOS, the FD must be provided by the NEPacketTunnelProvider.
|
||||||
|
// A value of 0 means the Swift code couldn't find the utun control socket
|
||||||
|
// (the low-level APIs like ctl_info, sockaddr_ctl may not be exposed in
|
||||||
|
// tvOS SDK headers). This is a hard error - there's no viable fallback
|
||||||
|
// since tun.CreateTUN() cannot work within the iOS/tvOS sandbox.
|
||||||
|
if t.tunFd == 0 {
|
||||||
|
log.Errorf("Tunnel file descriptor is 0 - Swift code failed to locate the utun control socket. " +
|
||||||
|
"On tvOS, ensure the NEPacketTunnelProvider is properly configured and the tunnel is started.")
|
||||||
|
return nil, ErrInvalidTunnelFD
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normal iOS/tvOS path: use the provided file descriptor from NEPacketTunnelProvider
|
||||||
|
var dupTunFd int
|
||||||
|
dupTunFd, err = unix.Dup(t.tunFd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Unable to dup tun fd: %v", err)
|
log.Errorf("Unable to dup tun fd: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -60,7 +82,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
|||||||
_ = unix.Close(dupTunFd)
|
_ = unix.Close(dupTunFd)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tunDevice, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
|
tunDevice, err = tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Unable to create new tun device from fd: %v", err)
|
log.Errorf("Unable to create new tun device from fd: %v", err)
|
||||||
_ = unix.Close(dupTunFd)
|
_ = unix.Close(dupTunFd)
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ type DefaultServer struct {
|
|||||||
updateSerial uint64
|
updateSerial uint64
|
||||||
previousConfigHash uint64
|
previousConfigHash uint64
|
||||||
currentConfig HostDNSConfig
|
currentConfig HostDNSConfig
|
||||||
|
currentConfigHash uint64
|
||||||
handlerChain *HandlerChain
|
handlerChain *HandlerChain
|
||||||
extraDomains map[domain.Domain]int
|
extraDomains map[domain.Domain]int
|
||||||
|
|
||||||
@@ -207,6 +208,7 @@ func newDefaultServer(
|
|||||||
hostsDNSHolder: newHostsDNSHolder(),
|
hostsDNSHolder: newHostsDNSHolder(),
|
||||||
hostManager: &noopHostConfigurator{},
|
hostManager: &noopHostConfigurator{},
|
||||||
mgmtCacheResolver: mgmtCacheResolver,
|
mgmtCacheResolver: mgmtCacheResolver,
|
||||||
|
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
|
||||||
}
|
}
|
||||||
|
|
||||||
// register with root zone, handler chain takes care of the routing
|
// register with root zone, handler chain takes care of the routing
|
||||||
@@ -586,8 +588,29 @@ func (s *DefaultServer) applyHostConfig() {
|
|||||||
|
|
||||||
log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
|
log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
|
||||||
|
|
||||||
|
hash, err := hashstructure.Hash(config, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||||
|
ZeroNil: true,
|
||||||
|
IgnoreZeroValue: true,
|
||||||
|
SlicesAsSets: true,
|
||||||
|
UseStringer: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("unable to hash the host dns configuration, will apply config anyway: %s", err)
|
||||||
|
// Fall through to apply config anyway (fail-safe approach)
|
||||||
|
} else if s.currentConfigHash == hash {
|
||||||
|
log.Debugf("not applying host config as there are no changes")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("applying host config as there are changes")
|
||||||
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
|
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
|
||||||
log.Errorf("failed to apply DNS host manager update: %v", err)
|
log.Errorf("failed to apply DNS host manager update: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only update hash if it was computed successfully and config was applied
|
||||||
|
if err == nil {
|
||||||
|
s.currentConfigHash = hash
|
||||||
}
|
}
|
||||||
|
|
||||||
s.registerFallback(config)
|
s.registerFallback(config)
|
||||||
|
|||||||
@@ -1602,7 +1602,10 @@ func TestExtraDomains(t *testing.T) {
|
|||||||
"other.example.com.",
|
"other.example.com.",
|
||||||
"duplicate.example.com.",
|
"duplicate.example.com.",
|
||||||
},
|
},
|
||||||
applyHostConfigCall: 4,
|
// Expect 3 calls instead of 4 because when deregistering duplicate.example.com,
|
||||||
|
// the domain remains in the config (ref count goes from 2 to 1), so the host
|
||||||
|
// config hash doesn't change and applyDNSConfig is not called.
|
||||||
|
applyHostConfigCall: 3,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Config update with new domains after registration",
|
name: "Config update with new domains after registration",
|
||||||
@@ -1657,7 +1660,10 @@ func TestExtraDomains(t *testing.T) {
|
|||||||
expectedMatchOnly: []string{
|
expectedMatchOnly: []string{
|
||||||
"extra.example.com.",
|
"extra.example.com.",
|
||||||
},
|
},
|
||||||
applyHostConfigCall: 3,
|
// Expect 2 calls instead of 3 because when deregistering protected.example.com,
|
||||||
|
// it's removed from extraDomains but still remains in the config (from customZones),
|
||||||
|
// so the host config hash doesn't change and applyDNSConfig is not called.
|
||||||
|
applyHostConfigCall: 2,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Register domain that is part of nameserver group",
|
name: "Register domain that is part of nameserver group",
|
||||||
|
|||||||
@@ -1121,6 +1121,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
|
|
||||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||||
|
|
||||||
|
// Filter out own peer from the remote peers list
|
||||||
|
localPubKey := e.config.WgPrivateKey.PublicKey().String()
|
||||||
|
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
|
||||||
|
for _, p := range networkMap.GetRemotePeers() {
|
||||||
|
if p.GetWgPubKey() != localPubKey {
|
||||||
|
remotePeers = append(remotePeers, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// cleanup request, most likely our peer has been deleted
|
// cleanup request, most likely our peer has been deleted
|
||||||
if networkMap.GetRemotePeersIsEmpty() {
|
if networkMap.GetRemotePeersIsEmpty() {
|
||||||
err := e.removeAllPeers()
|
err := e.removeAllPeers()
|
||||||
@@ -1129,32 +1138,34 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err := e.removePeers(networkMap.GetRemotePeers())
|
err := e.removePeers(remotePeers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = e.modifyPeers(networkMap.GetRemotePeers())
|
err = e.modifyPeers(remotePeers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = e.addNewPeers(networkMap.GetRemotePeers())
|
err = e.addNewPeers(remotePeers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
e.statusRecorder.FinishPeerListModifications()
|
e.statusRecorder.FinishPeerListModifications()
|
||||||
|
|
||||||
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
|
e.updatePeerSSHHostKeys(remotePeers)
|
||||||
|
|
||||||
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
|
if err := e.updateSSHClientConfig(remotePeers); err != nil {
|
||||||
log.Warnf("failed to update SSH client config: %v", err)
|
log.Warnf("failed to update SSH client config: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||||
}
|
}
|
||||||
|
|
||||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, networkMap.GetRemotePeers())
|
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
|
||||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||||
|
|
||||||
e.networkSerial = serial
|
e.networkSerial = serial
|
||||||
|
|||||||
@@ -11,15 +11,18 @@ import (
|
|||||||
|
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||||
)
|
)
|
||||||
|
|
||||||
type sshServer interface {
|
type sshServer interface {
|
||||||
Start(ctx context.Context, addr netip.AddrPort) error
|
Start(ctx context.Context, addr netip.AddrPort) error
|
||||||
Stop() error
|
Stop() error
|
||||||
GetStatus() (bool, []sshserver.SessionInfo)
|
GetStatus() (bool, []sshserver.SessionInfo)
|
||||||
|
UpdateSSHAuth(config *sshauth.Config)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) setupSSHPortRedirection() error {
|
func (e *Engine) setupSSHPortRedirection() error {
|
||||||
@@ -353,3 +356,38 @@ func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.Sessio
|
|||||||
|
|
||||||
return sshServer.GetStatus()
|
return sshServer.GetStatus()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateSSHServerAuth updates SSH fine-grained access control configuration on a running SSH server
|
||||||
|
func (e *Engine) updateSSHServerAuth(sshAuth *mgmProto.SSHAuth) {
|
||||||
|
if sshAuth == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.sshServer == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
protoUsers := sshAuth.GetAuthorizedUsers()
|
||||||
|
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
|
||||||
|
for i, hash := range protoUsers {
|
||||||
|
if len(hash) != 16 {
|
||||||
|
log.Warnf("invalid hash length %d, expected 16 - skipping SSH server auth update", len(hash))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
machineUsers := make(map[string][]uint32)
|
||||||
|
for osUser, indexes := range sshAuth.GetMachineUsers() {
|
||||||
|
machineUsers[osUser] = indexes.GetIndexes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update SSH server with new authorization configuration
|
||||||
|
authConfig := &sshauth.Config{
|
||||||
|
UserIDClaim: sshAuth.GetUserIDClaim(),
|
||||||
|
AuthorizedUsers: authorizedUsers,
|
||||||
|
MachineUsers: machineUsers,
|
||||||
|
}
|
||||||
|
|
||||||
|
e.sshServer.UpdateSSHAuth(authConfig)
|
||||||
|
}
|
||||||
|
|||||||
@@ -148,13 +148,15 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
|||||||
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
||||||
// be used.
|
// be used.
|
||||||
func (conn *Conn) Open(engineCtx context.Context) error {
|
func (conn *Conn) Open(engineCtx context.Context) error {
|
||||||
conn.semaphore.Add(engineCtx)
|
if err := conn.semaphore.Add(engineCtx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
if conn.opened {
|
if conn.opened {
|
||||||
conn.semaphore.Done(engineCtx)
|
conn.semaphore.Done()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,6 +167,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
|||||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
conn.semaphore.Done()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
conn.workerICE = workerICE
|
conn.workerICE = workerICE
|
||||||
@@ -200,7 +203,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
|||||||
defer conn.wg.Done()
|
defer conn.wg.Done()
|
||||||
|
|
||||||
conn.waitInitialRandomSleepTime(conn.ctx)
|
conn.waitInitialRandomSleepTime(conn.ctx)
|
||||||
conn.semaphore.Done(conn.ctx)
|
conn.semaphore.Done()
|
||||||
|
|
||||||
conn.guard.Start(conn.ctx, conn.onGuardEvent)
|
conn.guard.Start(conn.ctx, conn.onGuardEvent)
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package profilemanager
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@@ -820,3 +821,85 @@ func readConfig(configPath string, createIfMissing bool) (*Config, error) {
|
|||||||
func WriteOutConfig(path string, config *Config) error {
|
func WriteOutConfig(path string, config *Config) error {
|
||||||
return util.WriteJson(context.Background(), path, config)
|
return util.WriteJson(context.Background(), path, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DirectWriteOutConfig writes config directly without atomic temp file operations.
|
||||||
|
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
|
||||||
|
func DirectWriteOutConfig(path string, config *Config) error {
|
||||||
|
return util.DirectWriteJson(context.Background(), path, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
|
||||||
|
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
|
||||||
|
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||||
|
if !fileExists(input.ConfigPath) {
|
||||||
|
log.Infof("generating new config %s", input.ConfigPath)
|
||||||
|
cfg, err := createNewConfig(input)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = util.DirectWriteJson(context.Background(), input.ConfigPath, cfg)
|
||||||
|
return cfg, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||||
|
input.PreSharedKey = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enforce permissions on existing config files (same as UpdateOrCreateConfig)
|
||||||
|
if err := util.EnforcePermission(input.ConfigPath); err != nil {
|
||||||
|
log.Errorf("failed to enforce permission on config file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return directUpdate(input)
|
||||||
|
}
|
||||||
|
|
||||||
|
func directUpdate(input ConfigInput) (*Config, error) {
|
||||||
|
config := &Config{}
|
||||||
|
|
||||||
|
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := config.apply(input)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated {
|
||||||
|
if err := util.DirectWriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigToJSON serializes a Config struct to a JSON string.
|
||||||
|
// This is useful for exporting config to alternative storage mechanisms
|
||||||
|
// (e.g., UserDefaults on tvOS where file writes are blocked).
|
||||||
|
func ConfigToJSON(config *Config) (string, error) {
|
||||||
|
bs, err := json.MarshalIndent(config, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(bs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigFromJSON deserializes a JSON string to a Config struct.
|
||||||
|
// This is useful for restoring config from alternative storage mechanisms.
|
||||||
|
// After unmarshaling, defaults are applied to ensure the config is fully initialized.
|
||||||
|
func ConfigFromJSON(jsonStr string) (*Config, error) {
|
||||||
|
config := &Config{}
|
||||||
|
err := json.Unmarshal([]byte(jsonStr), config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply defaults to ensure required fields are initialized.
|
||||||
|
// This mirrors what readConfig does after loading from file.
|
||||||
|
if _, err := config.apply(ConfigInput{}); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to apply defaults to config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ const (
|
|||||||
|
|
||||||
defaultTempDir = "/var/lib/netbird/tmp-install"
|
defaultTempDir = "/var/lib/netbird/tmp-install"
|
||||||
|
|
||||||
pkgDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
|
pkgDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ const (
|
|||||||
|
|
||||||
msiLogFile = "msi.log"
|
msiLogFile = "msi.log"
|
||||||
|
|
||||||
msiDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
|
msiDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
|
||||||
exeDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
|
exeDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|||||||
@@ -75,6 +75,8 @@ type Client struct {
|
|||||||
dnsManager dns.IosDnsManager
|
dnsManager dns.IosDnsManager
|
||||||
loginComplete bool
|
loginComplete bool
|
||||||
connectClient *internal.ConnectClient
|
connectClient *internal.ConnectClient
|
||||||
|
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
|
||||||
|
preloadedConfig *profilemanager.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
@@ -92,17 +94,44 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetConfigFromJSON loads config from a JSON string into memory.
|
||||||
|
// This is used on tvOS where file writes to App Group containers are blocked.
|
||||||
|
// When set, IsLoginRequired() and Run() will use this preloaded config instead of reading from file.
|
||||||
|
func (c *Client) SetConfigFromJSON(jsonStr string) error {
|
||||||
|
cfg, err := profilemanager.ConfigFromJSON(jsonStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("SetConfigFromJSON: failed to parse config JSON: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.preloadedConfig = cfg
|
||||||
|
log.Infof("SetConfigFromJSON: config loaded successfully from JSON")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Run start the internal client. It is a blocker function
|
// Run start the internal client. It is a blocker function
|
||||||
func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||||
exportEnvList(envList)
|
exportEnvList(envList)
|
||||||
log.Infof("Starting NetBird client")
|
log.Infof("Starting NetBird client")
|
||||||
log.Debugf("Tunnel uses interface: %s", interfaceName)
|
log.Debugf("Tunnel uses interface: %s", interfaceName)
|
||||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
|
||||||
ConfigPath: c.cfgFile,
|
var cfg *profilemanager.Config
|
||||||
StateFilePath: c.stateFile,
|
var err error
|
||||||
})
|
|
||||||
if err != nil {
|
// Use preloaded config if available (tvOS where file writes are blocked)
|
||||||
return err
|
if c.preloadedConfig != nil {
|
||||||
|
log.Infof("Run: using preloaded config from memory")
|
||||||
|
cfg = c.preloadedConfig
|
||||||
|
} else {
|
||||||
|
log.Infof("Run: loading config from file")
|
||||||
|
// Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
|
||||||
|
// which are blocked by the tvOS sandbox in App Group containers
|
||||||
|
cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
|
ConfigPath: c.cfgFile,
|
||||||
|
StateFilePath: c.stateFile,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
||||||
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
|
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
|
||||||
@@ -120,7 +149,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
|||||||
c.ctxCancelLock.Unlock()
|
c.ctxCancelLock.Unlock()
|
||||||
|
|
||||||
auth := NewAuthWithConfig(ctx, cfg)
|
auth := NewAuthWithConfig(ctx, cfg)
|
||||||
err = auth.Login()
|
err = auth.LoginSync()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -208,14 +237,45 @@ func (c *Client) IsLoginRequired() bool {
|
|||||||
defer c.ctxCancelLock.Unlock()
|
defer c.ctxCancelLock.Unlock()
|
||||||
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
|
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
|
||||||
|
|
||||||
cfg, _ := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
var cfg *profilemanager.Config
|
||||||
ConfigPath: c.cfgFile,
|
var err error
|
||||||
})
|
|
||||||
|
|
||||||
needsLogin, _ := internal.IsLoginRequired(ctx, cfg)
|
// Use preloaded config if available (tvOS where file writes are blocked)
|
||||||
|
if c.preloadedConfig != nil {
|
||||||
|
log.Infof("IsLoginRequired: using preloaded config from memory")
|
||||||
|
cfg = c.preloadedConfig
|
||||||
|
} else {
|
||||||
|
log.Infof("IsLoginRequired: loading config from file")
|
||||||
|
// Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
|
||||||
|
// which are blocked by the tvOS sandbox in App Group containers
|
||||||
|
cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
|
ConfigPath: c.cfgFile,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("IsLoginRequired: failed to load config: %v", err)
|
||||||
|
// If we can't load config, assume login is required
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg == nil {
|
||||||
|
log.Errorf("IsLoginRequired: config is nil")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
needsLogin, err := internal.IsLoginRequired(ctx, cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("IsLoginRequired: check failed: %v", err)
|
||||||
|
// If the check fails, assume login is required to be safe
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
log.Infof("IsLoginRequired: needsLogin=%v", needsLogin)
|
||||||
return needsLogin
|
return needsLogin
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// loginForMobileAuthTimeout is the timeout for requesting auth info from the server
|
||||||
|
const loginForMobileAuthTimeout = 30 * time.Second
|
||||||
|
|
||||||
func (c *Client) LoginForMobile() string {
|
func (c *Client) LoginForMobile() string {
|
||||||
var ctx context.Context
|
var ctx context.Context
|
||||||
//nolint
|
//nolint
|
||||||
@@ -228,16 +288,26 @@ func (c *Client) LoginForMobile() string {
|
|||||||
defer c.ctxCancelLock.Unlock()
|
defer c.ctxCancelLock.Unlock()
|
||||||
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
|
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
|
||||||
|
|
||||||
cfg, _ := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
// Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
|
||||||
|
// which are blocked by the tvOS sandbox in App Group containers
|
||||||
|
cfg, err := profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("LoginForMobile: failed to load config: %v", err)
|
||||||
|
return fmt.Sprintf("failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, false, "")
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, false, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err.Error()
|
return err.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
// Use a bounded timeout for the auth info request to prevent indefinite hangs
|
||||||
|
authInfoCtx, authInfoCancel := context.WithTimeout(ctx, loginForMobileAuthTimeout)
|
||||||
|
defer authInfoCancel()
|
||||||
|
|
||||||
|
flowInfo, err := oAuthFlow.RequestAuthInfo(authInfoCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err.Error()
|
return err.Error()
|
||||||
}
|
}
|
||||||
@@ -249,10 +319,14 @@ func (c *Client) LoginForMobile() string {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Errorf("LoginForMobile: WaitToken failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
jwtToken := tokenInfo.GetTokenToUse()
|
jwtToken := tokenInfo.GetTokenToUse()
|
||||||
_ = internal.Login(ctx, cfg, "", jwtToken)
|
if err := internal.Login(ctx, cfg, "", jwtToken); err != nil {
|
||||||
|
log.Errorf("LoginForMobile: Login failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
c.loginComplete = true
|
c.loginComplete = true
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/cmd"
|
"github.com/netbirdio/netbird/client/cmd"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
)
|
)
|
||||||
@@ -33,7 +34,8 @@ type ErrListener interface {
|
|||||||
// URLOpener it is a callback interface. The Open function will be triggered if
|
// URLOpener it is a callback interface. The Open function will be triggered if
|
||||||
// the backend want to show an url for the user
|
// the backend want to show an url for the user
|
||||||
type URLOpener interface {
|
type URLOpener interface {
|
||||||
Open(string)
|
Open(url string, userCode string)
|
||||||
|
OnLoginSuccess()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Auth can register or login new client
|
// Auth can register or login new client
|
||||||
@@ -72,13 +74,32 @@ func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth
|
|||||||
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
|
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
|
||||||
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
||||||
// is not supported and returns false without saving the configuration. For other errors return false.
|
// is not supported and returns false without saving the configuration. For other errors return false.
|
||||||
func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
|
func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||||
|
if listener == nil {
|
||||||
|
log.Errorf("SaveConfigIfSSOSupported: listener is nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
sso, err := a.saveConfigIfSSOSupported()
|
||||||
|
if err != nil {
|
||||||
|
listener.OnError(err)
|
||||||
|
} else {
|
||||||
|
listener.OnSuccess(sso)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||||
supportsSSO := true
|
supportsSSO := true
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
s, ok := gstatus.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
||||||
supportsSSO = false
|
supportsSSO = false
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
@@ -97,12 +118,29 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
|
|||||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
||||||
|
// which are blocked by the tvOS sandbox in App Group containers
|
||||||
|
err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key.
|
// LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key.
|
||||||
func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupKey string, deviceName string) {
|
||||||
|
if resultListener == nil {
|
||||||
|
log.Errorf("LoginWithSetupKeyAndSaveConfig: resultListener is nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
err := a.loginWithSetupKeyAndSaveConfig(setupKey, deviceName)
|
||||||
|
if err != nil {
|
||||||
|
resultListener.OnError(err)
|
||||||
|
} else {
|
||||||
|
resultListener.OnSuccess()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||||
//nolint
|
//nolint
|
||||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||||
|
|
||||||
@@ -118,10 +156,14 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
|||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
||||||
|
// which are blocked by the tvOS sandbox in App Group containers
|
||||||
|
return profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) Login() error {
|
// LoginSync performs a synchronous login check without UI interaction
|
||||||
|
// Used for background VPN connection where user should already be authenticated
|
||||||
|
func (a *Auth) LoginSync() error {
|
||||||
var needsLogin bool
|
var needsLogin bool
|
||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
@@ -135,23 +177,142 @@ func (a *Auth) Login() error {
|
|||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
if needsLogin {
|
if needsLogin {
|
||||||
return fmt.Errorf("Not authenticated")
|
return fmt.Errorf("not authenticated")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.withBackOff(a.ctx, func() error {
|
err = a.withBackOff(a.ctx, func() error {
|
||||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||||
return nil
|
// PermissionDenied means registration is required or peer is blocked
|
||||||
|
return backoff.Permanent(err)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("login failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login performs interactive login with device authentication support
|
||||||
|
// Deprecated: Use LoginWithDeviceName instead to ensure proper device naming on tvOS
|
||||||
|
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, forceDeviceAuth bool) {
|
||||||
|
// Use empty device name - system will use hostname as fallback
|
||||||
|
a.LoginWithDeviceName(resultListener, urlOpener, forceDeviceAuth, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithDeviceName performs interactive login with device authentication support
|
||||||
|
// The deviceName parameter allows specifying a custom device name (required for tvOS)
|
||||||
|
func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpener, forceDeviceAuth bool, deviceName string) {
|
||||||
|
if resultListener == nil {
|
||||||
|
log.Errorf("LoginWithDeviceName: resultListener is nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if urlOpener == nil {
|
||||||
|
log.Errorf("LoginWithDeviceName: urlOpener is nil")
|
||||||
|
resultListener.OnError(fmt.Errorf("urlOpener is nil"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
err := a.login(urlOpener, forceDeviceAuth, deviceName)
|
||||||
|
if err != nil {
|
||||||
|
resultListener.OnError(err)
|
||||||
|
} else {
|
||||||
|
resultListener.OnSuccess()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
|
||||||
|
var needsLogin bool
|
||||||
|
|
||||||
|
// Create context with device name if provided
|
||||||
|
ctx := a.ctx
|
||||||
|
if deviceName != "" {
|
||||||
|
//nolint:staticcheck
|
||||||
|
ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if we need to generate JWT token
|
||||||
|
err := a.withBackOff(ctx, func() (err error) {
|
||||||
|
needsLogin, err = internal.IsLoginRequired(ctx, a.config)
|
||||||
|
return
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
jwtToken := ""
|
||||||
|
if needsLogin {
|
||||||
|
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||||
|
}
|
||||||
|
jwtToken = tokenInfo.GetTokenToUse()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = a.withBackOff(ctx, func() error {
|
||||||
|
err := internal.Login(ctx, a.config, "", jwtToken)
|
||||||
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||||
|
// PermissionDenied means registration is required or peer is blocked
|
||||||
|
return backoff.Permanent(err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("login failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the config before notifying success to ensure persistence completes
|
||||||
|
// before the callback potentially triggers teardown on the Swift side.
|
||||||
|
// Note: This differs from Android which doesn't save config after login.
|
||||||
|
// On iOS/tvOS, we save here because:
|
||||||
|
// 1. The config may have been modified during login (e.g., new tokens)
|
||||||
|
// 2. On tvOS, the Network Extension context may be the only place with
|
||||||
|
// write permissions to the App Group container
|
||||||
|
if a.cfgPath != "" {
|
||||||
|
if err := profilemanager.DirectWriteOutConfig(a.cfgPath, a.config); err != nil {
|
||||||
|
log.Warnf("failed to save config after login: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notify caller of successful login synchronously before returning
|
||||||
|
urlOpener.OnLoginSuccess()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const authInfoRequestTimeout = 30 * time.Second
|
||||||
|
|
||||||
|
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
|
||||||
|
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use a bounded timeout for the auth info request to prevent indefinite hangs
|
||||||
|
authInfoCtx, authInfoCancel := context.WithTimeout(a.ctx, authInfoRequestTimeout)
|
||||||
|
defer authInfoCancel()
|
||||||
|
|
||||||
|
flowInfo, err := oAuthFlow.RequestAuthInfo(authInfoCtx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||||
|
|
||||||
|
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||||
|
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
||||||
|
defer cancel()
|
||||||
|
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokenInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
||||||
return backoff.RetryNotify(
|
return backoff.RetryNotify(
|
||||||
bf,
|
bf,
|
||||||
@@ -160,3 +321,24 @@ func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
|||||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetConfigJSON returns the current config as a JSON string.
|
||||||
|
// This can be used by the caller to persist the config via alternative storage
|
||||||
|
// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).
|
||||||
|
func (a *Auth) GetConfigJSON() (string, error) {
|
||||||
|
if a.config == nil {
|
||||||
|
return "", fmt.Errorf("no config available")
|
||||||
|
}
|
||||||
|
return profilemanager.ConfigToJSON(a.config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetConfigFromJSON loads config from a JSON string.
|
||||||
|
// This can be used to restore config from alternative storage mechanisms.
|
||||||
|
func (a *Auth) SetConfigFromJSON(jsonStr string) error {
|
||||||
|
cfg, err := profilemanager.ConfigFromJSON(jsonStr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
a.config = cfg
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -112,6 +112,8 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) {
|
|||||||
|
|
||||||
// Commit write out the changes into config file
|
// Commit write out the changes into config file
|
||||||
func (p *Preferences) Commit() error {
|
func (p *Preferences) Commit() error {
|
||||||
_, err := profilemanager.UpdateOrCreateConfig(p.configInput)
|
// Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
|
||||||
|
// which are blocked by the tvOS sandbox in App Group containers
|
||||||
|
_, err := profilemanager.DirectUpdateOrCreateConfig(p.configInput)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
184
client/ssh/auth/auth.go
Normal file
184
client/ssh/auth/auth.go
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultUserIDClaim is the default JWT claim used to extract user IDs
|
||||||
|
DefaultUserIDClaim = "sub"
|
||||||
|
// Wildcard is a special user ID that matches all users
|
||||||
|
Wildcard = "*"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrEmptyUserID = errors.New("JWT user ID is empty")
|
||||||
|
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
|
||||||
|
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
|
||||||
|
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Authorizer handles SSH fine-grained access control authorization
|
||||||
|
type Authorizer struct {
|
||||||
|
// UserIDClaim is the JWT claim to extract the user ID from
|
||||||
|
userIDClaim string
|
||||||
|
|
||||||
|
// authorizedUsers is a list of hashed user IDs authorized to access this peer
|
||||||
|
authorizedUsers []sshuserhash.UserIDHash
|
||||||
|
|
||||||
|
// machineUsers maps OS login usernames to lists of authorized user indexes
|
||||||
|
machineUsers map[string][]uint32
|
||||||
|
|
||||||
|
// mu protects the list of users
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config contains configuration for the SSH authorizer
|
||||||
|
type Config struct {
|
||||||
|
// UserIDClaim is the JWT claim to extract the user ID from (e.g., "sub", "email")
|
||||||
|
UserIDClaim string
|
||||||
|
|
||||||
|
// AuthorizedUsers is a list of hashed user IDs (FNV-1a 64-bit) authorized to access this peer
|
||||||
|
AuthorizedUsers []sshuserhash.UserIDHash
|
||||||
|
|
||||||
|
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
|
||||||
|
// If a user wants to login as a specific OS user, their index must be in the corresponding list
|
||||||
|
MachineUsers map[string][]uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuthorizer creates a new SSH authorizer with empty configuration
|
||||||
|
func NewAuthorizer() *Authorizer {
|
||||||
|
a := &Authorizer{
|
||||||
|
userIDClaim: DefaultUserIDClaim,
|
||||||
|
machineUsers: make(map[string][]uint32),
|
||||||
|
}
|
||||||
|
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update updates the authorizer configuration with new values
|
||||||
|
func (a *Authorizer) Update(config *Config) {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
|
if config == nil {
|
||||||
|
// Clear authorization
|
||||||
|
a.userIDClaim = DefaultUserIDClaim
|
||||||
|
a.authorizedUsers = []sshuserhash.UserIDHash{}
|
||||||
|
a.machineUsers = make(map[string][]uint32)
|
||||||
|
log.Info("SSH authorization cleared")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userIDClaim := config.UserIDClaim
|
||||||
|
if userIDClaim == "" {
|
||||||
|
userIDClaim = DefaultUserIDClaim
|
||||||
|
}
|
||||||
|
a.userIDClaim = userIDClaim
|
||||||
|
|
||||||
|
// Store authorized users list
|
||||||
|
a.authorizedUsers = config.AuthorizedUsers
|
||||||
|
|
||||||
|
// Store machine users mapping
|
||||||
|
machineUsers := make(map[string][]uint32)
|
||||||
|
for osUser, indexes := range config.MachineUsers {
|
||||||
|
if len(indexes) > 0 {
|
||||||
|
machineUsers[osUser] = indexes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
a.machineUsers = machineUsers
|
||||||
|
|
||||||
|
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
|
||||||
|
len(config.AuthorizedUsers), len(machineUsers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorize validates if a user is authorized to login as the specified OS user
|
||||||
|
// Returns nil if authorized, or an error describing why authorization failed
|
||||||
|
func (a *Authorizer) Authorize(jwtUserID, osUsername string) error {
|
||||||
|
if jwtUserID == "" {
|
||||||
|
log.Warnf("SSH auth denied: JWT user ID is empty for OS user '%s'", osUsername)
|
||||||
|
return ErrEmptyUserID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash the JWT user ID for comparison
|
||||||
|
hashedUserID, err := sshuserhash.HashUserID(jwtUserID)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("SSH auth denied: failed to hash user ID '%s' for OS user '%s': %v", jwtUserID, osUsername, err)
|
||||||
|
return fmt.Errorf("failed to hash user ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
a.mu.RLock()
|
||||||
|
defer a.mu.RUnlock()
|
||||||
|
|
||||||
|
// Find the index of this user in the authorized list
|
||||||
|
userIndex, found := a.findUserIndex(hashedUserID)
|
||||||
|
if !found {
|
||||||
|
log.Warnf("SSH auth denied: user '%s' (hash: %s) not in authorized list for OS user '%s'", jwtUserID, hashedUserID, osUsername)
|
||||||
|
return ErrUserNotAuthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkMachineUserMapping validates if a user's index is authorized for the specified OS user
|
||||||
|
// Checks wildcard mapping first, then specific OS user mappings
|
||||||
|
func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) error {
|
||||||
|
// If wildcard exists and user's index is in the wildcard list, allow access to any OS user
|
||||||
|
if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard {
|
||||||
|
if a.isIndexInList(uint32(userIndex), wildcardIndexes) {
|
||||||
|
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' via wildcard (index: %d)", jwtUserID, osUsername, userIndex)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for specific OS username mapping
|
||||||
|
allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername]
|
||||||
|
if !hasMachineUserMapping {
|
||||||
|
// No mapping for this OS user - deny by default (fail closed)
|
||||||
|
log.Warnf("SSH auth denied: no machine user mapping for OS user '%s' (JWT user: %s)", osUsername, jwtUserID)
|
||||||
|
return ErrNoMachineUserMapping
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if user's index is in the allowed indexes for this specific OS user
|
||||||
|
if !a.isIndexInList(uint32(userIndex), allowedIndexes) {
|
||||||
|
log.Warnf("SSH auth denied: user '%s' not mapped to OS user '%s' (user index: %d)", jwtUserID, osUsername, userIndex)
|
||||||
|
return ErrUserNotMappedToOSUser
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' (index: %d)", jwtUserID, osUsername, userIndex)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserIDClaim returns the JWT claim name used to extract user IDs
|
||||||
|
func (a *Authorizer) GetUserIDClaim() string {
|
||||||
|
a.mu.RLock()
|
||||||
|
defer a.mu.RUnlock()
|
||||||
|
return a.userIDClaim
|
||||||
|
}
|
||||||
|
|
||||||
|
// findUserIndex finds the index of a hashed user ID in the authorized users list
|
||||||
|
// Returns the index and true if found, 0 and false if not found
|
||||||
|
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {
|
||||||
|
for i, id := range a.authorizedUsers {
|
||||||
|
if id == hashedUserID {
|
||||||
|
return i, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isIndexInList checks if an index exists in a list of indexes
|
||||||
|
func (a *Authorizer) isIndexInList(index uint32, indexes []uint32) bool {
|
||||||
|
for _, idx := range indexes {
|
||||||
|
if idx == index {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
612
client/ssh/auth/auth_test.go
Normal file
612
client/ssh/auth/auth_test.go
Normal file
@@ -0,0 +1,612 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/sshauth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthorizer_Authorize_UserNotInList(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Set up authorized users list with one user
|
||||||
|
authorizedUserHash, err := sshauth.HashUserID("authorized-user")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{authorizedUserHash},
|
||||||
|
MachineUsers: map[string][]uint32{},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// Try to authorize a different user
|
||||||
|
err = authorizer.Authorize("unauthorized-user", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user2Hash, err := sshauth.HashUserID("user2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
|
||||||
|
MachineUsers: map[string][]uint32{}, // Empty = deny all (fail closed)
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// All attempts should fail when no machine user mappings exist (fail closed)
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "admin")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user1", "postgres")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user2Hash, err := sshauth.HashUserID("user2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user3Hash, err := sshauth.HashUserID("user3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": {0, 1}, // user1 and user2 can access root
|
||||||
|
"postgres": {1, 2}, // user2 and user3 can access postgres
|
||||||
|
"admin": {0}, // only user1 can access admin
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// user1 (index 0) should access root and admin
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user1", "admin")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// user2 (index 1) should access root and postgres
|
||||||
|
err = authorizer.Authorize("user2", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "postgres")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// user3 (index 2) should access postgres
|
||||||
|
err = authorizer.Authorize("user3", "postgres")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Set up authorized users list
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user2Hash, err := sshauth.HashUserID("user2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user3Hash, err := sshauth.HashUserID("user3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": {0, 1}, // user1 and user2 can access root
|
||||||
|
"postgres": {1, 2}, // user2 and user3 can access postgres
|
||||||
|
"admin": {0}, // only user1 can access admin
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// user1 (index 0) should NOT access postgres
|
||||||
|
err = authorizer.Authorize("user1", "postgres")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
|
|
||||||
|
// user2 (index 1) should NOT access admin
|
||||||
|
err = authorizer.Authorize("user2", "admin")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
|
|
||||||
|
// user3 (index 2) should NOT access root
|
||||||
|
err = authorizer.Authorize("user3", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
|
|
||||||
|
// user3 (index 2) should NOT access admin
|
||||||
|
err = authorizer.Authorize("user3", "admin")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Set up authorized users list
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": {0}, // only root is mapped
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// user1 should NOT access an unmapped OS user (fail closed)
|
||||||
|
err = authorizer.Authorize("user1", "postgres")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Set up authorized users list
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||||
|
MachineUsers: map[string][]uint32{},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// Empty user ID should fail
|
||||||
|
err = authorizer.Authorize("", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrEmptyUserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Set up multiple authorized users
|
||||||
|
userHashes := make([]sshauth.UserIDHash, 10)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
hash, err := sshauth.HashUserID("user" + string(rune('0'+i)))
|
||||||
|
require.NoError(t, err)
|
||||||
|
userHashes[i] = hash
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create machine user mapping for all users
|
||||||
|
rootIndexes := make([]uint32, 10)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
rootIndexes[i] = uint32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: userHashes,
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": rootIndexes,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// All users should be authorized for root
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
err := authorizer.Authorize("user"+string(rune('0'+i)), "root")
|
||||||
|
assert.NoError(t, err, "user%d should be authorized", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// User not in list should fail
|
||||||
|
err := authorizer.Authorize("unknown-user", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Set up initial configuration
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||||
|
MachineUsers: map[string][]uint32{"root": {0}},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// user1 should be authorized
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Clear configuration
|
||||||
|
authorizer.Update(nil)
|
||||||
|
|
||||||
|
// user1 should no longer be authorized
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Machine users with empty index lists should be filtered out
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": {0},
|
||||||
|
"postgres": {}, // empty list - should be filtered out
|
||||||
|
"admin": nil, // nil list - should be filtered out
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// root should work
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// postgres should fail (no mapping)
|
||||||
|
err = authorizer.Authorize("user1", "postgres")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
|
// admin should fail (no mapping)
|
||||||
|
err = authorizer.Authorize("user1", "admin")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_CustomUserIDClaim(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Set up with custom user ID claim
|
||||||
|
user1Hash, err := sshauth.HashUserID("user@example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: "email",
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": {0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// Verify the custom claim is set
|
||||||
|
assert.Equal(t, "email", authorizer.GetUserIDClaim())
|
||||||
|
|
||||||
|
// Authorize with email as user ID
|
||||||
|
err = authorizer.Authorize("user@example.com", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_DefaultUserIDClaim(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Verify default claim
|
||||||
|
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
|
||||||
|
assert.Equal(t, "sub", authorizer.GetUserIDClaim())
|
||||||
|
|
||||||
|
// Set up with empty user ID claim (should use default)
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: "", // empty - should use default
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||||
|
MachineUsers: map[string][]uint32{},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// Should fall back to default
|
||||||
|
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Create a large authorized users list
|
||||||
|
const numUsers = 1000
|
||||||
|
userHashes := make([]sshauth.UserIDHash, numUsers)
|
||||||
|
for i := 0; i < numUsers; i++ {
|
||||||
|
hash, err := sshauth.HashUserID("user" + string(rune(i)))
|
||||||
|
require.NoError(t, err)
|
||||||
|
userHashes[i] = hash
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: userHashes,
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": {0, 500, 999}, // first, middle, and last user
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// First user should have access
|
||||||
|
err := authorizer.Authorize("user"+string(rune(0)), "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Middle user should have access
|
||||||
|
err = authorizer.Authorize("user"+string(rune(500)), "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Last user should have access
|
||||||
|
err = authorizer.Authorize("user"+string(rune(999)), "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// User not in mapping should NOT have access
|
||||||
|
err = authorizer.Authorize("user"+string(rune(100)), "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_ConcurrentAuthorization(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Set up authorized users
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user2Hash, err := sshauth.HashUserID("user2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": {0, 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// Test concurrent authorization calls (should be safe to read concurrently)
|
||||||
|
const numGoroutines = 100
|
||||||
|
errChan := make(chan error, numGoroutines)
|
||||||
|
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
go func(idx int) {
|
||||||
|
user := "user1"
|
||||||
|
if idx%2 == 0 {
|
||||||
|
user = "user2"
|
||||||
|
}
|
||||||
|
err := authorizer.Authorize(user, "root")
|
||||||
|
errChan <- err
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all goroutines to complete and collect errors
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
err := <-errChan
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user2Hash, err := sshauth.HashUserID("user2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user3Hash, err := sshauth.HashUserID("user3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Configure with wildcard - all authorized users can access any OS user
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"*": {0, 1, 2}, // wildcard with all user indexes
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// All authorized users should be able to access any OS user
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "postgres")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user3", "admin")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user1", "ubuntu")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "nginx")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user3", "docker")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Configure with wildcard
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"*": {0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// user1 should have access
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Unauthorized user should still be denied even with wildcard
|
||||||
|
err = authorizer.Authorize("unauthorized-user", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user2Hash, err := sshauth.HashUserID("user2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Configure with both wildcard and specific mappings
|
||||||
|
// Wildcard takes precedence for users in the wildcard index list
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"*": {0, 1}, // wildcard for both users
|
||||||
|
"root": {0}, // specific mapping that would normally restrict to user1 only
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// Both users should be able to access root via wildcard (takes precedence over specific mapping)
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Both users should be able to access any other OS user via wildcard
|
||||||
|
err = authorizer.Authorize("user1", "postgres")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "admin")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user2Hash, err := sshauth.HashUserID("user2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Configure WITHOUT wildcard - only specific mappings
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": {0}, // only user1
|
||||||
|
"postgres": {1}, // only user2
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// user1 can access root
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// user2 can access postgres
|
||||||
|
err = authorizer.Authorize("user2", "postgres")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// user1 cannot access postgres
|
||||||
|
err = authorizer.Authorize("user1", "postgres")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
|
|
||||||
|
// user2 cannot access root
|
||||||
|
err = authorizer.Authorize("user2", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
|
|
||||||
|
// Neither can access unmapped OS users
|
||||||
|
err = authorizer.Authorize("user1", "admin")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "admin")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
|
||||||
|
// This test covers the scenario where wildcard exists with limited indexes.
|
||||||
|
// Only users whose indexes are in the wildcard list can access any OS user via wildcard.
|
||||||
|
// Other users can only access OS users they are explicitly mapped to.
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Create two authorized user hashes (simulating the base64-encoded hashes in the config)
|
||||||
|
wasmHash, err := sshauth.HashUserID("wasm")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user2Hash, err := sshauth.HashUserID("user2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Configure with wildcard having only index 0, and specific mappings for other OS users
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: "sub",
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{wasmHash, user2Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"*": {0}, // wildcard with only index 0 - only wasm has wildcard access
|
||||||
|
"alice": {1}, // specific mapping for user2
|
||||||
|
"bob": {1}, // specific mapping for user2
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// wasm (index 0) should access any OS user via wildcard
|
||||||
|
err = authorizer.Authorize("wasm", "root")
|
||||||
|
assert.NoError(t, err, "wasm should access root via wildcard")
|
||||||
|
|
||||||
|
err = authorizer.Authorize("wasm", "alice")
|
||||||
|
assert.NoError(t, err, "wasm should access alice via wildcard")
|
||||||
|
|
||||||
|
err = authorizer.Authorize("wasm", "bob")
|
||||||
|
assert.NoError(t, err, "wasm should access bob via wildcard")
|
||||||
|
|
||||||
|
err = authorizer.Authorize("wasm", "postgres")
|
||||||
|
assert.NoError(t, err, "wasm should access postgres via wildcard")
|
||||||
|
|
||||||
|
// user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres
|
||||||
|
err = authorizer.Authorize("user2", "alice")
|
||||||
|
assert.NoError(t, err, "user2 should access alice via explicit mapping")
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "bob")
|
||||||
|
assert.NoError(t, err, "user2 should access bob via explicit mapping")
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "root")
|
||||||
|
assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)")
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "postgres")
|
||||||
|
assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)")
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
|
// Unauthorized user should still be denied
|
||||||
|
err = authorizer.Authorize("user3", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
|
||||||
|
}
|
||||||
@@ -27,9 +27,11 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
"github.com/netbirdio/netbird/client/ssh/server"
|
"github.com/netbirdio/netbird/client/ssh/server"
|
||||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
|
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
@@ -137,6 +139,21 @@ func TestSSHProxy_Connect(t *testing.T) {
|
|||||||
sshServer := server.New(serverConfig)
|
sshServer := server.New(serverConfig)
|
||||||
sshServer.SetAllowRootLogin(true)
|
sshServer.SetAllowRootLogin(true)
|
||||||
|
|
||||||
|
// Configure SSH authorization for the test user
|
||||||
|
testUsername := testutil.GetTestUsername(t)
|
||||||
|
testJWTUser := "test-username"
|
||||||
|
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
authConfig := &sshauth.Config{
|
||||||
|
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
testUsername: {0}, // Index 0 in AuthorizedUsers
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sshServer.UpdateSSHAuth(authConfig)
|
||||||
|
|
||||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||||
defer func() { _ = sshServer.Stop() }()
|
defer func() { _ = sshServer.Stop() }()
|
||||||
|
|
||||||
@@ -150,10 +167,10 @@ func TestSSHProxy_Connect(t *testing.T) {
|
|||||||
|
|
||||||
mockDaemon.setHostKey(host, hostPubKey)
|
mockDaemon.setHostKey(host, hostPubKey)
|
||||||
|
|
||||||
validToken := generateValidJWT(t, privateKey, issuer, audience)
|
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||||
mockDaemon.setJWTToken(validToken)
|
mockDaemon.setJWTToken(validToken)
|
||||||
|
|
||||||
proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil)
|
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
clientConn, proxyConn := net.Pipe()
|
clientConn, proxyConn := net.Pipe()
|
||||||
@@ -347,12 +364,12 @@ func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
|||||||
return privateKey, jwksJSON
|
return privateKey, jwksJSON
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
|
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
claims := jwt.MapClaims{
|
claims := jwt.MapClaims{
|
||||||
"iss": issuer,
|
"iss": issuer,
|
||||||
"aud": audience,
|
"aud": audience,
|
||||||
"sub": "test-user",
|
"sub": user,
|
||||||
"exp": time.Now().Add(time.Hour).Unix(),
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
"iat": time.Now().Unix(),
|
"iat": time.Now().Unix(),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,10 +23,12 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
"github.com/netbirdio/netbird/client/ssh/client"
|
"github.com/netbirdio/netbird/client/ssh/client"
|
||||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
|
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestJWTEnforcement(t *testing.T) {
|
func TestJWTEnforcement(t *testing.T) {
|
||||||
@@ -577,6 +579,22 @@ func TestJWTAuthentication(t *testing.T) {
|
|||||||
tc.setupServer(server)
|
tc.setupServer(server)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always set up authorization for test-user to ensure tests fail at JWT validation stage
|
||||||
|
testUserHash, err := sshuserhash.HashUserID("test-user")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Get current OS username for machine user mapping
|
||||||
|
currentUser := testutil.GetTestUsername(t)
|
||||||
|
|
||||||
|
authConfig := &sshauth.Config{
|
||||||
|
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
currentUser: {0}, // Allow test-user (index 0) to access current OS user
|
||||||
|
},
|
||||||
|
}
|
||||||
|
server.UpdateSSHAuth(authConfig)
|
||||||
|
|
||||||
serverAddr := StartTestServer(t, server)
|
serverAddr := StartTestServer(t, server)
|
||||||
defer require.NoError(t, server.Stop())
|
defer require.NoError(t, server.Stop())
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
"github.com/netbirdio/netbird/shared/auth"
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
@@ -138,6 +139,8 @@ type Server struct {
|
|||||||
jwtExtractor *jwt.ClaimsExtractor
|
jwtExtractor *jwt.ClaimsExtractor
|
||||||
jwtConfig *JWTConfig
|
jwtConfig *JWTConfig
|
||||||
|
|
||||||
|
authorizer *sshauth.Authorizer
|
||||||
|
|
||||||
suSupportsPty bool
|
suSupportsPty bool
|
||||||
loginIsUtilLinux bool
|
loginIsUtilLinux bool
|
||||||
}
|
}
|
||||||
@@ -179,6 +182,7 @@ func New(config *Config) *Server {
|
|||||||
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
|
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
|
||||||
jwtEnabled: config.JWT != nil,
|
jwtEnabled: config.JWT != nil,
|
||||||
jwtConfig: config.JWT,
|
jwtConfig: config.JWT,
|
||||||
|
authorizer: sshauth.NewAuthorizer(), // Initialize with empty config
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
return s
|
||||||
@@ -320,6 +324,19 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
|
|||||||
s.wgAddress = addr
|
s.wgAddress = addr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateSSHAuth updates the SSH fine-grained access control configuration
|
||||||
|
// This should be called when network map updates include new SSH auth configuration
|
||||||
|
func (s *Server) UpdateSSHAuth(config *sshauth.Config) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
// Reset JWT validator/extractor to pick up new userIDClaim
|
||||||
|
s.jwtValidator = nil
|
||||||
|
s.jwtExtractor = nil
|
||||||
|
|
||||||
|
s.authorizer.Update(config)
|
||||||
|
}
|
||||||
|
|
||||||
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
|
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
|
||||||
func (s *Server) ensureJWTValidator() error {
|
func (s *Server) ensureJWTValidator() error {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
@@ -328,6 +345,7 @@ func (s *Server) ensureJWTValidator() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
config := s.jwtConfig
|
config := s.jwtConfig
|
||||||
|
authorizer := s.authorizer
|
||||||
s.mu.RUnlock()
|
s.mu.RUnlock()
|
||||||
|
|
||||||
if config == nil {
|
if config == nil {
|
||||||
@@ -343,9 +361,16 @@ func (s *Server) ensureJWTValidator() error {
|
|||||||
true,
|
true,
|
||||||
)
|
)
|
||||||
|
|
||||||
extractor := jwt.NewClaimsExtractor(
|
// Use custom userIDClaim from authorizer if available
|
||||||
|
extractorOptions := []jwt.ClaimsExtractorOption{
|
||||||
jwt.WithAudience(config.Audience),
|
jwt.WithAudience(config.Audience),
|
||||||
)
|
}
|
||||||
|
if authorizer.GetUserIDClaim() != "" {
|
||||||
|
extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
|
||||||
|
log.Debugf("Using custom user ID claim: %s", authorizer.GetUserIDClaim())
|
||||||
|
}
|
||||||
|
|
||||||
|
extractor := jwt.NewClaimsExtractor(extractorOptions...)
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@@ -493,29 +518,41 @@ func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]int
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
|
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
|
||||||
|
osUsername := ctx.User()
|
||||||
|
remoteAddr := ctx.RemoteAddr()
|
||||||
|
|
||||||
if err := s.ensureJWTValidator(); err != nil {
|
if err := s.ensureJWTValidator(); err != nil {
|
||||||
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
log.Errorf("JWT validator initialization failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := s.validateJWTToken(password)
|
token, err := s.validateJWTToken(password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
log.Warnf("JWT authentication failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
userAuth, err := s.extractAndValidateUser(token)
|
userAuth, err := s.extractAndValidateUser(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
log.Warnf("User validation failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
key := newAuthKey(ctx.User(), ctx.RemoteAddr())
|
s.mu.RLock()
|
||||||
|
authorizer := s.authorizer
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if err := authorizer.Authorize(userAuth.UserId, osUsername); err != nil {
|
||||||
|
log.Warnf("SSH authorization denied for user %s (JWT user ID: %s) from %s: %v", osUsername, userAuth.UserId, remoteAddr, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
key := newAuthKey(osUsername, remoteAddr)
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.pendingAuthJWT[key] = userAuth.UserId
|
s.pendingAuthJWT[key] = userAuth.UserId
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
|
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", osUsername, userAuth.UserId, remoteAddr)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -312,6 +312,8 @@ type serviceClient struct {
|
|||||||
daemonVersion string
|
daemonVersion string
|
||||||
updateIndicationLock sync.Mutex
|
updateIndicationLock sync.Mutex
|
||||||
isUpdateIconActive bool
|
isUpdateIconActive bool
|
||||||
|
settingsEnabled bool
|
||||||
|
profilesEnabled bool
|
||||||
showNetworks bool
|
showNetworks bool
|
||||||
wNetworks fyne.Window
|
wNetworks fyne.Window
|
||||||
wProfiles fyne.Window
|
wProfiles fyne.Window
|
||||||
@@ -907,7 +909,7 @@ func (s *serviceClient) updateStatus() error {
|
|||||||
var systrayIconState bool
|
var systrayIconState bool
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case status.Status == string(internal.StatusConnected):
|
case status.Status == string(internal.StatusConnected) && !s.connected:
|
||||||
s.connected = true
|
s.connected = true
|
||||||
s.sendNotification = true
|
s.sendNotification = true
|
||||||
if s.isUpdateIconActive {
|
if s.isUpdateIconActive {
|
||||||
@@ -921,6 +923,7 @@ func (s *serviceClient) updateStatus() error {
|
|||||||
s.mUp.Disable()
|
s.mUp.Disable()
|
||||||
s.mDown.Enable()
|
s.mDown.Enable()
|
||||||
s.mNetworks.Enable()
|
s.mNetworks.Enable()
|
||||||
|
s.mExitNode.Enable()
|
||||||
go s.updateExitNodes()
|
go s.updateExitNodes()
|
||||||
systrayIconState = true
|
systrayIconState = true
|
||||||
case status.Status == string(internal.StatusConnecting):
|
case status.Status == string(internal.StatusConnecting):
|
||||||
@@ -1274,19 +1277,22 @@ func (s *serviceClient) checkAndUpdateFeatures() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.updateIndicationLock.Lock()
|
||||||
|
defer s.updateIndicationLock.Unlock()
|
||||||
|
|
||||||
// Update settings menu based on current features
|
// Update settings menu based on current features
|
||||||
if features != nil && features.DisableUpdateSettings {
|
settingsEnabled := features == nil || !features.DisableUpdateSettings
|
||||||
s.setSettingsEnabled(false)
|
if s.settingsEnabled != settingsEnabled {
|
||||||
} else {
|
s.settingsEnabled = settingsEnabled
|
||||||
s.setSettingsEnabled(true)
|
s.setSettingsEnabled(settingsEnabled)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update profile menu based on current features
|
// Update profile menu based on current features
|
||||||
if s.mProfile != nil {
|
if s.mProfile != nil {
|
||||||
if features != nil && features.DisableProfiles {
|
profilesEnabled := features == nil || !features.DisableProfiles
|
||||||
s.mProfile.setEnabled(false)
|
if s.profilesEnabled != profilesEnabled {
|
||||||
} else {
|
s.profilesEnabled = profilesEnabled
|
||||||
s.mProfile.setEnabled(true)
|
s.mProfile.setEnabled(profilesEnabled)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ func (s *serviceClient) getWindowsFontFilePath() string {
|
|||||||
"chr-CHER-US": "Gadugi.ttf",
|
"chr-CHER-US": "Gadugi.ttf",
|
||||||
"zh-HK": "Segoeui.ttf",
|
"zh-HK": "Segoeui.ttf",
|
||||||
"zh-TW": "Segoeui.ttf",
|
"zh-TW": "Segoeui.ttf",
|
||||||
"ja-JP": "Yugothm.ttc",
|
|
||||||
"km-KH": "Leelawui.ttf",
|
"km-KH": "Leelawui.ttf",
|
||||||
"ko-KR": "Malgun.ttf",
|
"ko-KR": "Malgun.ttf",
|
||||||
"th-TH": "Leelawui.ttf",
|
"th-TH": "Leelawui.ttf",
|
||||||
|
|||||||
17
go.mod
17
go.mod
@@ -22,7 +22,7 @@ require (
|
|||||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||||
google.golang.org/grpc v1.73.0
|
google.golang.org/grpc v1.75.0
|
||||||
google.golang.org/protobuf v1.36.8
|
google.golang.org/protobuf v1.36.8
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||||
)
|
)
|
||||||
@@ -41,6 +41,7 @@ require (
|
|||||||
github.com/coder/websocket v1.8.13
|
github.com/coder/websocket v1.8.13
|
||||||
github.com/coreos/go-iptables v0.7.0
|
github.com/coreos/go-iptables v0.7.0
|
||||||
github.com/creack/pty v1.1.18
|
github.com/creack/pty v1.1.18
|
||||||
|
github.com/dexidp/dex/api/v2 v2.4.0
|
||||||
github.com/eko/gocache/lib/v4 v4.2.0
|
github.com/eko/gocache/lib/v4 v4.2.0
|
||||||
github.com/eko/gocache/store/go_cache/v4 v4.2.2
|
github.com/eko/gocache/store/go_cache/v4 v4.2.2
|
||||||
github.com/eko/gocache/store/redis/v4 v4.2.2
|
github.com/eko/gocache/store/redis/v4 v4.2.2
|
||||||
@@ -97,10 +98,10 @@ require (
|
|||||||
github.com/yusufpapurcu/wmi v1.2.4
|
github.com/yusufpapurcu/wmi v1.2.4
|
||||||
github.com/zcalusic/sysinfo v1.1.3
|
github.com/zcalusic/sysinfo v1.1.3
|
||||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0
|
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0
|
||||||
go.opentelemetry.io/otel v1.35.0
|
go.opentelemetry.io/otel v1.37.0
|
||||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
|
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
|
||||||
go.opentelemetry.io/otel/metric v1.35.0
|
go.opentelemetry.io/otel/metric v1.37.0
|
||||||
go.opentelemetry.io/otel/sdk/metric v1.35.0
|
go.opentelemetry.io/otel/sdk/metric v1.37.0
|
||||||
go.uber.org/mock v0.5.0
|
go.uber.org/mock v0.5.0
|
||||||
go.uber.org/zap v1.27.0
|
go.uber.org/zap v1.27.0
|
||||||
goauthentik.io/api/v3 v3.2023051.3
|
goauthentik.io/api/v3 v3.2023051.3
|
||||||
@@ -124,7 +125,7 @@ require (
|
|||||||
require (
|
require (
|
||||||
cloud.google.com/go/auth v0.3.0 // indirect
|
cloud.google.com/go/auth v0.3.0 // indirect
|
||||||
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
|
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
|
||||||
cloud.google.com/go/compute/metadata v0.6.0 // indirect
|
cloud.google.com/go/compute/metadata v0.7.0 // indirect
|
||||||
dario.cat/mergo v1.0.0 // indirect
|
dario.cat/mergo v1.0.0 // indirect
|
||||||
filippo.io/edwards25519 v1.1.0 // indirect
|
filippo.io/edwards25519 v1.1.0 // indirect
|
||||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
|
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
|
||||||
@@ -170,7 +171,7 @@ require (
|
|||||||
github.com/fyne-io/oksvg v0.2.0 // indirect
|
github.com/fyne-io/oksvg v0.2.0 // indirect
|
||||||
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
|
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
|
||||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
|
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
|
||||||
github.com/go-logr/logr v1.4.2 // indirect
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
github.com/go-ole/go-ole v1.3.0 // indirect
|
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||||
@@ -248,8 +249,8 @@ require (
|
|||||||
go.opencensus.io v0.24.0 // indirect
|
go.opencensus.io v0.24.0 // indirect
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
|
||||||
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
|
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
|
||||||
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
go.opentelemetry.io/otel/trace v1.37.0 // indirect
|
||||||
go.uber.org/multierr v1.11.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
golang.org/x/image v0.33.0 // indirect
|
golang.org/x/image v0.33.0 // indirect
|
||||||
golang.org/x/text v0.31.0 // indirect
|
golang.org/x/text v0.31.0 // indirect
|
||||||
|
|||||||
40
go.sum
40
go.sum
@@ -4,8 +4,8 @@ cloud.google.com/go/auth v0.3.0/go.mod h1:lBv6NKTWp8E3LPzmO1TbiiRKc4drLOfHsgmlH9
|
|||||||
cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4=
|
cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4=
|
||||||
cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q=
|
cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q=
|
||||||
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||||
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
|
cloud.google.com/go/compute/metadata v0.7.0 h1:PBWF+iiAerVNe8UCHxdOt6eHLVc3ydFeOCw78U8ytSU=
|
||||||
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
|
cloud.google.com/go/compute/metadata v0.7.0/go.mod h1:j5MvL9PprKL39t166CoB1uVHfQMs4tFQZZcKwksXUjo=
|
||||||
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
|
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
|
||||||
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
|
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
|
||||||
dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
|
dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
|
||||||
@@ -117,6 +117,8 @@ github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70J
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dexidp/dex/api/v2 v2.4.0 h1:gNba7n6BKVp8X4Jp24cxYn5rIIGhM6kDOXcZoL6tr9A=
|
||||||
|
github.com/dexidp/dex/api/v2 v2.4.0/go.mod h1:/p550ADvFFh7K95VmhUD+jgm15VdaNnab9td8DHOpyI=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
@@ -164,8 +166,8 @@ github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71/go.mod h1:9YTyiznxEY1fVin
|
|||||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+xzxf1jTJKMKZw3H0swfWk9RpWbBbDK5+0=
|
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+xzxf1jTJKMKZw3H0swfWk9RpWbBbDK5+0=
|
||||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||||
@@ -561,22 +563,22 @@ go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.4
|
|||||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0=
|
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0=
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI=
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI=
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc=
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc=
|
||||||
go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
|
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
|
||||||
go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y=
|
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
|
||||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s=
|
go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s=
|
||||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o=
|
go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o=
|
||||||
go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M=
|
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
|
||||||
go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE=
|
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
|
||||||
go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY=
|
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
|
||||||
go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg=
|
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
|
||||||
go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o=
|
go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc=
|
||||||
go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w=
|
go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps=
|
||||||
go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs=
|
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
|
||||||
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
|
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
|
||||||
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
|
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
|
||||||
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
|
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
|
||||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||||
@@ -761,6 +763,8 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY
|
|||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||||
|
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||||
|
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||||
google.golang.org/api v0.177.0 h1:8a0p/BbPa65GlqGWtUKxot4p0TV8OGOfyTjtmkXNXmk=
|
google.golang.org/api v0.177.0 h1:8a0p/BbPa65GlqGWtUKxot4p0TV8OGOfyTjtmkXNXmk=
|
||||||
google.golang.org/api v0.177.0/go.mod h1:srbhue4MLjkjbkux5p3dw/ocYOSZTaIEvf7bCOnFQDw=
|
google.golang.org/api v0.177.0/go.mod h1:srbhue4MLjkjbkux5p3dw/ocYOSZTaIEvf7bCOnFQDw=
|
||||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||||
@@ -770,8 +774,8 @@ google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoA
|
|||||||
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
|
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
|
||||||
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
|
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
|
||||||
google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ=
|
google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ=
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463 h1:hE3bRWtU6uceqlh4fhrSnUyjKHMKB9KrTLLG+bc0ddM=
|
google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7 h1:FiusG7LWj+4byqhbvmB+Q93B/mOxJLN2DTozDuZm4EU=
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463/go.mod h1:U90ffi8eUL9MwPcrJylN5+Mk2v3vuPDptd5yyNUiRR8=
|
google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:kXqgZtrWaf6qS3jZOCnCH7WYfrvFjkC51bM8fz3RsCA=
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 h1:pFyd6EwwL2TqFf8emdthzeX+gZE1ElRq3iM8pui4KBY=
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 h1:pFyd6EwwL2TqFf8emdthzeX+gZE1ElRq3iM8pui4KBY=
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
|
||||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||||
@@ -779,8 +783,8 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac
|
|||||||
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
|
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
|
||||||
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
|
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
|
||||||
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
|
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
|
||||||
google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok=
|
google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4=
|
||||||
google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc=
|
google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
|
||||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||||
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
||||||
|
|||||||
@@ -53,7 +53,8 @@ services:
|
|||||||
command: [
|
command: [
|
||||||
"--cert-file", "$NETBIRD_MGMT_API_CERT_FILE",
|
"--cert-file", "$NETBIRD_MGMT_API_CERT_FILE",
|
||||||
"--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE",
|
"--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE",
|
||||||
"--log-file", "console"
|
"--log-file", "console",
|
||||||
|
"--port", "80"
|
||||||
]
|
]
|
||||||
|
|
||||||
# Relay
|
# Relay
|
||||||
|
|||||||
554
infrastructure_files/getting-started-with-dex.sh
Executable file
554
infrastructure_files/getting-started-with-dex.sh
Executable file
@@ -0,0 +1,554 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# NetBird Getting Started with Dex IDP
|
||||||
|
# This script sets up NetBird with Dex as the identity provider
|
||||||
|
|
||||||
|
# Sed pattern to strip base64 padding characters
|
||||||
|
SED_STRIP_PADDING='s/=//g'
|
||||||
|
|
||||||
|
check_docker_compose() {
|
||||||
|
if command -v docker-compose &> /dev/null
|
||||||
|
then
|
||||||
|
echo "docker-compose"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
if docker compose --help &> /dev/null
|
||||||
|
then
|
||||||
|
echo "docker compose"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "docker-compose is not installed or not in PATH. Please follow the steps from the official guide: https://docs.docker.com/engine/install/" > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
check_jq() {
|
||||||
|
if ! command -v jq &> /dev/null
|
||||||
|
then
|
||||||
|
echo "jq is not installed or not in PATH, please install with your package manager. e.g. sudo apt install jq" > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
get_main_ip_address() {
|
||||||
|
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||||
|
interface=$(route -n get default | grep 'interface:' | awk '{print $2}')
|
||||||
|
ip_address=$(ifconfig "$interface" | grep 'inet ' | awk '{print $2}')
|
||||||
|
else
|
||||||
|
interface=$(ip route | grep default | awk '{print $5}' | head -n 1)
|
||||||
|
ip_address=$(ip addr show "$interface" | grep 'inet ' | awk '{print $2}' | cut -d'/' -f1)
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "$ip_address"
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
check_nb_domain() {
|
||||||
|
DOMAIN=$1
|
||||||
|
if [[ "$DOMAIN-x" == "-x" ]]; then
|
||||||
|
echo "The NETBIRD_DOMAIN variable cannot be empty." > /dev/stderr
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$DOMAIN" == "netbird.example.com" ]]; then
|
||||||
|
echo "The NETBIRD_DOMAIN cannot be netbird.example.com" > /dev/stderr
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
read_nb_domain() {
|
||||||
|
READ_NETBIRD_DOMAIN=""
|
||||||
|
echo -n "Enter the domain you want to use for NetBird (e.g. netbird.my-domain.com): " > /dev/stderr
|
||||||
|
read -r READ_NETBIRD_DOMAIN < /dev/tty
|
||||||
|
if ! check_nb_domain "$READ_NETBIRD_DOMAIN"; then
|
||||||
|
read_nb_domain
|
||||||
|
fi
|
||||||
|
echo "$READ_NETBIRD_DOMAIN"
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
get_turn_external_ip() {
|
||||||
|
TURN_EXTERNAL_IP_CONFIG="#external-ip="
|
||||||
|
IP=$(curl -s -4 https://jsonip.com | jq -r '.ip')
|
||||||
|
if [[ "x-$IP" != "x-" ]]; then
|
||||||
|
TURN_EXTERNAL_IP_CONFIG="external-ip=$IP"
|
||||||
|
fi
|
||||||
|
echo "$TURN_EXTERNAL_IP_CONFIG"
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
wait_dex() {
|
||||||
|
set +e
|
||||||
|
echo -n "Waiting for Dex to become ready (via $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN)"
|
||||||
|
counter=1
|
||||||
|
while true; do
|
||||||
|
# Check Dex through Caddy proxy (also validates TLS is working)
|
||||||
|
if curl -sk -f -o /dev/null "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex/.well-known/openid-configuration" 2>/dev/null; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
if [[ $counter -eq 60 ]]; then
|
||||||
|
echo ""
|
||||||
|
echo "Taking too long. Checking logs..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND logs --tail=20 caddy
|
||||||
|
$DOCKER_COMPOSE_COMMAND logs --tail=20 dex
|
||||||
|
fi
|
||||||
|
echo -n " ."
|
||||||
|
sleep 2
|
||||||
|
counter=$((counter + 1))
|
||||||
|
done
|
||||||
|
echo " done"
|
||||||
|
set -e
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
init_environment() {
|
||||||
|
CADDY_SECURE_DOMAIN=""
|
||||||
|
NETBIRD_PORT=80
|
||||||
|
NETBIRD_HTTP_PROTOCOL="http"
|
||||||
|
NETBIRD_RELAY_PROTO="rel"
|
||||||
|
TURN_USER="self"
|
||||||
|
TURN_PASSWORD=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING")
|
||||||
|
NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING")
|
||||||
|
TURN_MIN_PORT=49152
|
||||||
|
TURN_MAX_PORT=65535
|
||||||
|
TURN_EXTERNAL_IP_CONFIG=$(get_turn_external_ip)
|
||||||
|
|
||||||
|
# Generate secrets for Dex
|
||||||
|
DEX_DASHBOARD_CLIENT_SECRET=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING")
|
||||||
|
|
||||||
|
# Generate admin password
|
||||||
|
NETBIRD_ADMIN_PASSWORD=$(openssl rand -base64 16 | sed "$SED_STRIP_PADDING")
|
||||||
|
|
||||||
|
if ! check_nb_domain "$NETBIRD_DOMAIN"; then
|
||||||
|
NETBIRD_DOMAIN=$(read_nb_domain)
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$NETBIRD_DOMAIN" == "use-ip" ]]; then
|
||||||
|
NETBIRD_DOMAIN=$(get_main_ip_address)
|
||||||
|
else
|
||||||
|
NETBIRD_PORT=443
|
||||||
|
CADDY_SECURE_DOMAIN=", $NETBIRD_DOMAIN:$NETBIRD_PORT"
|
||||||
|
NETBIRD_HTTP_PROTOCOL="https"
|
||||||
|
NETBIRD_RELAY_PROTO="rels"
|
||||||
|
fi
|
||||||
|
|
||||||
|
check_jq
|
||||||
|
|
||||||
|
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||||
|
|
||||||
|
if [[ -f dex.yaml ]]; then
|
||||||
|
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
|
||||||
|
echo "You can use the following commands:"
|
||||||
|
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
|
||||||
|
echo " rm -f docker-compose.yml Caddyfile dex.yaml dashboard.env turnserver.conf management.json relay.env"
|
||||||
|
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo Rendering initial files...
|
||||||
|
render_docker_compose > docker-compose.yml
|
||||||
|
render_caddyfile > Caddyfile
|
||||||
|
render_dex_config > dex.yaml
|
||||||
|
render_dashboard_env > dashboard.env
|
||||||
|
render_management_json > management.json
|
||||||
|
render_turn_server_conf > turnserver.conf
|
||||||
|
render_relay_env > relay.env
|
||||||
|
|
||||||
|
echo -e "\nStarting Dex IDP\n"
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d caddy dex
|
||||||
|
|
||||||
|
# Wait for Dex to be ready (through caddy proxy)
|
||||||
|
sleep 3
|
||||||
|
wait_dex
|
||||||
|
|
||||||
|
echo -e "\nStarting NetBird services\n"
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d
|
||||||
|
|
||||||
|
echo -e "\nDone!\n"
|
||||||
|
echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
||||||
|
echo ""
|
||||||
|
echo "Login with the following credentials:"
|
||||||
|
echo "Email: admin@$NETBIRD_DOMAIN" | tee .env
|
||||||
|
echo "Password: $NETBIRD_ADMIN_PASSWORD" | tee -a .env
|
||||||
|
echo ""
|
||||||
|
echo "Dex admin UI is not available (Dex has no built-in UI)."
|
||||||
|
echo "To add more users, edit dex.yaml and restart: $DOCKER_COMPOSE_COMMAND restart dex"
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
render_caddyfile() {
|
||||||
|
cat <<EOF
|
||||||
|
{
|
||||||
|
debug
|
||||||
|
servers :80,:443 {
|
||||||
|
protocols h1 h2c h2 h3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(security_headers) {
|
||||||
|
header * {
|
||||||
|
Strict-Transport-Security "max-age=3600; includeSubDomains; preload"
|
||||||
|
X-Content-Type-Options "nosniff"
|
||||||
|
X-Frame-Options "SAMEORIGIN"
|
||||||
|
X-XSS-Protection "1; mode=block"
|
||||||
|
-Server
|
||||||
|
Referrer-Policy strict-origin-when-cross-origin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
:80${CADDY_SECURE_DOMAIN} {
|
||||||
|
import security_headers
|
||||||
|
# Relay
|
||||||
|
reverse_proxy /relay* relay:80
|
||||||
|
# Signal
|
||||||
|
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
|
||||||
|
# Management
|
||||||
|
reverse_proxy /api/* management:80
|
||||||
|
reverse_proxy /management.ManagementService/* h2c://management:80
|
||||||
|
# Dex
|
||||||
|
reverse_proxy /dex/* dex:5556
|
||||||
|
# Dashboard
|
||||||
|
reverse_proxy /* dashboard:80
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
render_dex_config() {
|
||||||
|
# Generate bcrypt hash of the admin password
|
||||||
|
# Using a simple approach - htpasswd or python if available
|
||||||
|
ADMIN_PASSWORD_HASH=""
|
||||||
|
if command -v htpasswd &> /dev/null; then
|
||||||
|
ADMIN_PASSWORD_HASH=$(htpasswd -bnBC 10 "" "$NETBIRD_ADMIN_PASSWORD" | tr -d ':\n')
|
||||||
|
elif command -v python3 &> /dev/null; then
|
||||||
|
ADMIN_PASSWORD_HASH=$(python3 -c "import bcrypt; print(bcrypt.hashpw('$NETBIRD_ADMIN_PASSWORD'.encode(), bcrypt.gensalt(rounds=10)).decode())" 2>/dev/null || echo "")
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Fallback to a known hash if we can't generate one
|
||||||
|
if [[ -z "$ADMIN_PASSWORD_HASH" ]]; then
|
||||||
|
# This is hash of "password" - user should change it
|
||||||
|
ADMIN_PASSWORD_HASH='$2a$10$2b2cU8CPhOTaGrs1HRQuAueS7JTT5ZHsHSzYiFPm1leZck7Mc8T4W'
|
||||||
|
NETBIRD_ADMIN_PASSWORD="password"
|
||||||
|
echo "Warning: Could not generate password hash. Using default password: password. Please change it in dex.yaml" > /dev/stderr
|
||||||
|
fi
|
||||||
|
|
||||||
|
cat <<EOF
|
||||||
|
# Dex configuration for NetBird
|
||||||
|
# Generated by getting-started-with-dex.sh
|
||||||
|
|
||||||
|
issuer: $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex
|
||||||
|
|
||||||
|
storage:
|
||||||
|
type: sqlite3
|
||||||
|
config:
|
||||||
|
file: /var/dex/dex.db
|
||||||
|
|
||||||
|
web:
|
||||||
|
http: 0.0.0.0:5556
|
||||||
|
|
||||||
|
# gRPC API for user management (used by NetBird IDP manager)
|
||||||
|
grpc:
|
||||||
|
addr: 0.0.0.0:5557
|
||||||
|
|
||||||
|
oauth2:
|
||||||
|
skipApprovalScreen: true
|
||||||
|
|
||||||
|
# Static OAuth2 clients for NetBird
|
||||||
|
staticClients:
|
||||||
|
# Dashboard client
|
||||||
|
- id: netbird-dashboard
|
||||||
|
name: NetBird Dashboard
|
||||||
|
secret: $DEX_DASHBOARD_CLIENT_SECRET
|
||||||
|
redirectURIs:
|
||||||
|
- $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/nb-auth
|
||||||
|
- $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/nb-silent-auth
|
||||||
|
|
||||||
|
# CLI client (public - uses PKCE)
|
||||||
|
- id: netbird-cli
|
||||||
|
name: NetBird CLI
|
||||||
|
public: true
|
||||||
|
redirectURIs:
|
||||||
|
- http://localhost:53000/
|
||||||
|
- http://localhost:54000/
|
||||||
|
|
||||||
|
# Enable password database for static users
|
||||||
|
enablePasswordDB: true
|
||||||
|
|
||||||
|
# Static users - add more users here as needed
|
||||||
|
staticPasswords:
|
||||||
|
- email: "admin@$NETBIRD_DOMAIN"
|
||||||
|
hash: "$ADMIN_PASSWORD_HASH"
|
||||||
|
username: "admin"
|
||||||
|
userID: "$(uuidgen 2>/dev/null || cat /proc/sys/kernel/random/uuid 2>/dev/null || echo "admin-user-id-001")"
|
||||||
|
|
||||||
|
# Optional: Add external identity provider connectors
|
||||||
|
# connectors:
|
||||||
|
# - type: github
|
||||||
|
# id: github
|
||||||
|
# name: GitHub
|
||||||
|
# config:
|
||||||
|
# clientID: \$GITHUB_CLIENT_ID
|
||||||
|
# clientSecret: \$GITHUB_CLIENT_SECRET
|
||||||
|
# redirectURI: $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex/callback
|
||||||
|
#
|
||||||
|
# - type: ldap
|
||||||
|
# id: ldap
|
||||||
|
# name: LDAP
|
||||||
|
# config:
|
||||||
|
# host: ldap.example.com:636
|
||||||
|
# insecureNoSSL: false
|
||||||
|
# bindDN: cn=admin,dc=example,dc=com
|
||||||
|
# bindPW: admin
|
||||||
|
# userSearch:
|
||||||
|
# baseDN: ou=users,dc=example,dc=com
|
||||||
|
# filter: "(objectClass=person)"
|
||||||
|
# username: uid
|
||||||
|
# idAttr: uid
|
||||||
|
# emailAttr: mail
|
||||||
|
# nameAttr: cn
|
||||||
|
EOF
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
render_turn_server_conf() {
|
||||||
|
cat <<EOF
|
||||||
|
listening-port=3478
|
||||||
|
$TURN_EXTERNAL_IP_CONFIG
|
||||||
|
tls-listening-port=5349
|
||||||
|
min-port=$TURN_MIN_PORT
|
||||||
|
max-port=$TURN_MAX_PORT
|
||||||
|
fingerprint
|
||||||
|
lt-cred-mech
|
||||||
|
user=$TURN_USER:$TURN_PASSWORD
|
||||||
|
realm=wiretrustee.com
|
||||||
|
cert=/etc/coturn/certs/cert.pem
|
||||||
|
pkey=/etc/coturn/private/privkey.pem
|
||||||
|
log-file=stdout
|
||||||
|
no-software-attribute
|
||||||
|
pidfile="/var/tmp/turnserver.pid"
|
||||||
|
no-cli
|
||||||
|
EOF
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
render_management_json() {
|
||||||
|
cat <<EOF
|
||||||
|
{
|
||||||
|
"Stuns": [
|
||||||
|
{
|
||||||
|
"Proto": "udp",
|
||||||
|
"URI": "stun:$NETBIRD_DOMAIN:3478"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"Relay": {
|
||||||
|
"Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"],
|
||||||
|
"CredentialsTTL": "24h",
|
||||||
|
"Secret": "$NETBIRD_RELAY_AUTH_SECRET"
|
||||||
|
},
|
||||||
|
"Signal": {
|
||||||
|
"Proto": "$NETBIRD_HTTP_PROTOCOL",
|
||||||
|
"URI": "$NETBIRD_DOMAIN:$NETBIRD_PORT"
|
||||||
|
},
|
||||||
|
"HttpConfig": {
|
||||||
|
"AuthIssuer": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex",
|
||||||
|
"AuthAudience": "netbird-dashboard",
|
||||||
|
"OIDCConfigEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex/.well-known/openid-configuration"
|
||||||
|
},
|
||||||
|
"IdpManagerConfig": {
|
||||||
|
"ManagerType": "dex",
|
||||||
|
"ClientConfig": {
|
||||||
|
"Issuer": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex"
|
||||||
|
},
|
||||||
|
"ExtraConfig": {
|
||||||
|
"GRPCAddr": "dex:5557"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"DeviceAuthorizationFlow": {
|
||||||
|
"Provider": "hosted",
|
||||||
|
"ProviderConfig": {
|
||||||
|
"Audience": "netbird-cli",
|
||||||
|
"ClientID": "netbird-cli",
|
||||||
|
"Scope": "openid profile email offline_access",
|
||||||
|
"DeviceAuthEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex/device/code",
|
||||||
|
"TokenEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex/token"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"PKCEAuthorizationFlow": {
|
||||||
|
"ProviderConfig": {
|
||||||
|
"Audience": "netbird-cli",
|
||||||
|
"ClientID": "netbird-cli",
|
||||||
|
"Scope": "openid profile email offline_access",
|
||||||
|
"RedirectURLs": ["http://localhost:53000/", "http://localhost:54000/"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
render_dashboard_env() {
|
||||||
|
cat <<EOF
|
||||||
|
# Endpoints
|
||||||
|
NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN
|
||||||
|
NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN
|
||||||
|
# OIDC
|
||||||
|
AUTH_AUDIENCE=netbird-dashboard
|
||||||
|
AUTH_CLIENT_ID=netbird-dashboard
|
||||||
|
AUTH_CLIENT_SECRET=$DEX_DASHBOARD_CLIENT_SECRET
|
||||||
|
AUTH_AUTHORITY=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex
|
||||||
|
USE_AUTH0=false
|
||||||
|
AUTH_SUPPORTED_SCOPES=openid profile email offline_access
|
||||||
|
AUTH_REDIRECT_URI=/nb-auth
|
||||||
|
AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
|
||||||
|
# SSL
|
||||||
|
NGINX_SSL_PORT=443
|
||||||
|
# Letsencrypt
|
||||||
|
LETSENCRYPT_DOMAIN=none
|
||||||
|
EOF
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
render_relay_env() {
|
||||||
|
cat <<EOF
|
||||||
|
NB_LOG_LEVEL=info
|
||||||
|
NB_LISTEN_ADDRESS=:80
|
||||||
|
NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT
|
||||||
|
NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET
|
||||||
|
EOF
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
render_docker_compose() {
|
||||||
|
cat <<EOF
|
||||||
|
services:
|
||||||
|
# Caddy reverse proxy
|
||||||
|
caddy:
|
||||||
|
image: caddy
|
||||||
|
container_name: netbird-caddy
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [netbird]
|
||||||
|
ports:
|
||||||
|
- '443:443'
|
||||||
|
- '443:443/udp'
|
||||||
|
- '80:80'
|
||||||
|
volumes:
|
||||||
|
- netbird_caddy_data:/data
|
||||||
|
- ./Caddyfile:/etc/caddy/Caddyfile
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "500m"
|
||||||
|
max-file: "2"
|
||||||
|
|
||||||
|
# Dex - identity provider
|
||||||
|
dex:
|
||||||
|
image: ghcr.io/dexidp/dex:v2.38.0
|
||||||
|
container_name: netbird-dex
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [netbird]
|
||||||
|
volumes:
|
||||||
|
- ./dex.yaml:/etc/dex/config.docker.yaml:ro
|
||||||
|
- netbird_dex_data:/var/dex
|
||||||
|
command: ["dex", "serve", "/etc/dex/config.docker.yaml"]
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "500m"
|
||||||
|
max-file: "2"
|
||||||
|
|
||||||
|
# UI dashboard
|
||||||
|
dashboard:
|
||||||
|
image: netbirdio/dashboard:latest
|
||||||
|
container_name: netbird-dashboard
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [netbird]
|
||||||
|
env_file:
|
||||||
|
- ./dashboard.env
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "500m"
|
||||||
|
max-file: "2"
|
||||||
|
|
||||||
|
# Signal
|
||||||
|
signal:
|
||||||
|
image: netbirdio/signal:latest
|
||||||
|
container_name: netbird-signal
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [netbird]
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "500m"
|
||||||
|
max-file: "2"
|
||||||
|
|
||||||
|
# Relay
|
||||||
|
relay:
|
||||||
|
image: netbirdio/relay:latest
|
||||||
|
container_name: netbird-relay
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [netbird]
|
||||||
|
env_file:
|
||||||
|
- ./relay.env
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "500m"
|
||||||
|
max-file: "2"
|
||||||
|
|
||||||
|
# Management
|
||||||
|
management:
|
||||||
|
image: netbirdio/management:latest
|
||||||
|
container_name: netbird-management
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [netbird]
|
||||||
|
volumes:
|
||||||
|
- netbird_management:/var/lib/netbird
|
||||||
|
- ./management.json:/etc/netbird/management.json
|
||||||
|
command: [
|
||||||
|
"--port", "80",
|
||||||
|
"--log-file", "console",
|
||||||
|
"--log-level", "info",
|
||||||
|
"--disable-anonymous-metrics=false",
|
||||||
|
"--single-account-mode-domain=netbird.selfhosted",
|
||||||
|
"--dns-domain=netbird.selfhosted",
|
||||||
|
"--idp-sign-key-refresh-enabled",
|
||||||
|
]
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "500m"
|
||||||
|
max-file: "2"
|
||||||
|
|
||||||
|
# Coturn, AKA TURN server
|
||||||
|
coturn:
|
||||||
|
image: coturn/coturn
|
||||||
|
container_name: netbird-coturn
|
||||||
|
restart: unless-stopped
|
||||||
|
volumes:
|
||||||
|
- ./turnserver.conf:/etc/turnserver.conf:ro
|
||||||
|
network_mode: host
|
||||||
|
command:
|
||||||
|
- -c /etc/turnserver.conf
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "500m"
|
||||||
|
max-file: "2"
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
netbird_caddy_data:
|
||||||
|
netbird_dex_data:
|
||||||
|
netbird_management:
|
||||||
|
|
||||||
|
networks:
|
||||||
|
netbird:
|
||||||
|
EOF
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
init_environment
|
||||||
@@ -178,6 +178,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
if c.experimentalNetworkMap(accountID) {
|
||||||
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
||||||
@@ -224,7 +225,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
if c.experimentalNetworkMap(accountID) {
|
if c.experimentalNetworkMap(accountID) {
|
||||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||||
} else {
|
} else {
|
||||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
|
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||||
@@ -320,6 +321,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
|
||||||
postureChecks, err := c.getPeerPostureChecks(account, peerId)
|
postureChecks, err := c.getPeerPostureChecks(account, peerId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -338,7 +340,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
if c.experimentalNetworkMap(accountId) {
|
if c.experimentalNetworkMap(accountId) {
|
||||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||||
} else {
|
} else {
|
||||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
|
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
@@ -445,7 +447,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
if c.experimentalNetworkMap(accountID) {
|
if c.experimentalNetworkMap(accountID) {
|
||||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||||
} else {
|
} else {
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics)
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
@@ -811,7 +813,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
|||||||
if c.experimentalNetworkMap(peer.AccountID) {
|
if c.experimentalNetworkMap(peer.AccountID) {
|
||||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
|
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
|
||||||
} else {
|
} else {
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
|
|||||||
@@ -158,5 +158,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,10 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||||
@@ -16,6 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/shared/sshauth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||||
@@ -84,15 +88,15 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken
|
|||||||
return nbConfig
|
return nbConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.PeerConfig {
|
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, enableSSH bool) *proto.PeerConfig {
|
||||||
netmask, _ := network.Net.Mask.Size()
|
netmask, _ := network.Net.Mask.Size()
|
||||||
fqdn := peer.FQDN(dnsName)
|
fqdn := peer.FQDN(dnsName)
|
||||||
|
|
||||||
sshConfig := &proto.SSHConfig{
|
sshConfig := &proto.SSHConfig{
|
||||||
SshEnabled: peer.SSHEnabled,
|
SshEnabled: peer.SSHEnabled || enableSSH,
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.SSHEnabled {
|
if sshConfig.SshEnabled {
|
||||||
sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig)
|
sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,12 +114,12 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
|
|||||||
|
|
||||||
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
||||||
response := &proto.SyncResponse{
|
response := &proto.SyncResponse{
|
||||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
|
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||||
NetworkMap: &proto.NetworkMap{
|
NetworkMap: &proto.NetworkMap{
|
||||||
Serial: networkMap.Network.CurrentSerial(),
|
Serial: networkMap.Network.CurrentSerial(),
|
||||||
Routes: toProtocolRoutes(networkMap.Routes),
|
Routes: toProtocolRoutes(networkMap.Routes),
|
||||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
|
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||||
},
|
},
|
||||||
Checks: toProtocolChecks(ctx, checks),
|
Checks: toProtocolChecks(ctx, checks),
|
||||||
}
|
}
|
||||||
@@ -151,9 +155,45 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
|||||||
response.NetworkMap.ForwardingRules = forwardingRules
|
response.NetworkMap.ForwardingRules = forwardingRules
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if networkMap.AuthorizedUsers != nil {
|
||||||
|
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
||||||
|
userIDClaim := auth.DefaultUserIDClaim
|
||||||
|
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||||
|
userIDClaim = httpConfig.AuthUserIDClaim
|
||||||
|
}
|
||||||
|
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
|
||||||
|
}
|
||||||
|
|
||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
|
||||||
|
userIDToIndex := make(map[string]uint32)
|
||||||
|
var hashedUsers [][]byte
|
||||||
|
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
|
||||||
|
|
||||||
|
for machineUser, users := range authorizedUsers {
|
||||||
|
indexes := make([]uint32, 0, len(users))
|
||||||
|
for userID := range users {
|
||||||
|
idx, exists := userIDToIndex[userID]
|
||||||
|
if !exists {
|
||||||
|
hash, err := sshauth.HashUserID(userID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idx = uint32(len(hashedUsers))
|
||||||
|
userIDToIndex[userID] = idx
|
||||||
|
hashedUsers = append(hashedUsers, hash[:])
|
||||||
|
}
|
||||||
|
indexes = append(indexes, idx)
|
||||||
|
}
|
||||||
|
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
|
||||||
|
}
|
||||||
|
|
||||||
|
return hashedUsers, machineUsers
|
||||||
|
}
|
||||||
|
|
||||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||||
for _, rPeer := range peers {
|
for _, rPeer := range peers {
|
||||||
dst = append(dst, &proto.RemotePeerConfig{
|
dst = append(dst, &proto.RemotePeerConfig{
|
||||||
|
|||||||
@@ -184,8 +184,14 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
realIP := getRealIP(ctx)
|
realIP := getRealIP(ctx)
|
||||||
sRealIP := realIP.String()
|
sRealIP := realIP.String()
|
||||||
peerMeta := extractPeerMeta(ctx, syncReq.GetMeta())
|
peerMeta := extractPeerMeta(ctx, syncReq.GetMeta())
|
||||||
|
userID, err := s.accountManager.GetUserIDByPeerKey(ctx, peerKey.String())
|
||||||
|
if err != nil {
|
||||||
|
s.syncSem.Add(-1)
|
||||||
|
return mapError(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
metahashed := metaHash(peerMeta, sRealIP)
|
metahashed := metaHash(peerMeta, sRealIP)
|
||||||
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
if userID == "" && !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||||
if s.appMetrics != nil {
|
if s.appMetrics != nil {
|
||||||
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
|
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
|
||||||
}
|
}
|
||||||
@@ -270,6 +276,8 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
unlock()
|
unlock()
|
||||||
unlock = nil
|
unlock = nil
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debugf("Sync took %s", time.Since(reqStart))
|
||||||
|
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
|
|
||||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||||
@@ -559,6 +567,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
|||||||
if s.appMetrics != nil {
|
if s.appMetrics != nil {
|
||||||
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
|
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
|
||||||
}
|
}
|
||||||
|
log.WithContext(ctx).Debugf("Login took %s", time.Since(reqStart))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if loginReq.GetMeta() == nil {
|
if loginReq.GetMeta() == nil {
|
||||||
@@ -635,7 +644,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
|
|||||||
// if peer has reached this point then it has logged in
|
// if peer has reached this point then it has logged in
|
||||||
loginResp := &proto.LoginResponse{
|
loginResp := &proto.LoginResponse{
|
||||||
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
|
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
|
||||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow),
|
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH),
|
||||||
Checks: toProtocolChecks(ctx, postureChecks),
|
Checks: toProtocolChecks(ctx, postureChecks),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1456,21 +1456,19 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if settings.GroupsPropagationEnabled {
|
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
|
||||||
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
|
if err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
|
||||||
|
|
||||||
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
|
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
||||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
|
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
|
||||||
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
|
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -2158,3 +2156,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) {
|
||||||
|
return am.Store.GetUserIDByPeerKey(ctx, store.LockingStrengthNone, peerKey)
|
||||||
|
}
|
||||||
|
|||||||
@@ -123,4 +123,5 @@ type Manager interface {
|
|||||||
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
|
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
|
||||||
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
||||||
GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
|
GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||||
|
GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -397,7 +397,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -427,7 +427,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
|||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
for _, groupID := range groupIDs {
|
for _, groupID := range groupIDs {
|
||||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthUpdate, accountID, groupID)
|
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
allErrors = errors.Join(allErrors, err)
|
allErrors = errors.Join(allErrors, err)
|
||||||
continue
|
continue
|
||||||
@@ -442,6 +442,10 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
|||||||
deletedGroups = append(deletedGroups, group)
|
deletedGroups = append(deletedGroups, group)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(groupIDsToDelete) == 0 {
|
||||||
|
return allErrors
|
||||||
|
}
|
||||||
|
|
||||||
if err = transaction.DeleteGroups(ctx, accountID, groupIDsToDelete); err != nil {
|
if err = transaction.DeleteGroups(ctx, accountID, groupIDsToDelete); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -299,7 +299,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
||||||
|
|
||||||
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
|
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
|
||||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||||
}
|
}
|
||||||
@@ -369,6 +369,9 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
|
|||||||
PortRanges: []types.RulePortRange{portRange},
|
PortRanges: []types.RulePortRange{portRange},
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
|
if protocol == types.PolicyRuleProtocolNetbirdSSH {
|
||||||
|
policy.Rules[0].AuthorizedUser = userAuth.UserId
|
||||||
|
}
|
||||||
|
|
||||||
_, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true)
|
_, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -449,6 +452,18 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
|
|||||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||||
Ephemeral: peer.Ephemeral,
|
Ephemeral: peer.Ephemeral,
|
||||||
|
LocalFlags: &api.PeerLocalFlags{
|
||||||
|
BlockInbound: &peer.Meta.Flags.BlockInbound,
|
||||||
|
BlockLanAccess: &peer.Meta.Flags.BlockLANAccess,
|
||||||
|
DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes,
|
||||||
|
DisableDns: &peer.Meta.Flags.DisableDNS,
|
||||||
|
DisableFirewall: &peer.Meta.Flags.DisableFirewall,
|
||||||
|
DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes,
|
||||||
|
LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled,
|
||||||
|
RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled,
|
||||||
|
RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive,
|
||||||
|
ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if !approved {
|
if !approved {
|
||||||
@@ -463,7 +478,6 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
|
|||||||
if osVersion == "" {
|
if osVersion == "" {
|
||||||
osVersion = peer.Meta.Core
|
osVersion = peer.Meta.Core
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.PeerBatch{
|
return &api.PeerBatch{
|
||||||
CreatedAt: peer.CreatedAt,
|
CreatedAt: peer.CreatedAt,
|
||||||
Id: peer.ID,
|
Id: peer.ID,
|
||||||
@@ -492,6 +506,18 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
|
|||||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||||
Ephemeral: peer.Ephemeral,
|
Ephemeral: peer.Ephemeral,
|
||||||
|
LocalFlags: &api.PeerLocalFlags{
|
||||||
|
BlockInbound: &peer.Meta.Flags.BlockInbound,
|
||||||
|
BlockLanAccess: &peer.Meta.Flags.BlockLANAccess,
|
||||||
|
DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes,
|
||||||
|
DisableDns: &peer.Meta.Flags.DisableDNS,
|
||||||
|
DisableFirewall: &peer.Meta.Flags.DisableFirewall,
|
||||||
|
DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes,
|
||||||
|
LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled,
|
||||||
|
RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled,
|
||||||
|
RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive,
|
||||||
|
ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -221,6 +221,8 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
|||||||
pr.Protocol = types.PolicyRuleProtocolUDP
|
pr.Protocol = types.PolicyRuleProtocolUDP
|
||||||
case api.PolicyRuleUpdateProtocolIcmp:
|
case api.PolicyRuleUpdateProtocolIcmp:
|
||||||
pr.Protocol = types.PolicyRuleProtocolICMP
|
pr.Protocol = types.PolicyRuleProtocolICMP
|
||||||
|
case api.PolicyRuleUpdateProtocolNetbirdSsh:
|
||||||
|
pr.Protocol = types.PolicyRuleProtocolNetbirdSSH
|
||||||
default:
|
default:
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
|
||||||
return
|
return
|
||||||
@@ -254,6 +256,17 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if pr.Protocol == types.PolicyRuleProtocolNetbirdSSH && rule.AuthorizedGroups != nil && len(*rule.AuthorizedGroups) != 0 {
|
||||||
|
for _, sourceGroupID := range pr.Sources {
|
||||||
|
_, ok := (*rule.AuthorizedGroups)[sourceGroupID]
|
||||||
|
if !ok {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "authorized group for netbird-ssh protocol should be specified for each source group"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pr.AuthorizedGroups = *rule.AuthorizedGroups
|
||||||
|
}
|
||||||
|
|
||||||
// validate policy object
|
// validate policy object
|
||||||
if pr.Protocol == types.PolicyRuleProtocolALL || pr.Protocol == types.PolicyRuleProtocolICMP {
|
if pr.Protocol == types.PolicyRuleProtocolALL || pr.Protocol == types.PolicyRuleProtocolICMP {
|
||||||
if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 {
|
if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 {
|
||||||
@@ -380,6 +393,11 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
|
|||||||
DestinationResource: r.DestinationResource.ToAPIResponse(),
|
DestinationResource: r.DestinationResource.ToAPIResponse(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(r.AuthorizedGroups) != 0 {
|
||||||
|
authorizedGroupsCopy := r.AuthorizedGroups
|
||||||
|
rule.AuthorizedGroups = &authorizedGroupsCopy
|
||||||
|
}
|
||||||
|
|
||||||
if len(r.Ports) != 0 {
|
if len(r.Ports) != 0 {
|
||||||
portsCopy := r.Ports
|
portsCopy := r.Ports
|
||||||
rule.Ports = &portsCopy
|
rule.Ports = &portsCopy
|
||||||
|
|||||||
445
management/server/idp/dex.go
Normal file
445
management/server/idp/dex.go
Normal file
@@ -0,0 +1,445 @@
|
|||||||
|
package idp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/api/v2"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/connectivity"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DexManager implements the Manager interface for Dex IDP.
|
||||||
|
// It uses Dex's gRPC API to manage users in the password database.
|
||||||
|
type DexManager struct {
|
||||||
|
grpcAddr string
|
||||||
|
httpClient ManagerHTTPClient
|
||||||
|
helper ManagerHelper
|
||||||
|
appMetrics telemetry.AppMetrics
|
||||||
|
mux sync.Mutex
|
||||||
|
conn *grpc.ClientConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// DexClientConfig Dex manager client configuration.
|
||||||
|
type DexClientConfig struct {
|
||||||
|
// GRPCAddr is the address of Dex's gRPC API (e.g., "localhost:5557")
|
||||||
|
GRPCAddr string
|
||||||
|
// Issuer is the Dex issuer URL (e.g., "https://dex.example.com/dex")
|
||||||
|
Issuer string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDexManager creates a new instance of DexManager.
|
||||||
|
func NewDexManager(config DexClientConfig, appMetrics telemetry.AppMetrics) (*DexManager, error) {
|
||||||
|
if config.GRPCAddr == "" {
|
||||||
|
return nil, fmt.Errorf("dex IdP configuration is incomplete, GRPCAddr is missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
Transport: httpTransport,
|
||||||
|
}
|
||||||
|
helper := JsonParser{}
|
||||||
|
|
||||||
|
return &DexManager{
|
||||||
|
grpcAddr: config.GRPCAddr,
|
||||||
|
httpClient: httpClient,
|
||||||
|
helper: helper,
|
||||||
|
appMetrics: appMetrics,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getConnection returns a gRPC connection to Dex, creating one if necessary.
|
||||||
|
// It also checks if an existing connection is still healthy and reconnects if needed.
|
||||||
|
func (dm *DexManager) getConnection(ctx context.Context) (*grpc.ClientConn, error) {
|
||||||
|
dm.mux.Lock()
|
||||||
|
defer dm.mux.Unlock()
|
||||||
|
|
||||||
|
if dm.conn != nil {
|
||||||
|
state := dm.conn.GetState()
|
||||||
|
// If connection is shutdown or in a transient failure, close and reconnect
|
||||||
|
if state == connectivity.Shutdown || state == connectivity.TransientFailure {
|
||||||
|
log.WithContext(ctx).Debugf("Dex gRPC connection in state %s, reconnecting", state)
|
||||||
|
_ = dm.conn.Close()
|
||||||
|
dm.conn = nil
|
||||||
|
} else {
|
||||||
|
return dm.conn, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debugf("connecting to Dex gRPC API at %s", dm.grpcAddr)
|
||||||
|
|
||||||
|
conn, err := grpc.NewClient(dm.grpcAddr,
|
||||||
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to connect to Dex gRPC API: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dm.conn = conn
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDexClient returns a Dex API client.
|
||||||
|
func (dm *DexManager) getDexClient(ctx context.Context) (api.DexClient, error) {
|
||||||
|
conn, err := dm.getConnection(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return api.NewDexClient(conn), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeDexUserID encodes a user ID and connector ID into Dex's composite format.
|
||||||
|
// This is the reverse of parseDexUserID - it creates the base64-encoded protobuf
|
||||||
|
// format that Dex uses in JWT tokens.
|
||||||
|
func encodeDexUserID(userID, connectorID string) string {
|
||||||
|
// Build simple protobuf structure:
|
||||||
|
// Field 1 (tag 0x0a): user ID string
|
||||||
|
// Field 2 (tag 0x12): connector ID string
|
||||||
|
buf := make([]byte, 0, 2+len(userID)+2+len(connectorID))
|
||||||
|
|
||||||
|
// Field 1: user ID
|
||||||
|
buf = append(buf, 0x0a) // tag for field 1, wire type 2 (length-delimited)
|
||||||
|
buf = append(buf, byte(len(userID))) // length
|
||||||
|
buf = append(buf, []byte(userID)...) // value
|
||||||
|
|
||||||
|
// Field 2: connector ID
|
||||||
|
buf = append(buf, 0x12) // tag for field 2, wire type 2 (length-delimited)
|
||||||
|
buf = append(buf, byte(len(connectorID))) // length
|
||||||
|
buf = append(buf, []byte(connectorID)...) // value
|
||||||
|
|
||||||
|
return base64.StdEncoding.EncodeToString(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseDexUserID extracts the actual user ID from Dex's composite user ID.
|
||||||
|
// Dex encodes user IDs in JWT tokens as base64-encoded protobuf with format:
|
||||||
|
// - Field 1 (string): actual user ID
|
||||||
|
// - Field 2 (string): connector ID (e.g., "local")
|
||||||
|
// If the ID is not in this format, it returns the original ID.
|
||||||
|
func parseDexUserID(compositeID string) string {
|
||||||
|
// Try to decode as standard base64
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(compositeID)
|
||||||
|
if err != nil {
|
||||||
|
// Try URL-safe base64
|
||||||
|
decoded, err = base64.RawURLEncoding.DecodeString(compositeID)
|
||||||
|
if err != nil {
|
||||||
|
// Not base64 encoded, return as-is
|
||||||
|
return compositeID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the simple protobuf structure
|
||||||
|
// Field 1 (tag 0x0a): user ID string
|
||||||
|
// Field 2 (tag 0x12): connector ID string
|
||||||
|
if len(decoded) < 2 {
|
||||||
|
return compositeID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for field 1 tag (0x0a = field 1, wire type 2/length-delimited)
|
||||||
|
if decoded[0] != 0x0a {
|
||||||
|
return compositeID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the length of the user ID string
|
||||||
|
length := int(decoded[1])
|
||||||
|
if len(decoded) < 2+length {
|
||||||
|
return compositeID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the user ID
|
||||||
|
userID := string(decoded[2 : 2+length])
|
||||||
|
return userID
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||||
|
// Dex doesn't support app metadata, so this is a no-op.
|
||||||
|
func (dm *DexManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserDataByID requests user data from Dex via user ID.
|
||||||
|
func (dm *DexManager) GetUserDataByID(ctx context.Context, userID string, _ AppMetadata) (*UserData, error) {
|
||||||
|
if dm.appMetrics != nil {
|
||||||
|
dm.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := dm.getDexClient(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if dm.appMetrics != nil {
|
||||||
|
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.ListPasswords(ctx, &api.ListPasswordReq{})
|
||||||
|
if err != nil {
|
||||||
|
if dm.appMetrics != nil {
|
||||||
|
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to list passwords from Dex: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse the composite user ID from Dex JWT token
|
||||||
|
actualUserID := parseDexUserID(userID)
|
||||||
|
|
||||||
|
for _, p := range resp.Passwords {
|
||||||
|
// Match against both the raw userID and the parsed actualUserID
|
||||||
|
if p.UserId == userID || p.UserId == actualUserID {
|
||||||
|
return &UserData{
|
||||||
|
Email: p.Email,
|
||||||
|
Name: p.Username,
|
||||||
|
ID: userID, // Return the original ID for consistency
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("user with ID %s not found", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccount returns all the users for a given account.
|
||||||
|
// Since Dex doesn't have account concepts, this returns all users.
|
||||||
|
func (dm *DexManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||||
|
if dm.appMetrics != nil {
|
||||||
|
dm.appMetrics.IDPMetrics().CountGetAccount()
|
||||||
|
}
|
||||||
|
|
||||||
|
users, err := dm.getAllUsers(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the account ID for all users
|
||||||
|
for _, user := range users {
|
||||||
|
user.AppMetadata.WTAccountID = accountID
|
||||||
|
}
|
||||||
|
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||||
|
// Since Dex doesn't have account concepts, all users are returned under UnsetAccountID.
|
||||||
|
func (dm *DexManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||||
|
if dm.appMetrics != nil {
|
||||||
|
dm.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||||
|
}
|
||||||
|
|
||||||
|
users, err := dm.getAllUsers(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
indexedUsers := make(map[string][]*UserData)
|
||||||
|
indexedUsers[UnsetAccountID] = users
|
||||||
|
|
||||||
|
return indexedUsers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateUser creates a new user in Dex's password database.
|
||||||
|
func (dm *DexManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||||
|
if dm.appMetrics != nil {
|
||||||
|
dm.appMetrics.IDPMetrics().CountCreateUser()
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := dm.getDexClient(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if dm.appMetrics != nil {
|
||||||
|
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a random password for the new user
|
||||||
|
password := GeneratePassword(16, 2, 2, 2)
|
||||||
|
|
||||||
|
// Hash the password using bcrypt
|
||||||
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a user ID from email (Dex uses email as the key, but we need a stable ID)
|
||||||
|
userID := strings.ReplaceAll(email, "@", "-at-")
|
||||||
|
userID = strings.ReplaceAll(userID, ".", "-")
|
||||||
|
|
||||||
|
req := &api.CreatePasswordReq{
|
||||||
|
Password: &api.Password{
|
||||||
|
Email: email,
|
||||||
|
Username: name,
|
||||||
|
UserId: userID,
|
||||||
|
Hash: hashedPassword,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.CreatePassword(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
if dm.appMetrics != nil {
|
||||||
|
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to create user in Dex: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.AlreadyExists {
|
||||||
|
return nil, fmt.Errorf("user with email %s already exists", email)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debugf("created user %s in Dex", email)
|
||||||
|
|
||||||
|
return &UserData{
|
||||||
|
Email: email,
|
||||||
|
Name: name,
|
||||||
|
ID: userID,
|
||||||
|
AppMetadata: AppMetadata{
|
||||||
|
WTAccountID: accountID,
|
||||||
|
WTInvitedBy: invitedByEmail,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByEmail searches users with a given email.
|
||||||
|
// If no users have been found, this function returns an empty list.
|
||||||
|
func (dm *DexManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||||
|
if dm.appMetrics != nil {
|
||||||
|
dm.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := dm.getDexClient(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if dm.appMetrics != nil {
|
||||||
|
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.ListPasswords(ctx, &api.ListPasswordReq{})
|
||||||
|
if err != nil {
|
||||||
|
if dm.appMetrics != nil {
|
||||||
|
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to list passwords from Dex: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
users := make([]*UserData, 0)
|
||||||
|
for _, p := range resp.Passwords {
|
||||||
|
if strings.EqualFold(p.Email, email) {
|
||||||
|
// Encode the user ID in Dex's composite format to match stored IDs
|
||||||
|
encodedID := encodeDexUserID(p.UserId, "local")
|
||||||
|
users = append(users, &UserData{
|
||||||
|
Email: p.Email,
|
||||||
|
Name: p.Username,
|
||||||
|
ID: encodedID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InviteUserByID resends an invitation to a user.
|
||||||
|
// Dex doesn't support invitations, so this returns an error.
|
||||||
|
func (dm *DexManager) InviteUserByID(_ context.Context, _ string) error {
|
||||||
|
return fmt.Errorf("method InviteUserByID not implemented for Dex")
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteUser deletes a user from Dex by user ID.
|
||||||
|
func (dm *DexManager) DeleteUser(ctx context.Context, userID string) error {
|
||||||
|
if dm.appMetrics != nil {
|
||||||
|
dm.appMetrics.IDPMetrics().CountDeleteUser()
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := dm.getDexClient(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if dm.appMetrics != nil {
|
||||||
|
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// First, find the user's email by ID
|
||||||
|
resp, err := client.ListPasswords(ctx, &api.ListPasswordReq{})
|
||||||
|
if err != nil {
|
||||||
|
if dm.appMetrics != nil {
|
||||||
|
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to list passwords from Dex: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse the composite user ID from Dex JWT token
|
||||||
|
actualUserID := parseDexUserID(userID)
|
||||||
|
|
||||||
|
var email string
|
||||||
|
for _, p := range resp.Passwords {
|
||||||
|
if p.UserId == userID || p.UserId == actualUserID {
|
||||||
|
email = p.Email
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if email == "" {
|
||||||
|
return fmt.Errorf("user with ID %s not found", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the user by email
|
||||||
|
deleteResp, err := client.DeletePassword(ctx, &api.DeletePasswordReq{
|
||||||
|
Email: email,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
if dm.appMetrics != nil {
|
||||||
|
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to delete user from Dex: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if deleteResp.NotFound {
|
||||||
|
return fmt.Errorf("user with email %s not found", email)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debugf("deleted user %s from Dex", email)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAllUsers retrieves all users from Dex's password database.
|
||||||
|
func (dm *DexManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||||
|
client, err := dm.getDexClient(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if dm.appMetrics != nil {
|
||||||
|
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.ListPasswords(ctx, &api.ListPasswordReq{})
|
||||||
|
if err != nil {
|
||||||
|
if dm.appMetrics != nil {
|
||||||
|
dm.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to list passwords from Dex: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
users := make([]*UserData, 0, len(resp.Passwords))
|
||||||
|
for _, p := range resp.Passwords {
|
||||||
|
// Encode the user ID in Dex's composite format (base64-encoded protobuf)
|
||||||
|
// to match how NetBird stores user IDs from Dex JWT tokens.
|
||||||
|
// The connector ID "local" is used for Dex's password database.
|
||||||
|
encodedID := encodeDexUserID(p.UserId, "local")
|
||||||
|
users = append(users, &UserData{
|
||||||
|
Email: p.Email,
|
||||||
|
Name: p.Username,
|
||||||
|
ID: encodedID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
137
management/server/idp/dex_test.go
Normal file
137
management/server/idp/dex_test.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package idp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewDexManager(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
name string
|
||||||
|
inputConfig DexClientConfig
|
||||||
|
assertErrFunc require.ErrorAssertionFunc
|
||||||
|
assertErrFuncMessage string
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultTestConfig := DexClientConfig{
|
||||||
|
GRPCAddr: "localhost:5557",
|
||||||
|
Issuer: "https://dex.example.com/dex",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCase1 := test{
|
||||||
|
name: "Good Configuration",
|
||||||
|
inputConfig: defaultTestConfig,
|
||||||
|
assertErrFunc: require.NoError,
|
||||||
|
assertErrFuncMessage: "shouldn't return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCase2Config := defaultTestConfig
|
||||||
|
testCase2Config.GRPCAddr = ""
|
||||||
|
|
||||||
|
testCase2 := test{
|
||||||
|
name: "Missing GRPCAddr Configuration",
|
||||||
|
inputConfig: testCase2Config,
|
||||||
|
assertErrFunc: require.Error,
|
||||||
|
assertErrFuncMessage: "should return error when GRPCAddr is empty",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with empty issuer - should still work since issuer is optional for the manager
|
||||||
|
testCase3Config := defaultTestConfig
|
||||||
|
testCase3Config.Issuer = ""
|
||||||
|
|
||||||
|
testCase3 := test{
|
||||||
|
name: "Missing Issuer Configuration - OK",
|
||||||
|
inputConfig: testCase3Config,
|
||||||
|
assertErrFunc: require.NoError,
|
||||||
|
assertErrFuncMessage: "shouldn't return error when issuer is empty",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []test{testCase1, testCase2, testCase3} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
manager, err := NewDexManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||||
|
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
require.NotNil(t, manager, "manager should not be nil")
|
||||||
|
require.Equal(t, testCase.inputConfig.GRPCAddr, manager.grpcAddr, "grpcAddr should match")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDexManagerUpdateUserAppMetadata(t *testing.T) {
|
||||||
|
config := DexClientConfig{
|
||||||
|
GRPCAddr: "localhost:5557",
|
||||||
|
Issuer: "https://dex.example.com/dex",
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := NewDexManager(config, &telemetry.MockAppMetrics{})
|
||||||
|
require.NoError(t, err, "should create manager without error")
|
||||||
|
|
||||||
|
// UpdateUserAppMetadata should be a no-op for Dex
|
||||||
|
err = manager.UpdateUserAppMetadata(context.Background(), "test-user-id", AppMetadata{
|
||||||
|
WTAccountID: "test-account",
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "UpdateUserAppMetadata should not return error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDexManagerInviteUserByID(t *testing.T) {
|
||||||
|
config := DexClientConfig{
|
||||||
|
GRPCAddr: "localhost:5557",
|
||||||
|
Issuer: "https://dex.example.com/dex",
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := NewDexManager(config, &telemetry.MockAppMetrics{})
|
||||||
|
require.NoError(t, err, "should create manager without error")
|
||||||
|
|
||||||
|
// InviteUserByID should return an error for Dex
|
||||||
|
err = manager.InviteUserByID(context.Background(), "test-user-id")
|
||||||
|
require.Error(t, err, "InviteUserByID should return error")
|
||||||
|
require.Contains(t, err.Error(), "not implemented", "error should mention not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseDexUserID(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
compositeID string
|
||||||
|
expectedID string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Parse base64-encoded protobuf composite ID",
|
||||||
|
// This is a real Dex composite ID: contains user ID "cf5db180-d360-484d-9b78-c5db92146420" and connector "local"
|
||||||
|
compositeID: "CiRjZjVkYjE4MC1kMzYwLTQ4NGQtOWI3OC1jNWRiOTIxNDY0MjASBWxvY2Fs",
|
||||||
|
expectedID: "cf5db180-d360-484d-9b78-c5db92146420",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Return plain ID unchanged",
|
||||||
|
compositeID: "simple-user-id",
|
||||||
|
expectedID: "simple-user-id",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Return UUID unchanged",
|
||||||
|
compositeID: "cf5db180-d360-484d-9b78-c5db92146420",
|
||||||
|
expectedID: "cf5db180-d360-484d-9b78-c5db92146420",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Handle empty string",
|
||||||
|
compositeID: "",
|
||||||
|
expectedID: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Handle invalid base64",
|
||||||
|
compositeID: "not-valid-base64!!!",
|
||||||
|
expectedID: "not-valid-base64!!!",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := parseDexUserID(tt.compositeID)
|
||||||
|
require.Equal(t, tt.expectedID, result, "parsed user ID should match expected")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -173,40 +173,40 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
|
|||||||
|
|
||||||
return NewZitadelManager(*zitadelClientConfig, appMetrics)
|
return NewZitadelManager(*zitadelClientConfig, appMetrics)
|
||||||
case "authentik":
|
case "authentik":
|
||||||
authentikConfig := AuthentikClientConfig{
|
return NewAuthentikManager(AuthentikClientConfig{
|
||||||
Issuer: config.ClientConfig.Issuer,
|
Issuer: config.ClientConfig.Issuer,
|
||||||
ClientID: config.ClientConfig.ClientID,
|
ClientID: config.ClientConfig.ClientID,
|
||||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
||||||
GrantType: config.ClientConfig.GrantType,
|
GrantType: config.ClientConfig.GrantType,
|
||||||
Username: config.ExtraConfig["Username"],
|
Username: config.ExtraConfig["Username"],
|
||||||
Password: config.ExtraConfig["Password"],
|
Password: config.ExtraConfig["Password"],
|
||||||
}
|
}, appMetrics)
|
||||||
return NewAuthentikManager(authentikConfig, appMetrics)
|
|
||||||
case "okta":
|
case "okta":
|
||||||
oktaClientConfig := OktaClientConfig{
|
return NewOktaManager(OktaClientConfig{
|
||||||
Issuer: config.ClientConfig.Issuer,
|
Issuer: config.ClientConfig.Issuer,
|
||||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
||||||
GrantType: config.ClientConfig.GrantType,
|
GrantType: config.ClientConfig.GrantType,
|
||||||
APIToken: config.ExtraConfig["ApiToken"],
|
APIToken: config.ExtraConfig["ApiToken"],
|
||||||
}
|
}, appMetrics)
|
||||||
return NewOktaManager(oktaClientConfig, appMetrics)
|
|
||||||
case "google":
|
case "google":
|
||||||
googleClientConfig := GoogleWorkspaceClientConfig{
|
return NewGoogleWorkspaceManager(ctx, GoogleWorkspaceClientConfig{
|
||||||
ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"],
|
ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"],
|
||||||
CustomerID: config.ExtraConfig["CustomerId"],
|
CustomerID: config.ExtraConfig["CustomerId"],
|
||||||
}
|
}, appMetrics)
|
||||||
return NewGoogleWorkspaceManager(ctx, googleClientConfig, appMetrics)
|
|
||||||
case "jumpcloud":
|
case "jumpcloud":
|
||||||
jumpcloudConfig := JumpCloudClientConfig{
|
return NewJumpCloudManager(JumpCloudClientConfig{
|
||||||
APIToken: config.ExtraConfig["ApiToken"],
|
APIToken: config.ExtraConfig["ApiToken"],
|
||||||
}
|
}, appMetrics)
|
||||||
return NewJumpCloudManager(jumpcloudConfig, appMetrics)
|
|
||||||
case "pocketid":
|
case "pocketid":
|
||||||
pocketidConfig := PocketIdClientConfig{
|
return NewPocketIdManager(PocketIdClientConfig{
|
||||||
APIToken: config.ExtraConfig["ApiToken"],
|
APIToken: config.ExtraConfig["ApiToken"],
|
||||||
ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"],
|
ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"],
|
||||||
}
|
}, appMetrics)
|
||||||
return NewPocketIdManager(pocketidConfig, appMetrics)
|
case "dex":
|
||||||
|
return NewDexManager(DexClientConfig{
|
||||||
|
GRPCAddr: config.ExtraConfig["GRPCAddr"],
|
||||||
|
Issuer: config.ClientConfig.Issuer,
|
||||||
|
}, appMetrics)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,11 +2,12 @@ package mock_server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/netbirdio/netbird/shared/auth"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
@@ -988,3 +989,7 @@ func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, ac
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *MockAccountManager) GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) {
|
||||||
|
return "something", nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
|
|||||||
|
|
||||||
// fetch all the peers that have access to the user's peers
|
// fetch all the peers that have access to the user's peers
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap)
|
aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap, account.GetActiveGroupUsers())
|
||||||
for _, p := range aclPeers {
|
for _, p := range aclPeers {
|
||||||
peersMap[p.ID] = p
|
peersMap[p.ID] = p
|
||||||
}
|
}
|
||||||
@@ -1057,7 +1057,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range userPeers {
|
for _, p := range userPeers {
|
||||||
aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap)
|
aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap, account.GetActiveGroupUsers())
|
||||||
for _, aclPeer := range aclPeers {
|
for _, aclPeer := range aclPeers {
|
||||||
if aclPeer.ID == peer.ID {
|
if aclPeer.ID == peer.ID {
|
||||||
return peer, nil
|
return peer, nil
|
||||||
|
|||||||
@@ -246,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("check that all peers get map", func(t *testing.T) {
|
t.Run("check that all peers get map", func(t *testing.T) {
|
||||||
for _, p := range account.Peers {
|
for _, p := range account.Peers {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), p, validatedPeers, account.GetActiveGroupUsers())
|
||||||
assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
|
assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
|
||||||
assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
|
assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("check first peer map details", func(t *testing.T) {
|
t.Run("check first peer map details", func(t *testing.T) {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 8)
|
assert.Len(t, peers, 8)
|
||||||
assert.Contains(t, peers, account.Peers["peerA"])
|
assert.Contains(t, peers, account.Peers["peerA"])
|
||||||
assert.Contains(t, peers, account.Peers["peerC"])
|
assert.Contains(t, peers, account.Peers["peerC"])
|
||||||
@@ -509,7 +509,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("check port ranges support for older peers", func(t *testing.T) {
|
t.Run("check port ranges support for older peers", func(t *testing.T) {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 1)
|
assert.Len(t, peers, 1)
|
||||||
assert.Contains(t, peers, account.Peers["peerI"])
|
assert.Contains(t, peers, account.Peers["peerI"])
|
||||||
|
|
||||||
@@ -635,7 +635,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Run("check first peer map", func(t *testing.T) {
|
t.Run("check first peer map", func(t *testing.T) {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Contains(t, peers, account.Peers["peerC"])
|
assert.Contains(t, peers, account.Peers["peerC"])
|
||||||
|
|
||||||
expectedFirewallRules := []*types.FirewallRule{
|
expectedFirewallRules := []*types.FirewallRule{
|
||||||
@@ -665,7 +665,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("check second peer map", func(t *testing.T) {
|
t.Run("check second peer map", func(t *testing.T) {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Contains(t, peers, account.Peers["peerB"])
|
assert.Contains(t, peers, account.Peers["peerB"])
|
||||||
|
|
||||||
expectedFirewallRules := []*types.FirewallRule{
|
expectedFirewallRules := []*types.FirewallRule{
|
||||||
@@ -697,7 +697,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
|||||||
account.Policies[1].Rules[0].Bidirectional = false
|
account.Policies[1].Rules[0].Bidirectional = false
|
||||||
|
|
||||||
t.Run("check first peer map directional only", func(t *testing.T) {
|
t.Run("check first peer map directional only", func(t *testing.T) {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Contains(t, peers, account.Peers["peerC"])
|
assert.Contains(t, peers, account.Peers["peerC"])
|
||||||
|
|
||||||
expectedFirewallRules := []*types.FirewallRule{
|
expectedFirewallRules := []*types.FirewallRule{
|
||||||
@@ -719,7 +719,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("check second peer map directional only", func(t *testing.T) {
|
t.Run("check second peer map directional only", func(t *testing.T) {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Contains(t, peers, account.Peers["peerB"])
|
assert.Contains(t, peers, account.Peers["peerB"])
|
||||||
|
|
||||||
expectedFirewallRules := []*types.FirewallRule{
|
expectedFirewallRules := []*types.FirewallRule{
|
||||||
@@ -917,7 +917,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
|
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
|
||||||
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
|
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
|
||||||
// will establish a connection with all source peers satisfying the NB posture check.
|
// will establish a connection with all source peers satisfying the NB posture check.
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 4)
|
assert.Len(t, peers, 4)
|
||||||
assert.Len(t, firewallRules, 4)
|
assert.Len(t, firewallRules, 4)
|
||||||
assert.Contains(t, peers, account.Peers["peerA"])
|
assert.Contains(t, peers, account.Peers["peerA"])
|
||||||
@@ -927,7 +927,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
|
|
||||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||||
// We expect a single permissive firewall rule which all outgoing connections
|
// We expect a single permissive firewall rule which all outgoing connections
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||||
assert.Len(t, firewallRules, 7)
|
assert.Len(t, firewallRules, 7)
|
||||||
expectedFirewallRules := []*types.FirewallRule{
|
expectedFirewallRules := []*types.FirewallRule{
|
||||||
@@ -992,7 +992,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
|
|
||||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||||
// all source group peers satisfying the NB posture check should establish connection
|
// all source group peers satisfying the NB posture check should establish connection
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
|
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 4)
|
assert.Len(t, peers, 4)
|
||||||
assert.Len(t, firewallRules, 4)
|
assert.Len(t, firewallRules, 4)
|
||||||
assert.Contains(t, peers, account.Peers["peerA"])
|
assert.Contains(t, peers, account.Peers["peerA"])
|
||||||
@@ -1002,7 +1002,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
|
|
||||||
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
|
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
|
||||||
// all source group peers satisfying the NB posture check should establish connection
|
// all source group peers satisfying the NB posture check should establish connection
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
|
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 4)
|
assert.Len(t, peers, 4)
|
||||||
assert.Len(t, firewallRules, 4)
|
assert.Len(t, firewallRules, 4)
|
||||||
assert.Contains(t, peers, account.Peers["peerA"])
|
assert.Contains(t, peers, account.Peers["peerA"])
|
||||||
@@ -1017,19 +1017,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
|
|
||||||
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
|
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
|
||||||
// no connection should be established to any peer of destination group
|
// no connection should be established to any peer of destination group
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 0)
|
assert.Len(t, peers, 0)
|
||||||
assert.Len(t, firewallRules, 0)
|
assert.Len(t, firewallRules, 0)
|
||||||
|
|
||||||
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
|
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
|
||||||
// no connection should be established to any peer of destination group
|
// no connection should be established to any peer of destination group
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
|
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 0)
|
assert.Len(t, peers, 0)
|
||||||
assert.Len(t, firewallRules, 0)
|
assert.Len(t, firewallRules, 0)
|
||||||
|
|
||||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||||
// We expect a single permissive firewall rule which all outgoing connections
|
// We expect a single permissive firewall rule which all outgoing connections
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||||
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
|
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
|
||||||
|
|
||||||
@@ -1044,14 +1044,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
|
|
||||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||||
// all source group peers satisfying the NB posture check should establish connection
|
// all source group peers satisfying the NB posture check should establish connection
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
|
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 3)
|
assert.Len(t, peers, 3)
|
||||||
assert.Len(t, firewallRules, 3)
|
assert.Len(t, firewallRules, 3)
|
||||||
assert.Contains(t, peers, account.Peers["peerA"])
|
assert.Contains(t, peers, account.Peers["peerA"])
|
||||||
assert.Contains(t, peers, account.Peers["peerC"])
|
assert.Contains(t, peers, account.Peers["peerC"])
|
||||||
assert.Contains(t, peers, account.Peers["peerD"])
|
assert.Contains(t, peers, account.Peers["peerD"])
|
||||||
|
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers)
|
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 5)
|
assert.Len(t, peers, 5)
|
||||||
// assert peers from Group Swarm
|
// assert peers from Group Swarm
|
||||||
assert.Contains(t, peers, account.Peers["peerD"])
|
assert.Contains(t, peers, account.Peers["peerD"])
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ type SqlStore struct {
|
|||||||
installationPK int
|
installationPK int
|
||||||
storeEngine types.Engine
|
storeEngine types.Engine
|
||||||
pool *pgxpool.Pool
|
pool *pgxpool.Pool
|
||||||
|
|
||||||
|
transactionTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type installation struct {
|
type installation struct {
|
||||||
@@ -84,6 +86,14 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
|||||||
conns = runtime.NumCPU()
|
conns = runtime.NumCPU()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
transactionTimeout := 5 * time.Minute
|
||||||
|
if v := os.Getenv("NB_STORE_TRANSACTION_TIMEOUT"); v != "" {
|
||||||
|
if parsed, err := time.ParseDuration(v); err == nil {
|
||||||
|
transactionTimeout = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Infof("Setting transaction timeout to %v", transactionTimeout)
|
||||||
|
|
||||||
if storeEngine == types.SqliteStoreEngine {
|
if storeEngine == types.SqliteStoreEngine {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
|
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
|
||||||
@@ -101,7 +111,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
|||||||
|
|
||||||
if skipMigration {
|
if skipMigration {
|
||||||
log.WithContext(ctx).Infof("skipping migration")
|
log.WithContext(ctx).Infof("skipping migration")
|
||||||
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
|
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1, transactionTimeout: transactionTimeout}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := migratePreAuto(ctx, db); err != nil {
|
if err := migratePreAuto(ctx, db); err != nil {
|
||||||
@@ -120,7 +130,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
|||||||
return nil, fmt.Errorf("migratePostAuto: %w", err)
|
return nil, fmt.Errorf("migratePostAuto: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
|
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1, transactionTimeout: transactionTimeout}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetKeyQueryCondition(s *SqlStore) string {
|
func GetKeyQueryCondition(s *SqlStore) string {
|
||||||
@@ -1910,16 +1920,17 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t
|
|||||||
if len(policyIDs) == 0 {
|
if len(policyIDs) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)`
|
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges, authorized_groups, authorized_user FROM policy_rules WHERE policy_id = ANY($1)`
|
||||||
rows, err := s.pool.Query(ctx, query, policyIDs)
|
rows, err := s.pool.Query(ctx, query, policyIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) {
|
rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) {
|
||||||
var r types.PolicyRule
|
var r types.PolicyRule
|
||||||
var dest, destRes, sources, sourceRes, ports, portRanges []byte
|
var dest, destRes, sources, sourceRes, ports, portRanges, authorizedGroups []byte
|
||||||
var enabled, bidirectional sql.NullBool
|
var enabled, bidirectional sql.NullBool
|
||||||
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges)
|
var authorizedUser sql.NullString
|
||||||
|
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges, &authorizedGroups, &authorizedUser)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if enabled.Valid {
|
if enabled.Valid {
|
||||||
r.Enabled = enabled.Bool
|
r.Enabled = enabled.Bool
|
||||||
@@ -1945,6 +1956,12 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t
|
|||||||
if portRanges != nil {
|
if portRanges != nil {
|
||||||
_ = json.Unmarshal(portRanges, &r.PortRanges)
|
_ = json.Unmarshal(portRanges, &r.PortRanges)
|
||||||
}
|
}
|
||||||
|
if authorizedGroups != nil {
|
||||||
|
_ = json.Unmarshal(authorizedGroups, &r.AuthorizedGroups)
|
||||||
|
}
|
||||||
|
if authorizedUser.Valid {
|
||||||
|
r.AuthorizedUser = authorizedUser.String
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return &r, err
|
return &r, err
|
||||||
})
|
})
|
||||||
@@ -2890,8 +2907,11 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
|
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
|
||||||
|
timeoutCtx, cancel := context.WithTimeout(context.Background(), s.transactionTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
tx := s.db.Begin()
|
tx := s.db.WithContext(timeoutCtx).Begin()
|
||||||
if tx.Error != nil {
|
if tx.Error != nil {
|
||||||
return tx.Error
|
return tx.Error
|
||||||
}
|
}
|
||||||
@@ -2926,6 +2946,9 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
|
|||||||
err := operation(repo)
|
err := operation(repo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) || errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) {
|
||||||
|
log.WithContext(ctx).Warnf("transaction exceeded %s timeout after %v, stack: %s", s.transactionTimeout, time.Since(startTime), debug.Stack())
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2938,13 +2961,19 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Commit().Error
|
err = tx.Commit().Error
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) || errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) {
|
||||||
|
log.WithContext(ctx).Warnf("transaction commit exceeded %s timeout after %v, stack: %s", s.transactionTimeout, time.Since(startTime), debug.Stack())
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime))
|
log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime))
|
||||||
if s.metrics != nil {
|
if s.metrics != nil {
|
||||||
s.metrics.StoreMetrics().CountTransactionDuration(time.Since(startTime))
|
s.metrics.StoreMetrics().CountTransactionDuration(time.Since(startTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||||
@@ -4075,3 +4104,21 @@ func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, gro
|
|||||||
|
|
||||||
return peers, nil
|
return peers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error) {
|
||||||
|
tx := s.db
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
|
||||||
|
var userID string
|
||||||
|
result := tx.Model(&nbpeer.Peer{}).
|
||||||
|
Select("user_id").
|
||||||
|
Take(&userID, GetKeyQueryCondition(s), peerKey)
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
return "", status.Errorf(status.Internal, "failed to get user ID by peer key")
|
||||||
|
}
|
||||||
|
|
||||||
|
return userID, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -3718,6 +3718,69 @@ func TestSqlStore_GetPeersByGroupIDs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetUserIDByPeerKey(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
userID := "test-user-123"
|
||||||
|
peerKey := "peer-key-abc"
|
||||||
|
|
||||||
|
peer := &nbpeer.Peer{
|
||||||
|
ID: "test-peer-1",
|
||||||
|
Key: peerKey,
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
UserID: userID,
|
||||||
|
IP: net.IP{10, 0, 0, 1},
|
||||||
|
DNSLabel: "test-peer-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
err = store.AddPeerToAccount(context.Background(), peer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
retrievedUserID, err := store.GetUserIDByPeerKey(context.Background(), LockingStrengthNone, peerKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, userID, retrievedUserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetUserIDByPeerKey_NotFound(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
nonExistentPeerKey := "non-existent-peer-key"
|
||||||
|
|
||||||
|
userID, err := store.GetUserIDByPeerKey(context.Background(), LockingStrengthNone, nonExistentPeerKey)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Equal(t, "", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetUserIDByPeerKey_NoUserID(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
peerKey := "peer-key-abc"
|
||||||
|
|
||||||
|
peer := &nbpeer.Peer{
|
||||||
|
ID: "test-peer-1",
|
||||||
|
Key: peerKey,
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
UserID: "",
|
||||||
|
IP: net.IP{10, 0, 0, 1},
|
||||||
|
DNSLabel: "test-peer-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
err = store.AddPeerToAccount(context.Background(), peer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
retrievedUserID, err := store.GetUserIDByPeerKey(context.Background(), LockingStrengthNone, peerKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "", retrievedUserID)
|
||||||
|
}
|
||||||
|
|
||||||
func TestSqlStore_ApproveAccountPeers(t *testing.T) {
|
func TestSqlStore_ApproveAccountPeers(t *testing.T) {
|
||||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||||
accountID := "test-account"
|
accountID := "test-account"
|
||||||
@@ -3794,3 +3857,30 @@ func TestSqlStore_ApproveAccountPeers(t *testing.T) {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_ExecuteInTransaction_Timeout(t *testing.T) {
|
||||||
|
if os.Getenv("NETBIRD_STORE_ENGINE") == "mysql" {
|
||||||
|
t.Skip("Skipping timeout test for MySQL")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Setenv("NB_STORE_TRANSACTION_TIMEOUT", "1s")
|
||||||
|
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
|
sqlStore, ok := store.(*SqlStore)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, 1*time.Second, sqlStore.transactionTimeout)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err = sqlStore.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
// Sleep for 2 seconds to exceed the 1 second timeout
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// The transaction should fail with an error (either timeout or already rolled back)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "transaction has already been committed or rolled back", "expected transaction rolled back error, got: %v", err)
|
||||||
|
}
|
||||||
|
|||||||
@@ -204,6 +204,7 @@ type Store interface {
|
|||||||
MarkAccountPrimary(ctx context.Context, accountID string) error
|
MarkAccountPrimary(ctx context.Context, accountID string) error
|
||||||
UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error
|
UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error
|
||||||
GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) ([]*types.PolicyRule, error)
|
GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) ([]*types.PolicyRule, error)
|
||||||
|
GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
@@ -45,8 +46,10 @@ const (
|
|||||||
|
|
||||||
// nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections.
|
// nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections.
|
||||||
nativeSSHPortString = "22022"
|
nativeSSHPortString = "22022"
|
||||||
|
nativeSSHPortNumber = 22022
|
||||||
// defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections.
|
// defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections.
|
||||||
defaultSSHPortString = "22"
|
defaultSSHPortString = "22"
|
||||||
|
defaultSSHPortNumber = 22
|
||||||
)
|
)
|
||||||
|
|
||||||
type supportedFeatures struct {
|
type supportedFeatures struct {
|
||||||
@@ -275,6 +278,7 @@ func (a *Account) GetPeerNetworkMap(
|
|||||||
resourcePolicies map[string][]*Policy,
|
resourcePolicies map[string][]*Policy,
|
||||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||||
metrics *telemetry.AccountManagerMetrics,
|
metrics *telemetry.AccountManagerMetrics,
|
||||||
|
groupIDToUserIDs map[string][]string,
|
||||||
) *NetworkMap {
|
) *NetworkMap {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
peer := a.Peers[peerID]
|
peer := a.Peers[peerID]
|
||||||
@@ -290,7 +294,7 @@ func (a *Account) GetPeerNetworkMap(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
|
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
|
||||||
// exclude expired peers
|
// exclude expired peers
|
||||||
var peersToConnect []*nbpeer.Peer
|
var peersToConnect []*nbpeer.Peer
|
||||||
var expiredPeers []*nbpeer.Peer
|
var expiredPeers []*nbpeer.Peer
|
||||||
@@ -338,6 +342,8 @@ func (a *Account) GetPeerNetworkMap(
|
|||||||
OfflinePeers: expiredPeers,
|
OfflinePeers: expiredPeers,
|
||||||
FirewallRules: firewallRules,
|
FirewallRules: firewallRules,
|
||||||
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
|
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
|
||||||
|
AuthorizedUsers: authorizedUsers,
|
||||||
|
EnableSSH: enableSSH,
|
||||||
}
|
}
|
||||||
|
|
||||||
if metrics != nil {
|
if metrics != nil {
|
||||||
@@ -1009,8 +1015,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
|
|||||||
// GetPeerConnectionResources for a given peer
|
// GetPeerConnectionResources for a given peer
|
||||||
//
|
//
|
||||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||||
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, groupIDToUserIDs map[string][]string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) {
|
||||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
|
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
|
||||||
|
authorizedUsers := make(map[string]map[string]struct{}) // machine user to list of userIDs
|
||||||
|
sshEnabled := false
|
||||||
|
|
||||||
for _, policy := range a.Policies {
|
for _, policy := range a.Policies {
|
||||||
if !policy.Enabled {
|
if !policy.Enabled {
|
||||||
@@ -1053,10 +1061,58 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
|
|||||||
if peerInDestinations {
|
if peerInDestinations {
|
||||||
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
|
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||||
|
sshEnabled = true
|
||||||
|
switch {
|
||||||
|
case len(rule.AuthorizedGroups) > 0:
|
||||||
|
for groupID, localUsers := range rule.AuthorizedGroups {
|
||||||
|
userIDs, ok := groupIDToUserIDs[groupID]
|
||||||
|
if !ok {
|
||||||
|
log.WithContext(ctx).Tracef("no user IDs found for group ID %s", groupID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(localUsers) == 0 {
|
||||||
|
localUsers = []string{auth.Wildcard}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, localUser := range localUsers {
|
||||||
|
if authorizedUsers[localUser] == nil {
|
||||||
|
authorizedUsers[localUser] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
for _, userID := range userIDs {
|
||||||
|
authorizedUsers[localUser][userID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case rule.AuthorizedUser != "":
|
||||||
|
if authorizedUsers[auth.Wildcard] == nil {
|
||||||
|
authorizedUsers[auth.Wildcard] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{}
|
||||||
|
default:
|
||||||
|
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
|
||||||
|
}
|
||||||
|
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
|
||||||
|
sshEnabled = true
|
||||||
|
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return getAccumulatedResources()
|
peers, fwRules := getAccumulatedResources()
|
||||||
|
return peers, fwRules, authorizedUsers, sshEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) getAllowedUserIDs() map[string]struct{} {
|
||||||
|
users := make(map[string]struct{})
|
||||||
|
for _, nbUser := range a.Users {
|
||||||
|
if !nbUser.IsBlocked() && !nbUser.IsServiceUser {
|
||||||
|
users[nbUser.Id] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return users
|
||||||
}
|
}
|
||||||
|
|
||||||
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
|
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
|
||||||
@@ -1081,12 +1137,17 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
|
|||||||
peersExists[peer.ID] = struct{}{}
|
peersExists[peer.ID] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protocol := rule.Protocol
|
||||||
|
if protocol == PolicyRuleProtocolNetbirdSSH {
|
||||||
|
protocol = PolicyRuleProtocolTCP
|
||||||
|
}
|
||||||
|
|
||||||
fr := FirewallRule{
|
fr := FirewallRule{
|
||||||
PolicyID: rule.ID,
|
PolicyID: rule.ID,
|
||||||
PeerIP: peer.IP.String(),
|
PeerIP: peer.IP.String(),
|
||||||
Direction: direction,
|
Direction: direction,
|
||||||
Action: string(rule.Action),
|
Action: string(rule.Action),
|
||||||
Protocol: string(rule.Protocol),
|
Protocol: string(protocol),
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||||
@@ -1108,6 +1169,28 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func policyRuleImpliesLegacySSH(rule *PolicyRule) bool {
|
||||||
|
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func portRangeIncludesSSH(portRanges []RulePortRange) bool {
|
||||||
|
for _, pr := range portRanges {
|
||||||
|
if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func portsIncludesSSH(ports []string) bool {
|
||||||
|
for _, port := range ports {
|
||||||
|
if port == defaultSSHPortString || port == nativeSSHPortString {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// getAllPeersFromGroups for given peer ID and list of groups
|
// getAllPeersFromGroups for given peer ID and list of groups
|
||||||
//
|
//
|
||||||
// Returns a list of peers from specified groups that pass specified posture checks
|
// Returns a list of peers from specified groups that pass specified posture checks
|
||||||
@@ -1152,7 +1235,11 @@ func (a *Account) getPeerFromResource(resource Resource, peerID string) ([]*nbpe
|
|||||||
return []*nbpeer.Peer{}, false
|
return []*nbpeer.Peer{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
return []*nbpeer.Peer{peer}, resource.ID == peerID
|
if peer.ID == peerID {
|
||||||
|
return []*nbpeer.Peer{}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return []*nbpeer.Peer{peer}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// validatePostureChecksOnPeer validates the posture checks on a peer
|
// validatePostureChecksOnPeer validates the posture checks on a peer
|
||||||
@@ -1660,6 +1747,26 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetActiveGroupUsers() map[string][]string {
|
||||||
|
allGroupID := ""
|
||||||
|
group, err := a.GetGroupAll()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get group all: %v", err)
|
||||||
|
} else {
|
||||||
|
allGroupID = group.ID
|
||||||
|
}
|
||||||
|
groups := make(map[string][]string, len(a.GroupsG))
|
||||||
|
for _, user := range a.Users {
|
||||||
|
if !user.IsBlocked() && !user.IsServiceUser {
|
||||||
|
for _, groupID := range user.AutoGroups {
|
||||||
|
groups[groupID] = append(groups[groupID], user.Id)
|
||||||
|
}
|
||||||
|
groups[allGroupID] = append(groups[allGroupID], user.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return groups
|
||||||
|
}
|
||||||
|
|
||||||
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
|
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
|
||||||
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
||||||
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
|
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
|
||||||
@@ -1691,7 +1798,7 @@ func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer
|
|||||||
expanded = append(expanded, &fr)
|
expanded = append(expanded, &fr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) {
|
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) || rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||||
expanded = addNativeSSHRule(base, expanded)
|
expanded = addNativeSSHRule(base, expanded)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1105,6 +1105,193 @@ func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_GetActiveGroupUsers(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
account *Account
|
||||||
|
expected map[string][]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "all users are active",
|
||||||
|
account: &Account{
|
||||||
|
Users: map[string]*User{
|
||||||
|
"user1": {
|
||||||
|
Id: "user1",
|
||||||
|
AutoGroups: []string{"group1", "group2"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
"user2": {
|
||||||
|
Id: "user2",
|
||||||
|
AutoGroups: []string{"group2", "group3"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
"user3": {
|
||||||
|
Id: "user3",
|
||||||
|
AutoGroups: []string{"group1"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"group1": {"user1", "user3"},
|
||||||
|
"group2": {"user1", "user2"},
|
||||||
|
"group3": {"user2"},
|
||||||
|
"": {"user1", "user2", "user3"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "some users are blocked",
|
||||||
|
account: &Account{
|
||||||
|
Users: map[string]*User{
|
||||||
|
"user1": {
|
||||||
|
Id: "user1",
|
||||||
|
AutoGroups: []string{"group1", "group2"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
"user2": {
|
||||||
|
Id: "user2",
|
||||||
|
AutoGroups: []string{"group2", "group3"},
|
||||||
|
Blocked: true,
|
||||||
|
},
|
||||||
|
"user3": {
|
||||||
|
Id: "user3",
|
||||||
|
AutoGroups: []string{"group1", "group3"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"group1": {"user1", "user3"},
|
||||||
|
"group2": {"user1"},
|
||||||
|
"group3": {"user3"},
|
||||||
|
"": {"user1", "user3"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all users are blocked",
|
||||||
|
account: &Account{
|
||||||
|
Users: map[string]*User{
|
||||||
|
"user1": {
|
||||||
|
Id: "user1",
|
||||||
|
AutoGroups: []string{"group1"},
|
||||||
|
Blocked: true,
|
||||||
|
},
|
||||||
|
"user2": {
|
||||||
|
Id: "user2",
|
||||||
|
AutoGroups: []string{"group2"},
|
||||||
|
Blocked: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user with no auto groups",
|
||||||
|
account: &Account{
|
||||||
|
Users: map[string]*User{
|
||||||
|
"user1": {
|
||||||
|
Id: "user1",
|
||||||
|
AutoGroups: []string{},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
"user2": {
|
||||||
|
Id: "user2",
|
||||||
|
AutoGroups: []string{"group1"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"group1": {"user2"},
|
||||||
|
"": {"user1", "user2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty account",
|
||||||
|
account: &Account{
|
||||||
|
Users: map[string]*User{},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple users in same group",
|
||||||
|
account: &Account{
|
||||||
|
Users: map[string]*User{
|
||||||
|
"user1": {
|
||||||
|
Id: "user1",
|
||||||
|
AutoGroups: []string{"group1"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
"user2": {
|
||||||
|
Id: "user2",
|
||||||
|
AutoGroups: []string{"group1"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
"user3": {
|
||||||
|
Id: "user3",
|
||||||
|
AutoGroups: []string{"group1"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"group1": {"user1", "user2", "user3"},
|
||||||
|
"": {"user1", "user2", "user3"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user in multiple groups with blocked users",
|
||||||
|
account: &Account{
|
||||||
|
Users: map[string]*User{
|
||||||
|
"user1": {
|
||||||
|
Id: "user1",
|
||||||
|
AutoGroups: []string{"group1", "group2", "group3"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
"user2": {
|
||||||
|
Id: "user2",
|
||||||
|
AutoGroups: []string{"group1", "group2"},
|
||||||
|
Blocked: true,
|
||||||
|
},
|
||||||
|
"user3": {
|
||||||
|
Id: "user3",
|
||||||
|
AutoGroups: []string{"group3"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"group1": {"user1"},
|
||||||
|
"group2": {"user1"},
|
||||||
|
"group3": {"user1", "user3"},
|
||||||
|
"": {"user1", "user3"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.account.GetActiveGroupUsers()
|
||||||
|
|
||||||
|
// Check that the number of groups matches
|
||||||
|
assert.Equal(t, len(tt.expected), len(result), "number of groups should match")
|
||||||
|
|
||||||
|
// Check each group's users
|
||||||
|
for groupID, expectedUsers := range tt.expected {
|
||||||
|
actualUsers, exists := result[groupID]
|
||||||
|
assert.True(t, exists, "group %s should exist in result", groupID)
|
||||||
|
assert.ElementsMatch(t, expectedUsers, actualUsers, "users in group %s should match", groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure no extra groups in result
|
||||||
|
for groupID := range result {
|
||||||
|
_, exists := tt.expected[groupID]
|
||||||
|
assert.True(t, exists, "unexpected group %s in result", groupID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -38,6 +38,8 @@ type NetworkMap struct {
|
|||||||
FirewallRules []*FirewallRule
|
FirewallRules []*FirewallRule
|
||||||
RoutesFirewallRules []*RouteFirewallRule
|
RoutesFirewallRules []*RouteFirewallRule
|
||||||
ForwardingRules []*ForwardingRule
|
ForwardingRules []*ForwardingRule
|
||||||
|
AuthorizedUsers map[string]map[string]struct{}
|
||||||
|
EnableSSH bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nm *NetworkMap) Merge(other *NetworkMap) {
|
func (nm *NetworkMap) Merge(other *NetworkMap) {
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
|
|
||||||
normalizeAndSortNetworkMap(networkMap)
|
normalizeAndSortNetworkMap(networkMap)
|
||||||
|
|
||||||
@@ -141,7 +141,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
|
|||||||
b.Run("old builder", func(b *testing.B) {
|
b.Run("old builder", func(b *testing.B) {
|
||||||
for range b.N {
|
for range b.N {
|
||||||
for _, peerID := range peerIDs {
|
for _, peerID := range peerIDs {
|
||||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -201,7 +201,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
|
|
||||||
normalizeAndSortNetworkMap(networkMap)
|
normalizeAndSortNetworkMap(networkMap)
|
||||||
|
|
||||||
@@ -320,7 +320,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
|
|||||||
b.Run("old builder after add", func(b *testing.B) {
|
b.Run("old builder after add", func(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, testingPeerID := range peerIDs {
|
for _, testingPeerID := range peerIDs {
|
||||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -395,7 +395,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
|
|
||||||
normalizeAndSortNetworkMap(networkMap)
|
normalizeAndSortNetworkMap(networkMap)
|
||||||
|
|
||||||
@@ -550,7 +550,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
|
|||||||
b.Run("old builder after add", func(b *testing.B) {
|
b.Run("old builder after add", func(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, testingPeerID := range peerIDs {
|
for _, testingPeerID := range peerIDs {
|
||||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -604,7 +604,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
|
|
||||||
normalizeAndSortNetworkMap(networkMap)
|
normalizeAndSortNetworkMap(networkMap)
|
||||||
|
|
||||||
@@ -730,7 +730,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
|
|
||||||
normalizeAndSortNetworkMap(networkMap)
|
normalizeAndSortNetworkMap(networkMap)
|
||||||
|
|
||||||
@@ -847,7 +847,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
|
|||||||
b.Run("old builder after delete", func(b *testing.B) {
|
b.Run("old builder after delete", func(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, testingPeerID := range peerIDs {
|
for _, testingPeerID := range peerIDs {
|
||||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ const (
|
|||||||
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
|
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
|
||||||
// PolicyRuleProtocolICMP type of traffic
|
// PolicyRuleProtocolICMP type of traffic
|
||||||
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
|
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
|
||||||
|
// PolicyRuleProtocolNetbirdSSH type of traffic
|
||||||
|
PolicyRuleProtocolNetbirdSSH = PolicyRuleProtocolType("netbird-ssh")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -167,6 +169,8 @@ func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error)
|
|||||||
protocol = PolicyRuleProtocolUDP
|
protocol = PolicyRuleProtocolUDP
|
||||||
case "icmp":
|
case "icmp":
|
||||||
return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'")
|
return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'")
|
||||||
|
case "netbird-ssh":
|
||||||
|
return PolicyRuleProtocolNetbirdSSH, RulePortRange{Start: nativeSSHPortNumber, End: nativeSSHPortNumber}, nil
|
||||||
default:
|
default:
|
||||||
return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr)
|
return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,6 +80,12 @@ type PolicyRule struct {
|
|||||||
|
|
||||||
// PortRanges a list of port ranges.
|
// PortRanges a list of port ranges.
|
||||||
PortRanges []RulePortRange `gorm:"serializer:json"`
|
PortRanges []RulePortRange `gorm:"serializer:json"`
|
||||||
|
|
||||||
|
// AuthorizedGroups is a map of groupIDs and their respective access to local users via ssh
|
||||||
|
AuthorizedGroups map[string][]string `gorm:"serializer:json"`
|
||||||
|
|
||||||
|
// AuthorizedUser is a list of userIDs that are authorized to access local resources via ssh
|
||||||
|
AuthorizedUser string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy returns a copy of a policy rule
|
// Copy returns a copy of a policy rule
|
||||||
@@ -99,10 +105,16 @@ func (pm *PolicyRule) Copy() *PolicyRule {
|
|||||||
Protocol: pm.Protocol,
|
Protocol: pm.Protocol,
|
||||||
Ports: make([]string, len(pm.Ports)),
|
Ports: make([]string, len(pm.Ports)),
|
||||||
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
|
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
|
||||||
|
AuthorizedGroups: make(map[string][]string, len(pm.AuthorizedGroups)),
|
||||||
|
AuthorizedUser: pm.AuthorizedUser,
|
||||||
}
|
}
|
||||||
copy(rule.Destinations, pm.Destinations)
|
copy(rule.Destinations, pm.Destinations)
|
||||||
copy(rule.Sources, pm.Sources)
|
copy(rule.Sources, pm.Sources)
|
||||||
copy(rule.Ports, pm.Ports)
|
copy(rule.Ports, pm.Ports)
|
||||||
copy(rule.PortRanges, pm.PortRanges)
|
copy(rule.PortRanges, pm.PortRanges)
|
||||||
|
for k, v := range pm.AuthorizedGroups {
|
||||||
|
rule.AuthorizedGroups[k] = make([]string, len(v))
|
||||||
|
copy(rule.AuthorizedGroups[k], v)
|
||||||
|
}
|
||||||
return rule
|
return rule
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -523,16 +523,14 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
|
_, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
|
||||||
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
|
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to process update for user %s: %w", update.Id, err)
|
return fmt.Errorf("failed to process update for user %s: %w", update.Id, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if userHadPeers {
|
updateAccountPeers = true
|
||||||
updateAccountPeers = true
|
|
||||||
}
|
|
||||||
|
|
||||||
err = transaction.SaveUser(ctx, updatedUser)
|
err = transaction.SaveUser(ctx, updatedUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -581,7 +579,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if settings.GroupsPropagationEnabled && updateAccountPeers {
|
if updateAccountPeers {
|
||||||
if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil {
|
if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return nil, fmt.Errorf("failed to increment network serial: %w", err)
|
return nil, fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1379,11 +1379,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
|||||||
updateManager.CloseChannel(context.Background(), peer1.ID)
|
updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Creating a new regular user should not update account peers and not send peer update
|
// Creating a new regular user should send peer update (as users are not filtered yet)
|
||||||
t.Run("creating new regular user with no groups", func(t *testing.T) {
|
t.Run("creating new regular user with no groups", func(t *testing.T) {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
peerShouldNotReceiveUpdate(t, updMsg)
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -1402,11 +1402,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// updating user with no linked peers should not update account peers and not send peer update
|
// updating user with no linked peers should update account peers and send peer update (as users are not filtered yet)
|
||||||
t.Run("updating user with no linked peers", func(t *testing.T) {
|
t.Run("updating user with no linked peers", func(t *testing.T) {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
peerShouldNotReceiveUpdate(t, updMsg)
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
216
release_files/freebsd-port-diff.sh
Executable file
216
release_files/freebsd-port-diff.sh
Executable file
@@ -0,0 +1,216 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# FreeBSD Port Diff Generator for NetBird
|
||||||
|
#
|
||||||
|
# This script generates the diff file required for submitting a FreeBSD port update.
|
||||||
|
# It works on macOS, Linux, and FreeBSD by fetching files from FreeBSD cgit and
|
||||||
|
# computing checksums from the Go module proxy.
|
||||||
|
#
|
||||||
|
# Usage: ./freebsd-port-diff.sh [new_version]
|
||||||
|
# Example: ./freebsd-port-diff.sh 0.60.7
|
||||||
|
#
|
||||||
|
# If no version is provided, it fetches the latest from GitHub.
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
GITHUB_REPO="netbirdio/netbird"
|
||||||
|
PORTS_CGIT_BASE="https://cgit.freebsd.org/ports/plain/security/netbird"
|
||||||
|
GO_PROXY="https://proxy.golang.org/github.com/netbirdio/netbird/@v"
|
||||||
|
OUTPUT_DIR="${OUTPUT_DIR:-.}"
|
||||||
|
AWK_FIRST_FIELD='{print $1}'
|
||||||
|
|
||||||
|
fetch_all_tags() {
|
||||||
|
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
|
||||||
|
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
|
||||||
|
sed 's/.*\/v//' | \
|
||||||
|
sort -u -V
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
fetch_current_ports_version() {
|
||||||
|
echo "Fetching current version from FreeBSD ports..." >&2
|
||||||
|
curl -sL "${PORTS_CGIT_BASE}/Makefile" 2>/dev/null | \
|
||||||
|
grep -E "^DISTVERSION=" | \
|
||||||
|
sed 's/DISTVERSION=[[:space:]]*//' | \
|
||||||
|
tr -d '\t '
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
fetch_latest_github_release() {
|
||||||
|
echo "Fetching latest release from GitHub..." >&2
|
||||||
|
fetch_all_tags | tail -1
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
fetch_ports_file() {
|
||||||
|
local filename="$1"
|
||||||
|
curl -sL "${PORTS_CGIT_BASE}/${filename}" 2>/dev/null
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
compute_checksums() {
|
||||||
|
local version="$1"
|
||||||
|
local tmpdir
|
||||||
|
tmpdir=$(mktemp -d)
|
||||||
|
# shellcheck disable=SC2064
|
||||||
|
trap "rm -rf '$tmpdir'" EXIT
|
||||||
|
|
||||||
|
echo "Downloading files from Go module proxy for v${version}..." >&2
|
||||||
|
|
||||||
|
local mod_file="${tmpdir}/v${version}.mod"
|
||||||
|
local zip_file="${tmpdir}/v${version}.zip"
|
||||||
|
|
||||||
|
curl -sL "${GO_PROXY}/v${version}.mod" -o "$mod_file" 2>/dev/null
|
||||||
|
curl -sL "${GO_PROXY}/v${version}.zip" -o "$zip_file" 2>/dev/null
|
||||||
|
|
||||||
|
if [[ ! -s "$mod_file" ]] || [[ ! -s "$zip_file" ]]; then
|
||||||
|
echo "Error: Could not download files from Go module proxy" >&2
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
local mod_sha256 mod_size zip_sha256 zip_size
|
||||||
|
|
||||||
|
if command -v sha256sum &>/dev/null; then
|
||||||
|
mod_sha256=$(sha256sum "$mod_file" | awk "$AWK_FIRST_FIELD")
|
||||||
|
zip_sha256=$(sha256sum "$zip_file" | awk "$AWK_FIRST_FIELD")
|
||||||
|
elif command -v shasum &>/dev/null; then
|
||||||
|
mod_sha256=$(shasum -a 256 "$mod_file" | awk "$AWK_FIRST_FIELD")
|
||||||
|
zip_sha256=$(shasum -a 256 "$zip_file" | awk "$AWK_FIRST_FIELD")
|
||||||
|
else
|
||||||
|
echo "Error: No sha256 command found" >&2
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||||
|
mod_size=$(stat -f%z "$mod_file")
|
||||||
|
zip_size=$(stat -f%z "$zip_file")
|
||||||
|
else
|
||||||
|
mod_size=$(stat -c%s "$mod_file")
|
||||||
|
zip_size=$(stat -c%s "$zip_file")
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "TIMESTAMP = $(date +%s)"
|
||||||
|
echo "SHA256 (go/security_netbird/netbird-v${version}/v${version}.mod) = ${mod_sha256}"
|
||||||
|
echo "SIZE (go/security_netbird/netbird-v${version}/v${version}.mod) = ${mod_size}"
|
||||||
|
echo "SHA256 (go/security_netbird/netbird-v${version}/v${version}.zip) = ${zip_sha256}"
|
||||||
|
echo "SIZE (go/security_netbird/netbird-v${version}/v${version}.zip) = ${zip_size}"
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
generate_new_makefile() {
|
||||||
|
local new_version="$1"
|
||||||
|
local old_makefile="$2"
|
||||||
|
|
||||||
|
# Check if old version had PORTREVISION
|
||||||
|
if echo "$old_makefile" | grep -q "^PORTREVISION="; then
|
||||||
|
# Remove PORTREVISION line and update DISTVERSION
|
||||||
|
echo "$old_makefile" | \
|
||||||
|
sed "s/^DISTVERSION=.*/DISTVERSION= ${new_version}/" | \
|
||||||
|
grep -v "^PORTREVISION="
|
||||||
|
else
|
||||||
|
# Just update DISTVERSION
|
||||||
|
echo "$old_makefile" | \
|
||||||
|
sed "s/^DISTVERSION=.*/DISTVERSION= ${new_version}/"
|
||||||
|
fi
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse arguments
|
||||||
|
NEW_VERSION="${1:-}"
|
||||||
|
|
||||||
|
# Auto-detect versions if not provided
|
||||||
|
OLD_VERSION=$(fetch_current_ports_version)
|
||||||
|
if [[ -z "$OLD_VERSION" ]]; then
|
||||||
|
echo "Error: Could not fetch current version from FreeBSD ports" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Current FreeBSD ports version: ${OLD_VERSION}" >&2
|
||||||
|
|
||||||
|
if [[ -z "$NEW_VERSION" ]]; then
|
||||||
|
NEW_VERSION=$(fetch_latest_github_release)
|
||||||
|
if [[ -z "$NEW_VERSION" ]]; then
|
||||||
|
echo "Error: Could not fetch latest release from GitHub" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
echo "Target version: ${NEW_VERSION}" >&2
|
||||||
|
|
||||||
|
if [[ "$OLD_VERSION" = "$NEW_VERSION" ]]; then
|
||||||
|
echo "Port is already at version ${NEW_VERSION}. Nothing to do." >&2
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "" >&2
|
||||||
|
|
||||||
|
# Fetch current files
|
||||||
|
echo "Fetching current Makefile from FreeBSD ports..." >&2
|
||||||
|
OLD_MAKEFILE=$(fetch_ports_file "Makefile")
|
||||||
|
if [[ -z "$OLD_MAKEFILE" ]]; then
|
||||||
|
echo "Error: Could not fetch Makefile" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Fetching current distinfo from FreeBSD ports..." >&2
|
||||||
|
OLD_DISTINFO=$(fetch_ports_file "distinfo")
|
||||||
|
if [[ -z "$OLD_DISTINFO" ]]; then
|
||||||
|
echo "Error: Could not fetch distinfo" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Generate new files
|
||||||
|
echo "Generating new Makefile..." >&2
|
||||||
|
NEW_MAKEFILE=$(generate_new_makefile "$NEW_VERSION" "$OLD_MAKEFILE")
|
||||||
|
|
||||||
|
echo "Computing checksums for new version..." >&2
|
||||||
|
NEW_DISTINFO=$(compute_checksums "$NEW_VERSION")
|
||||||
|
if [[ -z "$NEW_DISTINFO" ]]; then
|
||||||
|
echo "Error: Could not compute checksums" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Create temp files for diff
|
||||||
|
TMPDIR=$(mktemp -d)
|
||||||
|
# shellcheck disable=SC2064
|
||||||
|
trap "rm -rf '$TMPDIR'" EXIT
|
||||||
|
|
||||||
|
mkdir -p "${TMPDIR}/a/security/netbird" "${TMPDIR}/b/security/netbird"
|
||||||
|
|
||||||
|
echo "$OLD_MAKEFILE" > "${TMPDIR}/a/security/netbird/Makefile"
|
||||||
|
echo "$OLD_DISTINFO" > "${TMPDIR}/a/security/netbird/distinfo"
|
||||||
|
echo "$NEW_MAKEFILE" > "${TMPDIR}/b/security/netbird/Makefile"
|
||||||
|
echo "$NEW_DISTINFO" > "${TMPDIR}/b/security/netbird/distinfo"
|
||||||
|
|
||||||
|
# Generate diff
|
||||||
|
OUTPUT_FILE="${OUTPUT_DIR}/netbird-${NEW_VERSION}.diff"
|
||||||
|
|
||||||
|
echo "" >&2
|
||||||
|
echo "Generating diff..." >&2
|
||||||
|
|
||||||
|
# Generate diff and clean up temp paths to show standard a/b paths
|
||||||
|
(cd "${TMPDIR}" && diff -ruN "a/security/netbird" "b/security/netbird") > "$OUTPUT_FILE" || true
|
||||||
|
|
||||||
|
if [[ ! -s "$OUTPUT_FILE" ]]; then
|
||||||
|
echo "Error: Generated diff is empty" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "" >&2
|
||||||
|
echo "========================================="
|
||||||
|
echo "Diff saved to: ${OUTPUT_FILE}"
|
||||||
|
echo "========================================="
|
||||||
|
echo ""
|
||||||
|
cat "$OUTPUT_FILE"
|
||||||
|
echo ""
|
||||||
|
echo "========================================="
|
||||||
|
echo ""
|
||||||
|
echo "Next steps:"
|
||||||
|
echo "1. Review the diff above"
|
||||||
|
echo "2. Submit to https://bugs.freebsd.org/bugzilla/"
|
||||||
|
echo "3. Use ./freebsd-port-issue-body.sh to generate the issue content"
|
||||||
|
echo ""
|
||||||
|
echo "For FreeBSD testing (optional but recommended):"
|
||||||
|
echo " cd /usr/ports/security/netbird"
|
||||||
|
echo " patch < ${OUTPUT_FILE}"
|
||||||
|
echo " make stage && make stage-qa && make package && make install"
|
||||||
|
echo " netbird status"
|
||||||
|
echo " make deinstall"
|
||||||
159
release_files/freebsd-port-issue-body.sh
Executable file
159
release_files/freebsd-port-issue-body.sh
Executable file
@@ -0,0 +1,159 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# FreeBSD Port Issue Body Generator for NetBird
|
||||||
|
#
|
||||||
|
# This script generates the issue body content for submitting a FreeBSD port update
|
||||||
|
# to the FreeBSD Bugzilla at https://bugs.freebsd.org/bugzilla/
|
||||||
|
#
|
||||||
|
# Usage: ./freebsd-port-issue-body.sh [old_version] [new_version]
|
||||||
|
# Example: ./freebsd-port-issue-body.sh 0.56.0 0.59.1
|
||||||
|
#
|
||||||
|
# If no versions are provided, the script will:
|
||||||
|
# - Fetch OLD version from FreeBSD ports cgit (current version in ports tree)
|
||||||
|
# - Fetch NEW version from latest NetBird GitHub release tag
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
GITHUB_REPO="netbirdio/netbird"
|
||||||
|
PORTS_CGIT_URL="https://cgit.freebsd.org/ports/plain/security/netbird/Makefile"
|
||||||
|
|
||||||
|
fetch_current_ports_version() {
|
||||||
|
echo "Fetching current version from FreeBSD ports..." >&2
|
||||||
|
local makefile_content
|
||||||
|
makefile_content=$(curl -sL "$PORTS_CGIT_URL" 2>/dev/null)
|
||||||
|
if [[ -z "$makefile_content" ]]; then
|
||||||
|
echo "Error: Could not fetch Makefile from FreeBSD ports" >&2
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
echo "$makefile_content" | grep -E "^DISTVERSION=" | sed 's/DISTVERSION=[[:space:]]*//' | tr -d '\t '
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
fetch_all_tags() {
|
||||||
|
# Fetch tags from GitHub tags page (no rate limiting, no auth needed)
|
||||||
|
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
|
||||||
|
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
|
||||||
|
sed 's/.*\/v//' | \
|
||||||
|
sort -u -V
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
fetch_latest_github_release() {
|
||||||
|
echo "Fetching latest release from GitHub..." >&2
|
||||||
|
local latest
|
||||||
|
|
||||||
|
# Fetch from GitHub tags page
|
||||||
|
latest=$(fetch_all_tags | tail -1)
|
||||||
|
|
||||||
|
if [[ -z "$latest" ]]; then
|
||||||
|
# Fallback to GitHub API
|
||||||
|
latest=$(curl -sL "https://api.github.com/repos/${GITHUB_REPO}/releases/latest" 2>/dev/null | \
|
||||||
|
grep '"tag_name"' | sed 's/.*"tag_name": *"v\([^"]*\)".*/\1/')
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -z "$latest" ]]; then
|
||||||
|
echo "Error: Could not fetch latest release from GitHub" >&2
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
echo "$latest"
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
OLD_VERSION="${1:-}"
|
||||||
|
NEW_VERSION="${2:-}"
|
||||||
|
|
||||||
|
if [[ -z "$OLD_VERSION" ]]; then
|
||||||
|
OLD_VERSION=$(fetch_current_ports_version)
|
||||||
|
if [[ -z "$OLD_VERSION" ]]; then
|
||||||
|
echo "Error: Could not determine old version. Please provide it manually." >&2
|
||||||
|
echo "Usage: $0 <old_version> <new_version>" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Detected OLD version from FreeBSD ports: $OLD_VERSION" >&2
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -z "$NEW_VERSION" ]]; then
|
||||||
|
NEW_VERSION=$(fetch_latest_github_release)
|
||||||
|
if [[ -z "$NEW_VERSION" ]]; then
|
||||||
|
echo "Error: Could not determine new version. Please provide it manually." >&2
|
||||||
|
echo "Usage: $0 <old_version> <new_version>" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Detected NEW version from GitHub: $NEW_VERSION" >&2
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$OLD_VERSION" = "$NEW_VERSION" ]]; then
|
||||||
|
echo "Warning: OLD and NEW versions are the same ($OLD_VERSION). Port may already be up to date." >&2
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "" >&2
|
||||||
|
|
||||||
|
OUTPUT_DIR="${OUTPUT_DIR:-.}"
|
||||||
|
|
||||||
|
fetch_releases_between_versions() {
|
||||||
|
echo "Fetching release history from GitHub..." >&2
|
||||||
|
|
||||||
|
# Fetch all tags and filter to those between OLD and NEW versions
|
||||||
|
fetch_all_tags | \
|
||||||
|
while read -r ver; do
|
||||||
|
if [[ "$(printf '%s\n' "$OLD_VERSION" "$ver" | sort -V | head -n1)" = "$OLD_VERSION" ]] && \
|
||||||
|
[[ "$(printf '%s\n' "$ver" "$NEW_VERSION" | sort -V | head -n1)" = "$ver" ]] && \
|
||||||
|
[[ "$ver" != "$OLD_VERSION" ]]; then
|
||||||
|
echo "$ver"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
generate_changelog_section() {
|
||||||
|
local releases
|
||||||
|
releases=$(fetch_releases_between_versions)
|
||||||
|
|
||||||
|
echo "Changelogs:"
|
||||||
|
if [[ -n "$releases" ]]; then
|
||||||
|
echo "$releases" | while read -r ver; do
|
||||||
|
echo "https://github.com/${GITHUB_REPO}/releases/tag/v${ver}"
|
||||||
|
done
|
||||||
|
else
|
||||||
|
echo "https://github.com/${GITHUB_REPO}/releases/tag/v${NEW_VERSION}"
|
||||||
|
fi
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
OUTPUT_FILE="${OUTPUT_DIR}/netbird-${NEW_VERSION}-issue.txt"
|
||||||
|
|
||||||
|
cat << EOF > "$OUTPUT_FILE"
|
||||||
|
BUGZILLA ISSUE DETAILS
|
||||||
|
======================
|
||||||
|
|
||||||
|
Severity: Affects Some People
|
||||||
|
|
||||||
|
Summary: security/netbird: Update to ${NEW_VERSION}
|
||||||
|
|
||||||
|
Description:
|
||||||
|
------------
|
||||||
|
security/netbird: Update ${OLD_VERSION} => ${NEW_VERSION}
|
||||||
|
|
||||||
|
$(generate_changelog_section)
|
||||||
|
|
||||||
|
Commit log:
|
||||||
|
https://github.com/${GITHUB_REPO}/compare/v${OLD_VERSION}...v${NEW_VERSION}
|
||||||
|
EOF
|
||||||
|
|
||||||
|
echo "========================================="
|
||||||
|
echo "Issue body saved to: ${OUTPUT_FILE}"
|
||||||
|
echo "========================================="
|
||||||
|
echo ""
|
||||||
|
cat "$OUTPUT_FILE"
|
||||||
|
echo ""
|
||||||
|
echo "========================================="
|
||||||
|
echo ""
|
||||||
|
echo "Next steps:"
|
||||||
|
echo "1. Go to https://bugs.freebsd.org/bugzilla/ and login"
|
||||||
|
echo "2. Click 'Report an update or defect to a port'"
|
||||||
|
echo "3. Fill in:"
|
||||||
|
echo " - Severity: Affects Some People"
|
||||||
|
echo " - Summary: security/netbird: Update to ${NEW_VERSION}"
|
||||||
|
echo " - Description: Copy content from ${OUTPUT_FILE}"
|
||||||
|
echo "4. Attach diff file: netbird-${NEW_VERSION}.diff"
|
||||||
|
echo "5. Submit the bug report"
|
||||||
@@ -111,6 +111,8 @@ func (c *GrpcClient) ready() bool {
|
|||||||
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
|
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
|
||||||
// Blocking request. The result will be sent via msgHandler callback function
|
// Blocking request. The result will be sent via msgHandler callback function
|
||||||
func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
|
func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||||
|
backOff := defaultBackoff(ctx)
|
||||||
|
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
log.Debugf("management connection state %v", c.conn.GetState())
|
log.Debugf("management connection state %v", c.conn.GetState())
|
||||||
connState := c.conn.GetState()
|
connState := c.conn.GetState()
|
||||||
@@ -128,10 +130,10 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler)
|
return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler, backOff)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := backoff.Retry(operation, defaultBackoff(ctx))
|
err := backoff.Retry(operation, backOff)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("exiting the Management service connection retry loop due to the unrecoverable error: %s", err)
|
log.Warnf("exiting the Management service connection retry loop due to the unrecoverable error: %s", err)
|
||||||
}
|
}
|
||||||
@@ -140,7 +142,7 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info,
|
func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info,
|
||||||
msgHandler func(msg *proto.SyncResponse) error) error {
|
msgHandler func(msg *proto.SyncResponse) error, backOff backoff.BackOff) error {
|
||||||
ctx, cancelStream := context.WithCancel(ctx)
|
ctx, cancelStream := context.WithCancel(ctx)
|
||||||
defer cancelStream()
|
defer cancelStream()
|
||||||
|
|
||||||
@@ -158,6 +160,9 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key,
|
|||||||
|
|
||||||
// blocking until error
|
// blocking until error
|
||||||
err = c.receiveEvents(stream, serverPubKey, msgHandler)
|
err = c.receiveEvents(stream, serverPubKey, msgHandler)
|
||||||
|
// we need this reset because after a successful connection and a consequent error, backoff lib doesn't
|
||||||
|
// reset times and next try will start with a long delay
|
||||||
|
backOff.Reset()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.notifyDisconnected(err)
|
c.notifyDisconnected(err)
|
||||||
s, _ := gstatus.FromError(err)
|
s, _ := gstatus.FromError(err)
|
||||||
|
|||||||
@@ -488,6 +488,8 @@ components:
|
|||||||
description: Indicates whether the peer is ephemeral or not
|
description: Indicates whether the peer is ephemeral or not
|
||||||
type: boolean
|
type: boolean
|
||||||
example: false
|
example: false
|
||||||
|
local_flags:
|
||||||
|
$ref: '#/components/schemas/PeerLocalFlags'
|
||||||
required:
|
required:
|
||||||
- city_name
|
- city_name
|
||||||
- connected
|
- connected
|
||||||
@@ -514,6 +516,49 @@ components:
|
|||||||
- serial_number
|
- serial_number
|
||||||
- extra_dns_labels
|
- extra_dns_labels
|
||||||
- ephemeral
|
- ephemeral
|
||||||
|
PeerLocalFlags:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
rosenpass_enabled:
|
||||||
|
description: Indicates whether Rosenpass is enabled on this peer
|
||||||
|
type: boolean
|
||||||
|
example: true
|
||||||
|
rosenpass_permissive:
|
||||||
|
description: Indicates whether Rosenpass is in permissive mode or not
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
|
server_ssh_allowed:
|
||||||
|
description: Indicates whether SSH access this peer is allowed or not
|
||||||
|
type: boolean
|
||||||
|
example: true
|
||||||
|
disable_client_routes:
|
||||||
|
description: Indicates whether client routes are disabled on this peer or not
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
|
disable_server_routes:
|
||||||
|
description: Indicates whether server routes are disabled on this peer or not
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
|
disable_dns:
|
||||||
|
description: Indicates whether DNS management is disabled on this peer or not
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
|
disable_firewall:
|
||||||
|
description: Indicates whether firewall management is disabled on this peer or not
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
|
block_lan_access:
|
||||||
|
description: Indicates whether LAN access is blocked on this peer when used as a routing peer
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
|
block_inbound:
|
||||||
|
description: Indicates whether inbound traffic is blocked on this peer
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
|
lazy_connection_enabled:
|
||||||
|
description: Indicates whether lazy connection is enabled on this peer
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
PeerTemporaryAccessRequest:
|
PeerTemporaryAccessRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -936,7 +981,7 @@ components:
|
|||||||
protocol:
|
protocol:
|
||||||
description: Policy rule type of the traffic
|
description: Policy rule type of the traffic
|
||||||
type: string
|
type: string
|
||||||
enum: ["all", "tcp", "udp", "icmp"]
|
enum: ["all", "tcp", "udp", "icmp", "netbird-ssh"]
|
||||||
example: "tcp"
|
example: "tcp"
|
||||||
ports:
|
ports:
|
||||||
description: Policy rule affected ports
|
description: Policy rule affected ports
|
||||||
@@ -949,6 +994,14 @@ components:
|
|||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/RulePortRange'
|
$ref: '#/components/schemas/RulePortRange'
|
||||||
|
authorized_groups:
|
||||||
|
description: Map of user group ids to a list of local users
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
example: "group1"
|
||||||
required:
|
required:
|
||||||
- name
|
- name
|
||||||
- enabled
|
- enabled
|
||||||
|
|||||||
@@ -130,10 +130,11 @@ const (
|
|||||||
|
|
||||||
// Defines values for PolicyRuleProtocol.
|
// Defines values for PolicyRuleProtocol.
|
||||||
const (
|
const (
|
||||||
PolicyRuleProtocolAll PolicyRuleProtocol = "all"
|
PolicyRuleProtocolAll PolicyRuleProtocol = "all"
|
||||||
PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp"
|
PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp"
|
||||||
PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp"
|
PolicyRuleProtocolNetbirdSsh PolicyRuleProtocol = "netbird-ssh"
|
||||||
PolicyRuleProtocolUdp PolicyRuleProtocol = "udp"
|
PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp"
|
||||||
|
PolicyRuleProtocolUdp PolicyRuleProtocol = "udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defines values for PolicyRuleMinimumAction.
|
// Defines values for PolicyRuleMinimumAction.
|
||||||
@@ -144,10 +145,11 @@ const (
|
|||||||
|
|
||||||
// Defines values for PolicyRuleMinimumProtocol.
|
// Defines values for PolicyRuleMinimumProtocol.
|
||||||
const (
|
const (
|
||||||
PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all"
|
PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all"
|
||||||
PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp"
|
PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp"
|
||||||
PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp"
|
PolicyRuleMinimumProtocolNetbirdSsh PolicyRuleMinimumProtocol = "netbird-ssh"
|
||||||
PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp"
|
PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp"
|
||||||
|
PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defines values for PolicyRuleUpdateAction.
|
// Defines values for PolicyRuleUpdateAction.
|
||||||
@@ -158,10 +160,11 @@ const (
|
|||||||
|
|
||||||
// Defines values for PolicyRuleUpdateProtocol.
|
// Defines values for PolicyRuleUpdateProtocol.
|
||||||
const (
|
const (
|
||||||
PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all"
|
PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all"
|
||||||
PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp"
|
PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp"
|
||||||
PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp"
|
PolicyRuleUpdateProtocolNetbirdSsh PolicyRuleUpdateProtocol = "netbird-ssh"
|
||||||
PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp"
|
PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp"
|
||||||
|
PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defines values for ResourceType.
|
// Defines values for ResourceType.
|
||||||
@@ -1077,7 +1080,8 @@ type Peer struct {
|
|||||||
LastLogin time.Time `json:"last_login"`
|
LastLogin time.Time `json:"last_login"`
|
||||||
|
|
||||||
// LastSeen Last time peer connected to Netbird's management service
|
// LastSeen Last time peer connected to Netbird's management service
|
||||||
LastSeen time.Time `json:"last_seen"`
|
LastSeen time.Time `json:"last_seen"`
|
||||||
|
LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"`
|
||||||
|
|
||||||
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
|
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
|
||||||
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
||||||
@@ -1167,7 +1171,8 @@ type PeerBatch struct {
|
|||||||
LastLogin time.Time `json:"last_login"`
|
LastLogin time.Time `json:"last_login"`
|
||||||
|
|
||||||
// LastSeen Last time peer connected to Netbird's management service
|
// LastSeen Last time peer connected to Netbird's management service
|
||||||
LastSeen time.Time `json:"last_seen"`
|
LastSeen time.Time `json:"last_seen"`
|
||||||
|
LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"`
|
||||||
|
|
||||||
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
|
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
|
||||||
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
||||||
@@ -1197,6 +1202,39 @@ type PeerBatch struct {
|
|||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PeerLocalFlags defines model for PeerLocalFlags.
|
||||||
|
type PeerLocalFlags struct {
|
||||||
|
// BlockInbound Indicates whether inbound traffic is blocked on this peer
|
||||||
|
BlockInbound *bool `json:"block_inbound,omitempty"`
|
||||||
|
|
||||||
|
// BlockLanAccess Indicates whether LAN access is blocked on this peer when used as a routing peer
|
||||||
|
BlockLanAccess *bool `json:"block_lan_access,omitempty"`
|
||||||
|
|
||||||
|
// DisableClientRoutes Indicates whether client routes are disabled on this peer or not
|
||||||
|
DisableClientRoutes *bool `json:"disable_client_routes,omitempty"`
|
||||||
|
|
||||||
|
// DisableDns Indicates whether DNS management is disabled on this peer or not
|
||||||
|
DisableDns *bool `json:"disable_dns,omitempty"`
|
||||||
|
|
||||||
|
// DisableFirewall Indicates whether firewall management is disabled on this peer or not
|
||||||
|
DisableFirewall *bool `json:"disable_firewall,omitempty"`
|
||||||
|
|
||||||
|
// DisableServerRoutes Indicates whether server routes are disabled on this peer or not
|
||||||
|
DisableServerRoutes *bool `json:"disable_server_routes,omitempty"`
|
||||||
|
|
||||||
|
// LazyConnectionEnabled Indicates whether lazy connection is enabled on this peer
|
||||||
|
LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"`
|
||||||
|
|
||||||
|
// RosenpassEnabled Indicates whether Rosenpass is enabled on this peer
|
||||||
|
RosenpassEnabled *bool `json:"rosenpass_enabled,omitempty"`
|
||||||
|
|
||||||
|
// RosenpassPermissive Indicates whether Rosenpass is in permissive mode or not
|
||||||
|
RosenpassPermissive *bool `json:"rosenpass_permissive,omitempty"`
|
||||||
|
|
||||||
|
// ServerSshAllowed Indicates whether SSH access this peer is allowed or not
|
||||||
|
ServerSshAllowed *bool `json:"server_ssh_allowed,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// PeerMinimum defines model for PeerMinimum.
|
// PeerMinimum defines model for PeerMinimum.
|
||||||
type PeerMinimum struct {
|
type PeerMinimum struct {
|
||||||
// Id Peer ID
|
// Id Peer ID
|
||||||
@@ -1349,6 +1387,9 @@ type PolicyRule struct {
|
|||||||
// Action Policy rule accept or drops packets
|
// Action Policy rule accept or drops packets
|
||||||
Action PolicyRuleAction `json:"action"`
|
Action PolicyRuleAction `json:"action"`
|
||||||
|
|
||||||
|
// AuthorizedGroups Map of user group ids to a list of local users
|
||||||
|
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
|
||||||
|
|
||||||
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
||||||
Bidirectional bool `json:"bidirectional"`
|
Bidirectional bool `json:"bidirectional"`
|
||||||
|
|
||||||
@@ -1393,6 +1434,9 @@ type PolicyRuleMinimum struct {
|
|||||||
// Action Policy rule accept or drops packets
|
// Action Policy rule accept or drops packets
|
||||||
Action PolicyRuleMinimumAction `json:"action"`
|
Action PolicyRuleMinimumAction `json:"action"`
|
||||||
|
|
||||||
|
// AuthorizedGroups Map of user group ids to a list of local users
|
||||||
|
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
|
||||||
|
|
||||||
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
||||||
Bidirectional bool `json:"bidirectional"`
|
Bidirectional bool `json:"bidirectional"`
|
||||||
|
|
||||||
@@ -1426,6 +1470,9 @@ type PolicyRuleUpdate struct {
|
|||||||
// Action Policy rule accept or drops packets
|
// Action Policy rule accept or drops packets
|
||||||
Action PolicyRuleUpdateAction `json:"action"`
|
Action PolicyRuleUpdateAction `json:"action"`
|
||||||
|
|
||||||
|
// AuthorizedGroups Map of user group ids to a list of local users
|
||||||
|
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
|
||||||
|
|
||||||
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
||||||
Bidirectional bool `json:"bidirectional"`
|
Bidirectional bool `json:"bidirectional"`
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -332,6 +332,24 @@ message NetworkMap {
|
|||||||
bool routesFirewallRulesIsEmpty = 11;
|
bool routesFirewallRulesIsEmpty = 11;
|
||||||
|
|
||||||
repeated ForwardingRule forwardingRules = 12;
|
repeated ForwardingRule forwardingRules = 12;
|
||||||
|
|
||||||
|
// SSHAuth represents SSH authorization configuration
|
||||||
|
SSHAuth sshAuth = 13;
|
||||||
|
}
|
||||||
|
|
||||||
|
message SSHAuth {
|
||||||
|
// UserIDClaim is the JWT claim to be used to get the users ID
|
||||||
|
string UserIDClaim = 1;
|
||||||
|
|
||||||
|
// AuthorizedUsers is a list of hashed user IDs authorized to access this peer via SSH
|
||||||
|
repeated bytes AuthorizedUsers = 2;
|
||||||
|
|
||||||
|
// MachineUsers is a map of machine user names to their corresponding indexes in the AuthorizedUsers list
|
||||||
|
map<string, MachineUserIndexes> machine_users = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message MachineUserIndexes {
|
||||||
|
repeated uint32 indexes = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemotePeerConfig represents a configuration of a remote peer.
|
// RemotePeerConfig represents a configuration of a remote peer.
|
||||||
|
|||||||
28
shared/sshauth/userhash.go
Normal file
28
shared/sshauth/userhash.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package sshauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/blake2b"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserIDHash represents a hashed user ID (BLAKE2b-128)
|
||||||
|
type UserIDHash [16]byte
|
||||||
|
|
||||||
|
// HashUserID hashes a user ID using BLAKE2b-128 and returns the hash value
|
||||||
|
// This function must produce the same hash on both client and management server
|
||||||
|
func HashUserID(userID string) (UserIDHash, error) {
|
||||||
|
hash, err := blake2b.New(16, nil)
|
||||||
|
if err != nil {
|
||||||
|
return UserIDHash{}, err
|
||||||
|
}
|
||||||
|
hash.Write([]byte(userID))
|
||||||
|
var result UserIDHash
|
||||||
|
copy(result[:], hash.Sum(nil))
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the hexadecimal string representation of the hash
|
||||||
|
func (h UserIDHash) String() string {
|
||||||
|
return hex.EncodeToString(h[:])
|
||||||
|
}
|
||||||
210
shared/sshauth/userhash_test.go
Normal file
210
shared/sshauth/userhash_test.go
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
package sshauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHashUserID(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
userID string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple user ID",
|
||||||
|
userID: "user@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UUID format",
|
||||||
|
userID: "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "numeric ID",
|
||||||
|
userID: "12345",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
userID: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "special characters",
|
||||||
|
userID: "user+test@domain.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unicode characters",
|
||||||
|
userID: "用户@example.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
hash, err := HashUserID(tt.userID)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("HashUserID() error = %v, want nil", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify hash is non-zero for non-empty inputs
|
||||||
|
if tt.userID != "" && hash == [16]byte{} {
|
||||||
|
t.Errorf("HashUserID() returned zero hash for non-empty input")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashUserID_Consistency(t *testing.T) {
|
||||||
|
userID := "test@example.com"
|
||||||
|
|
||||||
|
hash1, err1 := HashUserID(userID)
|
||||||
|
if err1 != nil {
|
||||||
|
t.Fatalf("First HashUserID() error = %v", err1)
|
||||||
|
}
|
||||||
|
|
||||||
|
hash2, err2 := HashUserID(userID)
|
||||||
|
if err2 != nil {
|
||||||
|
t.Fatalf("Second HashUserID() error = %v", err2)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hash1 != hash2 {
|
||||||
|
t.Errorf("HashUserID() is not consistent: got %v and %v for same input", hash1, hash2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashUserID_Uniqueness(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
userID1 string
|
||||||
|
userID2 string
|
||||||
|
}{
|
||||||
|
{"user1@example.com", "user2@example.com"},
|
||||||
|
{"alice@domain.com", "bob@domain.com"},
|
||||||
|
{"test", "test1"},
|
||||||
|
{"", "a"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
hash1, err1 := HashUserID(tt.userID1)
|
||||||
|
if err1 != nil {
|
||||||
|
t.Fatalf("HashUserID(%s) error = %v", tt.userID1, err1)
|
||||||
|
}
|
||||||
|
|
||||||
|
hash2, err2 := HashUserID(tt.userID2)
|
||||||
|
if err2 != nil {
|
||||||
|
t.Fatalf("HashUserID(%s) error = %v", tt.userID2, err2)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hash1 == hash2 {
|
||||||
|
t.Errorf("HashUserID() collision: %s and %s produced same hash %v", tt.userID1, tt.userID2, hash1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserIDHash_String(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
hash UserIDHash
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "zero hash",
|
||||||
|
hash: [16]byte{},
|
||||||
|
expected: "00000000000000000000000000000000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "small value",
|
||||||
|
hash: [16]byte{15: 0xff},
|
||||||
|
expected: "000000000000000000000000000000ff",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large value",
|
||||||
|
hash: [16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe},
|
||||||
|
expected: "0000000000000000deadbeefcafebabe",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "max value",
|
||||||
|
hash: [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||||
|
expected: "ffffffffffffffffffffffffffffffff",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.hash.String()
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("UserIDHash.String() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserIDHash_String_Length(t *testing.T) {
|
||||||
|
// Test that String() always returns 32 hex characters (16 bytes * 2)
|
||||||
|
userID := "test@example.com"
|
||||||
|
hash, err := HashUserID(userID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("HashUserID() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := hash.String()
|
||||||
|
if len(result) != 32 {
|
||||||
|
t.Errorf("UserIDHash.String() length = %d, want 32", len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's valid hex
|
||||||
|
for i, c := range result {
|
||||||
|
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||||
|
t.Errorf("UserIDHash.String() contains non-hex character at position %d: %c", i, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashUserID_KnownValues(t *testing.T) {
|
||||||
|
// Test with known BLAKE2b-128 values to ensure correct implementation
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
userID string
|
||||||
|
expected UserIDHash
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
userID: "",
|
||||||
|
// BLAKE2b-128 of empty string
|
||||||
|
expected: [16]byte{0xca, 0xe6, 0x69, 0x41, 0xd9, 0xef, 0xbd, 0x40, 0x4e, 0x4d, 0x88, 0x75, 0x8e, 0xa6, 0x76, 0x70},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single character 'a'",
|
||||||
|
userID: "a",
|
||||||
|
// BLAKE2b-128 of "a"
|
||||||
|
expected: [16]byte{0x27, 0xc3, 0x5e, 0x6e, 0x93, 0x73, 0x87, 0x7f, 0x29, 0xe5, 0x62, 0x46, 0x4e, 0x46, 0x49, 0x7e},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
hash, err := HashUserID(tt.userID)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("HashUserID() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hash != tt.expected {
|
||||||
|
t.Errorf("HashUserID(%q) = %x, want %x",
|
||||||
|
tt.userID, hash, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkHashUserID(b *testing.B) {
|
||||||
|
userID := "user@example.com"
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = HashUserID(userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUserIDHash_String(b *testing.B) {
|
||||||
|
hash := UserIDHash([16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe})
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = hash.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,12 +2,10 @@ package semaphoregroup
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sync"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SemaphoreGroup is a custom type that combines sync.WaitGroup and a semaphore.
|
// SemaphoreGroup is a custom type that combines sync.WaitGroup and a semaphore.
|
||||||
type SemaphoreGroup struct {
|
type SemaphoreGroup struct {
|
||||||
waitGroup sync.WaitGroup
|
|
||||||
semaphore chan struct{}
|
semaphore chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -18,31 +16,18 @@ func NewSemaphoreGroup(limit int) *SemaphoreGroup {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add increments the internal WaitGroup counter and acquires a semaphore slot.
|
// Add acquire a slot
|
||||||
func (sg *SemaphoreGroup) Add(ctx context.Context) {
|
func (sg *SemaphoreGroup) Add(ctx context.Context) error {
|
||||||
sg.waitGroup.Add(1)
|
|
||||||
|
|
||||||
// Acquire semaphore slot
|
// Acquire semaphore slot
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return ctx.Err()
|
||||||
case sg.semaphore <- struct{}{}:
|
case sg.semaphore <- struct{}{}:
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Done decrements the internal WaitGroup counter and releases a semaphore slot.
|
// Done releases a slot. Must be called after a successful Add.
|
||||||
func (sg *SemaphoreGroup) Done(ctx context.Context) {
|
func (sg *SemaphoreGroup) Done() {
|
||||||
sg.waitGroup.Done()
|
<-sg.semaphore
|
||||||
|
|
||||||
// Release semaphore slot
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-sg.semaphore:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait waits until the internal WaitGroup counter is zero.
|
|
||||||
func (sg *SemaphoreGroup) Wait() {
|
|
||||||
sg.waitGroup.Wait()
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,65 +2,89 @@ package semaphoregroup
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSemaphoreGroup(t *testing.T) {
|
func TestSemaphoreGroup(t *testing.T) {
|
||||||
semGroup := NewSemaphoreGroup(2)
|
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
semGroup.Add(context.Background())
|
|
||||||
go func(id int) {
|
|
||||||
defer semGroup.Done(context.Background())
|
|
||||||
|
|
||||||
got := len(semGroup.semaphore)
|
|
||||||
if got == 0 {
|
|
||||||
t.Errorf("Expected semaphore length > 0 , got 0")
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(time.Millisecond)
|
|
||||||
t.Logf("Goroutine %d is running\n", id)
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
semGroup.Wait()
|
|
||||||
|
|
||||||
want := 0
|
|
||||||
got := len(semGroup.semaphore)
|
|
||||||
if got != want {
|
|
||||||
t.Errorf("Expected semaphore length %d, got %d", want, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSemaphoreGroupContext(t *testing.T) {
|
|
||||||
semGroup := NewSemaphoreGroup(1)
|
semGroup := NewSemaphoreGroup(1)
|
||||||
semGroup.Add(context.Background())
|
_ = semGroup.Add(context.Background())
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
|
||||||
|
ctxTimeout, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||||
t.Cleanup(cancel)
|
t.Cleanup(cancel)
|
||||||
rChan := make(chan struct{})
|
|
||||||
|
|
||||||
go func() {
|
if err := semGroup.Add(ctxTimeout); err == nil {
|
||||||
semGroup.Add(ctx)
|
t.Error("Adding to semaphore group should not block")
|
||||||
rChan <- struct{}{}
|
}
|
||||||
}()
|
}
|
||||||
select {
|
|
||||||
case <-rChan:
|
func TestSemaphoreGroupFreeUp(t *testing.T) {
|
||||||
case <-time.NewTimer(2 * time.Second).C:
|
semGroup := NewSemaphoreGroup(1)
|
||||||
t.Error("Adding to semaphore group should not block when context is not done")
|
_ = semGroup.Add(context.Background())
|
||||||
}
|
semGroup.Done()
|
||||||
|
|
||||||
semGroup.Done(context.Background())
|
ctxTimeout, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||||
|
t.Cleanup(cancel)
|
||||||
ctxDone, cancelDone := context.WithTimeout(context.Background(), 1*time.Second)
|
if err := semGroup.Add(ctxTimeout); err != nil {
|
||||||
t.Cleanup(cancelDone)
|
t.Error(err)
|
||||||
go func() {
|
}
|
||||||
semGroup.Done(ctxDone)
|
}
|
||||||
rChan <- struct{}{}
|
|
||||||
}()
|
func TestSemaphoreGroupCanceledContext(t *testing.T) {
|
||||||
select {
|
semGroup := NewSemaphoreGroup(1)
|
||||||
case <-rChan:
|
_ = semGroup.Add(context.Background())
|
||||||
case <-time.NewTimer(2 * time.Second).C:
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
t.Error("Releasing from semaphore group should not block when context is not done")
|
cancel() // Cancel immediately
|
||||||
|
|
||||||
|
if err := semGroup.Add(ctx); err == nil {
|
||||||
|
t.Error("Add should return error when context is already canceled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSemaphoreGroupCancelWhileWaiting(t *testing.T) {
|
||||||
|
semGroup := NewSemaphoreGroup(1)
|
||||||
|
_ = semGroup.Add(context.Background())
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
errChan <- semGroup.Add(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err := <-errChan; err == nil {
|
||||||
|
t.Error("Add should return error when context is canceled while waiting")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSemaphoreGroupHighConcurrency(t *testing.T) {
|
||||||
|
const limit = 10
|
||||||
|
const numGoroutines = 100
|
||||||
|
|
||||||
|
semGroup := NewSemaphoreGroup(limit)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := semGroup.Add(context.Background()); err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
semGroup.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Verify all slots were released
|
||||||
|
if got := len(semGroup.semaphore); got != 0 {
|
||||||
|
t.Errorf("Expected semaphore to be empty, got %d slots occupied", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user