mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-05 23:49:54 +00:00
Compare commits
2 Commits
dn-reverse
...
trigger-pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfeb60fbb5 | ||
|
|
ea41cf2d2c |
@@ -1,6 +0,0 @@
|
|||||||
.env
|
|
||||||
.env.*
|
|
||||||
*.pem
|
|
||||||
*.key
|
|
||||||
*.crt
|
|
||||||
*.p12
|
|
||||||
10
.github/workflows/check-license-dependencies.yml
vendored
10
.github/workflows/check-license-dependencies.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Check for problematic license dependencies
|
- name: Check for problematic license dependencies
|
||||||
run: |
|
run: |
|
||||||
echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..."
|
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Find all directories except the problematic ones and system dirs
|
# Find all directories except the problematic ones and system dirs
|
||||||
@@ -31,7 +31,7 @@ jobs:
|
|||||||
while IFS= read -r dir; do
|
while IFS= read -r dir; do
|
||||||
echo "=== Checking $dir ==="
|
echo "=== Checking $dir ==="
|
||||||
# Search for problematic imports, excluding test files
|
# Search for problematic imports, excluding test files
|
||||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||||
if [ -n "$RESULTS" ]; then
|
if [ -n "$RESULTS" ]; then
|
||||||
echo "❌ Found problematic dependencies:"
|
echo "❌ Found problematic dependencies:"
|
||||||
echo "$RESULTS"
|
echo "$RESULTS"
|
||||||
@@ -39,11 +39,11 @@ jobs:
|
|||||||
else
|
else
|
||||||
echo "✓ No problematic dependencies found"
|
echo "✓ No problematic dependencies found"
|
||||||
fi
|
fi
|
||||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name ".git*" | sort)
|
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||||
echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages"
|
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||||
exit 1
|
exit 1
|
||||||
else
|
else
|
||||||
@@ -88,7 +88,7 @@ jobs:
|
|||||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||||
|
|
||||||
# Check if any importer is NOT in management/signal/relay
|
# Check if any importer is NOT in management/signal/relay
|
||||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" | head -1)
|
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
|
||||||
|
|
||||||
if [ -n "$BSD_IMPORTER" ]; then
|
if [ -n "$BSD_IMPORTER" ]; then
|
||||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||||
|
|||||||
2
.github/workflows/golang-test-darwin.yml
vendored
2
.github/workflows/golang-test-darwin.yml
vendored
@@ -43,5 +43,5 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
|
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
||||||
|
|
||||||
|
|||||||
1
.github/workflows/golang-test-freebsd.yml
vendored
1
.github/workflows/golang-test-freebsd.yml
vendored
@@ -46,5 +46,6 @@ jobs:
|
|||||||
time go test -timeout 1m -failfast ./client/iface/...
|
time go test -timeout 1m -failfast ./client/iface/...
|
||||||
time go test -timeout 1m -failfast ./route/...
|
time go test -timeout 1m -failfast ./route/...
|
||||||
time go test -timeout 1m -failfast ./sharedsock/...
|
time go test -timeout 1m -failfast ./sharedsock/...
|
||||||
|
time go test -timeout 1m -failfast ./signal/...
|
||||||
time go test -timeout 1m -failfast ./util/...
|
time go test -timeout 1m -failfast ./util/...
|
||||||
time go test -timeout 1m -failfast ./version/...
|
time go test -timeout 1m -failfast ./version/...
|
||||||
|
|||||||
51
.github/workflows/golang-test-linux.yml
vendored
51
.github/workflows/golang-test-linux.yml
vendored
@@ -144,7 +144,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
name: "Client (Docker) / Unit"
|
name: "Client (Docker) / Unit"
|
||||||
@@ -204,7 +204,7 @@ jobs:
|
|||||||
sh -c ' \
|
sh -c ' \
|
||||||
apk update; apk add --no-cache \
|
apk update; apk add --no-cache \
|
||||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /client/ui -e /upload-server)
|
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
|
||||||
'
|
'
|
||||||
|
|
||||||
test_relay:
|
test_relay:
|
||||||
@@ -261,53 +261,6 @@ jobs:
|
|||||||
-exec 'sudo' \
|
-exec 'sudo' \
|
||||||
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
||||||
|
|
||||||
test_proxy:
|
|
||||||
name: "Proxy / Unit"
|
|
||||||
needs: [build-cache]
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
arch: [ '386','amd64' ]
|
|
||||||
runs-on: ubuntu-22.04
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Install Go
|
|
||||||
uses: actions/setup-go@v5
|
|
||||||
with:
|
|
||||||
go-version-file: "go.mod"
|
|
||||||
cache: false
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
|
||||||
|
|
||||||
- name: Get Go environment
|
|
||||||
run: |
|
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
|
||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Cache Go modules
|
|
||||||
uses: actions/cache/restore@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
${{ env.cache }}
|
|
||||||
${{ env.modcache }}
|
|
||||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-gotest-cache-
|
|
||||||
|
|
||||||
- name: Install modules
|
|
||||||
run: go mod tidy
|
|
||||||
|
|
||||||
- name: check git status
|
|
||||||
run: git --no-pager diff --exit-code
|
|
||||||
|
|
||||||
- name: Test
|
|
||||||
run: |
|
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
|
||||||
go test -timeout 10m -p 1 ./proxy/...
|
|
||||||
|
|
||||||
test_signal:
|
test_signal:
|
||||||
name: "Signal / Unit"
|
name: "Signal / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
|
|||||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
|||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' })" >> $env:GITHUB_ENV
|
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||||
|
|||||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
|||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
|
||||||
skip: go.mod,go.sum,**/proxy/web/**
|
skip: go.mod,go.sum
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|||||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.1.1"
|
SIGN_PIPE_VER: "v0.1.0"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,7 +2,6 @@
|
|||||||
.run
|
.run
|
||||||
*.iml
|
*.iml
|
||||||
dist/
|
dist/
|
||||||
!proxy/web/dist/
|
|
||||||
bin/
|
bin/
|
||||||
.env
|
.env
|
||||||
conf.json
|
conf.json
|
||||||
|
|||||||
@@ -60,8 +60,8 @@
|
|||||||
|
|
||||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||||
|
|
||||||
### Self-Host NetBird (Video)
|
### NetBird on Lawrence Systems (Video)
|
||||||
[](https://youtu.be/bZAgpT6nzaQ)
|
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||||
|
|
||||||
### Key features
|
### Key features
|
||||||
|
|
||||||
|
|||||||
@@ -282,9 +282,13 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
|||||||
}
|
}
|
||||||
defer authClient.Close()
|
defer authClient.Close()
|
||||||
|
|
||||||
needsLogin, err := authClient.IsLoginRequired(ctx)
|
needsLogin := false
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("check login required: %v", err)
|
err, isAuthError := authClient.Login(ctx, "", "")
|
||||||
|
if isAuthError {
|
||||||
|
needsLogin = true
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("login check failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
|
|||||||
@@ -31,14 +31,6 @@ var (
|
|||||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||||
)
|
)
|
||||||
|
|
||||||
// PeerConnStatus is a peer's connection status.
|
|
||||||
type PeerConnStatus = peer.ConnStatus
|
|
||||||
|
|
||||||
const (
|
|
||||||
// PeerStatusConnected indicates the peer is in connected state.
|
|
||||||
PeerStatusConnected = peer.StatusConnected
|
|
||||||
)
|
|
||||||
|
|
||||||
// Client manages a netbird embedded client instance.
|
// Client manages a netbird embedded client instance.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
deviceName string
|
deviceName string
|
||||||
@@ -79,8 +71,6 @@ type Options struct {
|
|||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
// BlockInbound blocks all inbound connections from peers
|
// BlockInbound blocks all inbound connections from peers
|
||||||
BlockInbound bool
|
BlockInbound bool
|
||||||
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
|
||||||
WireguardPort *int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateCredentials checks that exactly one credential type is provided
|
// validateCredentials checks that exactly one credential type is provided
|
||||||
@@ -150,7 +140,6 @@ func New(opts Options) (*Client, error) {
|
|||||||
DisableServerRoutes: &t,
|
DisableServerRoutes: &t,
|
||||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||||
BlockInbound: &opts.BlockInbound,
|
BlockInbound: &opts.BlockInbound,
|
||||||
WireguardPort: opts.WireguardPort,
|
|
||||||
}
|
}
|
||||||
if opts.ConfigPath != "" {
|
if opts.ConfigPath != "" {
|
||||||
config, err = profilemanager.UpdateOrCreateConfig(input)
|
config, err = profilemanager.UpdateOrCreateConfig(input)
|
||||||
@@ -170,7 +159,6 @@ func New(opts Options) (*Client, error) {
|
|||||||
setupKey: opts.SetupKey,
|
setupKey: opts.SetupKey,
|
||||||
jwtToken: opts.JWTToken,
|
jwtToken: opts.JWTToken,
|
||||||
config: config,
|
config: config,
|
||||||
recorder: peer.NewRecorder(config.ManagementURL.String()),
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,7 +180,6 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
|
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||||
|
|
||||||
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create auth client: %w", err)
|
return fmt.Errorf("create auth client: %w", err)
|
||||||
@@ -202,7 +189,10 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||||
return fmt.Errorf("login: %w", err)
|
return fmt.Errorf("login: %w", err)
|
||||||
}
|
}
|
||||||
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
|
|
||||||
|
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||||
|
c.recorder = recorder
|
||||||
|
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
||||||
client.SetSyncResponsePersistence(true)
|
client.SetSyncResponsePersistence(true)
|
||||||
|
|
||||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
@@ -355,9 +345,14 @@ func (c *Client) NewHTTPClient() *http.Client {
|
|||||||
// Status returns the current status of the client.
|
// Status returns the current status of the client.
|
||||||
func (c *Client) Status() (peer.FullStatus, error) {
|
func (c *Client) Status() (peer.FullStatus, error) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
|
recorder := c.recorder
|
||||||
connect := c.connect
|
connect := c.connect
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if recorder == nil {
|
||||||
|
return peer.FullStatus{}, errors.New("client not started")
|
||||||
|
}
|
||||||
|
|
||||||
if connect != nil {
|
if connect != nil {
|
||||||
engine := connect.Engine()
|
engine := connect.Engine()
|
||||||
if engine != nil {
|
if engine != nil {
|
||||||
@@ -365,7 +360,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.recorder.GetFullStatus(), nil
|
return recorder.GetFullStatus(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLatestSyncResponse returns the latest sync response from the management server.
|
// GetLatestSyncResponse returns the latest sync response from the management server.
|
||||||
|
|||||||
@@ -483,12 +483,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if nftRule.Handle == 0 {
|
if nftRule.Handle == 0 {
|
||||||
log.Warnf("route rule %s has no handle, removing stale entry", ruleKey)
|
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
||||||
if err := r.decrementSetCounter(nftRule); err != nil {
|
|
||||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleKey)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||||
@@ -665,32 +660,13 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
r.rollbackRules(pair)
|
// TODO: rollback ipset counter
|
||||||
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
|
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
|
|
||||||
func (r *router) rollbackRules(pair firewall.RouterPair) {
|
|
||||||
keys := []string{
|
|
||||||
firewall.GenKey(firewall.ForwardingFormat, pair),
|
|
||||||
firewall.GenKey(firewall.PreroutingFormat, pair),
|
|
||||||
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
|
|
||||||
}
|
|
||||||
for _, key := range keys {
|
|
||||||
rule, ok := r.rules[key]
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
|
||||||
log.Warnf("rollback set counter for %s: %v", key, err)
|
|
||||||
}
|
|
||||||
delete(r.rules, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||||
@@ -952,30 +928,18 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||||
|
|
||||||
rule, exists := r.rules[ruleKey]
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
if !exists {
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
return nil
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
|
||||||
|
|
||||||
if rule.Handle == 0 {
|
|
||||||
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
|
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
|
||||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
|
||||||
|
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
|
||||||
return fmt.Errorf("decrement set counter: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -1365,89 +1329,65 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
|
|
||||||
if pair.Masquerade {
|
if pair.Masquerade {
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
|
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
|
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
|
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set counters are decremented in the sub-methods above before flush. If flush fails,
|
|
||||||
// counters will be off until the next successful removal or refresh cycle.
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
|
// TODO: rollback set counter
|
||||||
}
|
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
|
||||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
|
||||||
|
|
||||||
rule, exists := r.rules[ruleKey]
|
|
||||||
if !exists {
|
|
||||||
log.Debugf("prerouting rule %s not found", ruleKey)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if rule.Handle == 0 {
|
|
||||||
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
|
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
|
||||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleKey)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
|
||||||
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
|
||||||
|
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
|
||||||
return fmt.Errorf("decrement set counter: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries
|
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||||
// (e.g. from failed flushes) and updates handles for all existing rules.
|
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
|
||||||
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Debugf("prerouting rule %s not found", ruleKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
||||||
|
// duplicates and to get missing attributes that we don't have when adding new rules
|
||||||
func (r *router) refreshRulesMap() error {
|
func (r *router) refreshRulesMap() error {
|
||||||
var merr *multierror.Error
|
|
||||||
newRules := make(map[string]*nftables.Rule)
|
|
||||||
for _, chain := range r.chains {
|
for _, chain := range r.chains {
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
|
return fmt.Errorf("list rules: %w", err)
|
||||||
// preserve existing entries for this chain since we can't verify their state
|
|
||||||
for k, v := range r.rules {
|
|
||||||
if v.Chain != nil && v.Chain.Name == chain.Name {
|
|
||||||
newRules[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 {
|
if len(rule.UserData) > 0 {
|
||||||
newRules[string(rule.UserData)] = rule
|
r.rules[string(rule.UserData)] = rule
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
r.rules = newRules
|
return nil
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
@@ -1689,34 +1629,20 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
var needsFlush bool
|
|
||||||
|
|
||||||
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
if dnatRule.Handle == 0 {
|
if err := r.conn.DelRule(dnatRule); err != nil {
|
||||||
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
|
|
||||||
delete(r.rules, ruleKey+dnatSuffix)
|
|
||||||
} else if err := r.conn.DelRule(dnatRule); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
||||||
} else {
|
|
||||||
needsFlush = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||||
if masqRule.Handle == 0 {
|
if err := r.conn.DelRule(masqRule); err != nil {
|
||||||
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
|
|
||||||
delete(r.rules, ruleKey+snatSuffix)
|
|
||||||
} else if err := r.conn.DelRule(masqRule); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
||||||
} else {
|
|
||||||
needsFlush = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if needsFlush {
|
if err := r.conn.Flush(); err != nil {
|
||||||
if err := r.conn.Flush(); err != nil {
|
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if merr == nil {
|
if merr == nil {
|
||||||
@@ -1831,25 +1757,16 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
|||||||
|
|
||||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
rule, exists := r.rules[ruleID]
|
if rule, exists := r.rules[ruleID]; exists {
|
||||||
if !exists {
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
return nil
|
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||||
}
|
}
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
if rule.Handle == 0 {
|
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||||
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
|
}
|
||||||
delete(r.rules, ruleID)
|
delete(r.rules, ruleID)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
|
||||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
|
||||||
}
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/test"
|
"github.com/netbirdio/netbird/client/firewall/test"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -720,137 +719,3 @@ func deleteWorkTable() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
|
||||||
if check() != NFTABLES {
|
|
||||||
t.Skip("nftables not supported on this system")
|
|
||||||
}
|
|
||||||
|
|
||||||
workTable, err := createWorkTable()
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer deleteWorkTable()
|
|
||||||
|
|
||||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NoError(t, r.init(workTable))
|
|
||||||
defer func() { require.NoError(t, r.Reset()) }()
|
|
||||||
|
|
||||||
// Add a real rule to the kernel
|
|
||||||
ruleKey, err := r.AddRouteFiltering(
|
|
||||||
nil,
|
|
||||||
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
|
||||||
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
|
||||||
firewall.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
&firewall.Port{Values: []uint16{80}},
|
|
||||||
firewall.ActionAccept,
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
require.NoError(t, r.DeleteRouteRule(ruleKey))
|
|
||||||
})
|
|
||||||
|
|
||||||
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
|
|
||||||
staleKey := "stale-rule-that-does-not-exist"
|
|
||||||
r.rules[staleKey] = &nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chains[chainNameRoutingFw],
|
|
||||||
Handle: 0,
|
|
||||||
UserData: []byte(staleKey),
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
|
|
||||||
|
|
||||||
err = r.refreshRulesMap()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
|
|
||||||
|
|
||||||
realRule, ok := r.rules[ruleKey.ID()]
|
|
||||||
assert.True(t, ok, "real rule should still exist after refresh")
|
|
||||||
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
|
||||||
if check() != NFTABLES {
|
|
||||||
t.Skip("nftables not supported on this system")
|
|
||||||
}
|
|
||||||
|
|
||||||
workTable, err := createWorkTable()
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer deleteWorkTable()
|
|
||||||
|
|
||||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NoError(t, r.init(workTable))
|
|
||||||
defer func() { require.NoError(t, r.Reset()) }()
|
|
||||||
|
|
||||||
// Inject a stale entry with Handle=0
|
|
||||||
staleKey := "stale-route-rule"
|
|
||||||
r.rules[staleKey] = &nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chains[chainNameRoutingFw],
|
|
||||||
Handle: 0,
|
|
||||||
UserData: []byte(staleKey),
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteRouteRule should not return an error for stale handles
|
|
||||||
err = r.DeleteRouteRule(id.RuleID(staleKey))
|
|
||||||
assert.NoError(t, err, "deleting a stale rule should not error")
|
|
||||||
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
|
||||||
if check() != NFTABLES {
|
|
||||||
t.Skip("nftables not supported on this system")
|
|
||||||
}
|
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NoError(t, manager.Init(nil))
|
|
||||||
t.Cleanup(func() {
|
|
||||||
require.NoError(t, manager.Close(nil))
|
|
||||||
})
|
|
||||||
|
|
||||||
pair := firewall.RouterPair{
|
|
||||||
ID: "staletest",
|
|
||||||
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
|
||||||
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
|
||||||
Masquerade: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
rtr := manager.router
|
|
||||||
|
|
||||||
// First add succeeds
|
|
||||||
err = rtr.AddNatRule(pair)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
require.NoError(t, rtr.RemoveNatRule(pair))
|
|
||||||
})
|
|
||||||
|
|
||||||
// Corrupt the handle to simulate stale state
|
|
||||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
|
||||||
if rule, exists := rtr.rules[natRuleKey]; exists {
|
|
||||||
rule.Handle = 0
|
|
||||||
}
|
|
||||||
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
|
|
||||||
if rule, exists := rtr.rules[inverseKey]; exists {
|
|
||||||
rule.Handle = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adding the same rule again should succeed despite stale handles
|
|
||||||
err = rtr.AddNatRule(pair)
|
|
||||||
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
|
|
||||||
|
|
||||||
// Verify rules exist in kernel
|
|
||||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
found := 0
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
|
||||||
found++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,6 +3,12 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -11,7 +17,33 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.resetState()
|
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||||
|
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
||||||
|
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||||
|
|
||||||
|
if m.udpTracker != nil {
|
||||||
|
m.udpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.icmpTracker != nil {
|
||||||
|
m.icmpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.tcpTracker != nil {
|
||||||
|
m.tcpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
|
fwder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.logger != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := m.logger.Stop(ctx); err != nil {
|
||||||
|
log.Errorf("failed to shutdown logger: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Close(stateManager)
|
return m.nativeFirewall.Close(stateManager)
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -23,7 +26,33 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.resetState()
|
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||||
|
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
||||||
|
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||||
|
|
||||||
|
if m.udpTracker != nil {
|
||||||
|
m.udpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.icmpTracker != nil {
|
||||||
|
m.icmpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.tcpTracker != nil {
|
||||||
|
m.tcpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
|
fwder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.logger != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := m.logger.Stop(ctx); err != nil {
|
||||||
|
log.Errorf("failed to shutdown logger: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !isWindowsFirewallReachable() {
|
if !isWindowsFirewallReachable() {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -115,17 +115,6 @@ func (t *TCPConnTrack) IsTombstone() bool {
|
|||||||
return t.tombstone.Load()
|
return t.tombstone.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsSupersededBy returns true if this connection should be replaced by a new one
|
|
||||||
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
|
|
||||||
// connections are superseded by a pure SYN (a new connection attempt for the same
|
|
||||||
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
|
|
||||||
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
|
|
||||||
if t.tombstone.Load() {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTombstone safely marks the connection for deletion
|
// SetTombstone safely marks the connection for deletion
|
||||||
func (t *TCPConnTrack) SetTombstone() {
|
func (t *TCPConnTrack) SetTombstone() {
|
||||||
t.tombstone.Store(true)
|
t.tombstone.Store(true)
|
||||||
@@ -180,7 +169,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
|||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if exists && !conn.IsSupersededBy(flags) {
|
if exists {
|
||||||
t.updateState(key, conn, flags, direction, size)
|
t.updateState(key, conn, flags, direction, size)
|
||||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||||
}
|
}
|
||||||
@@ -252,7 +241,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
|||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists || conn.IsSupersededBy(flags) {
|
if !exists || conn.IsTombstone() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -485,261 +485,6 @@ func TestTCPAbnormalSequences(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
|
|
||||||
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
|
|
||||||
// updateIfExists treats tombstoned entries as live, causing track() to skip
|
|
||||||
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
|
|
||||||
// because the entry is tombstoned, and the response packet gets dropped by ACL.
|
|
||||||
func TestTCPPortReuseTombstone(t *testing.T) {
|
|
||||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
|
||||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
|
||||||
srcPort := uint16(12345)
|
|
||||||
dstPort := uint16(80)
|
|
||||||
|
|
||||||
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
|
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
|
||||||
|
|
||||||
// Establish and gracefully close a connection (server-initiated close)
|
|
||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
|
||||||
|
|
||||||
// Server sends FIN
|
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
|
||||||
require.True(t, valid)
|
|
||||||
|
|
||||||
// Client sends FIN-ACK
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
|
||||||
|
|
||||||
// Server sends final ACK
|
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
|
||||||
require.True(t, valid)
|
|
||||||
|
|
||||||
// Connection should be tombstoned
|
|
||||||
conn := tracker.connections[key]
|
|
||||||
require.NotNil(t, conn, "old connection should still be in map")
|
|
||||||
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
|
|
||||||
|
|
||||||
// Now reuse the same port for a new connection
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
|
||||||
|
|
||||||
// The old tombstoned entry should be replaced with a new one
|
|
||||||
newConn := tracker.connections[key]
|
|
||||||
require.NotNil(t, newConn, "new connection should exist")
|
|
||||||
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
|
||||||
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
|
||||||
|
|
||||||
// SYN-ACK for the new connection should be valid
|
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
|
||||||
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
|
|
||||||
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
|
||||||
|
|
||||||
// Data transfer should work
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
|
|
||||||
require.True(t, valid, "data should be allowed on new connection")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Outbound port reuse after RST", func(t *testing.T) {
|
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
|
||||||
|
|
||||||
// Establish and RST a connection
|
|
||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
|
||||||
require.True(t, valid)
|
|
||||||
|
|
||||||
conn := tracker.connections[key]
|
|
||||||
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
|
|
||||||
|
|
||||||
// Reuse the same port
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
|
||||||
|
|
||||||
newConn := tracker.connections[key]
|
|
||||||
require.NotNil(t, newConn)
|
|
||||||
require.False(t, newConn.IsTombstone())
|
|
||||||
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
|
||||||
|
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
|
||||||
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Inbound port reuse after close", func(t *testing.T) {
|
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
clientIP := srcIP
|
|
||||||
serverIP := dstIP
|
|
||||||
clientPort := srcPort
|
|
||||||
serverPort := dstPort
|
|
||||||
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
|
|
||||||
|
|
||||||
// Inbound connection: client SYN → server SYN-ACK → client ACK
|
|
||||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
|
||||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
|
||||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
|
||||||
|
|
||||||
conn := tracker.connections[key]
|
|
||||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
|
||||||
|
|
||||||
// Server-initiated close to reach Closed/tombstoned:
|
|
||||||
// Server FIN (opposite dir) → CloseWait
|
|
||||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
|
|
||||||
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
|
||||||
// Client FIN-ACK (same dir as conn) → LastAck
|
|
||||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
|
|
||||||
require.Equal(t, TCPStateLastAck, conn.GetState())
|
|
||||||
// Server final ACK (opposite dir) → Closed → tombstoned
|
|
||||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
|
||||||
|
|
||||||
require.True(t, conn.IsTombstone())
|
|
||||||
|
|
||||||
// New inbound connection on same ports
|
|
||||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
|
||||||
|
|
||||||
newConn := tracker.connections[key]
|
|
||||||
require.NotNil(t, newConn)
|
|
||||||
require.False(t, newConn.IsTombstone())
|
|
||||||
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
|
||||||
|
|
||||||
// Complete handshake: server SYN-ACK, then client ACK
|
|
||||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
|
||||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
|
||||||
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
|
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
|
||||||
|
|
||||||
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
|
|
||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
|
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
|
|
||||||
|
|
||||||
conn := tracker.connections[key]
|
|
||||||
require.True(t, conn.IsTombstone())
|
|
||||||
|
|
||||||
// Late ACK should be rejected (tombstoned)
|
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
|
||||||
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
|
|
||||||
|
|
||||||
// Late outbound ACK should not create a new connection (not a SYN)
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
|
||||||
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTCPPortReuseTimeWait(t *testing.T) {
|
|
||||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
|
||||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
|
||||||
srcPort := uint16(12345)
|
|
||||||
dstPort := uint16(80)
|
|
||||||
|
|
||||||
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
|
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
|
||||||
|
|
||||||
// Establish connection
|
|
||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
|
||||||
|
|
||||||
// Active close: client (outbound initiator) sends FIN first
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
|
||||||
conn := tracker.connections[key]
|
|
||||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
|
||||||
|
|
||||||
// Server ACKs the FIN
|
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
|
||||||
require.True(t, valid)
|
|
||||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
|
||||||
|
|
||||||
// Server sends its own FIN
|
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
|
||||||
require.True(t, valid)
|
|
||||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
|
||||||
|
|
||||||
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
|
||||||
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
|
|
||||||
|
|
||||||
// New outbound SYN on the same port (port reuse during TIME-WAIT)
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
|
||||||
|
|
||||||
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
|
|
||||||
newConn := tracker.connections[key]
|
|
||||||
require.NotNil(t, newConn, "new connection should exist")
|
|
||||||
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
|
||||||
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
|
|
||||||
|
|
||||||
// SYN-ACK for new connection should be valid
|
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
|
||||||
require.True(t, valid, "SYN-ACK for new connection should be accepted")
|
|
||||||
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
|
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
|
||||||
|
|
||||||
// Establish outbound connection and close via active close → TIME-WAIT
|
|
||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
|
||||||
|
|
||||||
conn := tracker.connections[key]
|
|
||||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
|
||||||
|
|
||||||
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
|
|
||||||
// so the filter falls through to ACL check + TrackInbound (which creates
|
|
||||||
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
|
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
|
|
||||||
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
|
|
||||||
|
|
||||||
// Simulate what the filter does next: TrackInbound via the normal path
|
|
||||||
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
|
|
||||||
|
|
||||||
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
|
|
||||||
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
|
|
||||||
newConn := tracker.connections[invertedKey]
|
|
||||||
require.NotNil(t, newConn, "new inbound connection should be tracked")
|
|
||||||
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
|
||||||
require.False(t, newConn.IsTombstone())
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
|
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
|
||||||
|
|
||||||
// Establish and active close → TIME-WAIT
|
|
||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
|
||||||
|
|
||||||
conn := tracker.connections[key]
|
|
||||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
|
||||||
|
|
||||||
// Late ACK retransmits during TIME-WAIT should still be accepted
|
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
|
||||||
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTCPTimeoutHandling(t *testing.T) {
|
func TestTCPTimeoutHandling(t *testing.T) {
|
||||||
// Create tracker with a very short timeout for testing
|
// Create tracker with a very short timeout for testing
|
||||||
shortTimeout := 100 * time.Millisecond
|
shortTimeout := 100 * time.Millisecond
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -13,13 +12,11 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
@@ -27,7 +24,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
@@ -93,7 +89,6 @@ type Manager struct {
|
|||||||
incomingDenyRules map[netip.Addr]RuleSet
|
incomingDenyRules map[netip.Addr]RuleSet
|
||||||
incomingRules map[netip.Addr]RuleSet
|
incomingRules map[netip.Addr]RuleSet
|
||||||
routeRules RouteRules
|
routeRules RouteRules
|
||||||
routeRulesMap map[nbid.RuleID]*RouteRule
|
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
wgIface common.IFaceMapper
|
wgIface common.IFaceMapper
|
||||||
nativeFirewall firewall.Manager
|
nativeFirewall firewall.Manager
|
||||||
@@ -234,7 +229,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
flowLogger: flowLogger,
|
flowLogger: flowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
|
|
||||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||||
portDNATRules: []portDNATRule{},
|
portDNATRules: []portDNATRule{},
|
||||||
netstackServices: make(map[serviceKey]struct{}),
|
netstackServices: make(map[serviceKey]struct{}),
|
||||||
@@ -486,15 +480,11 @@ func (m *Manager) addRouteFiltering(
|
|||||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
ruleID := uuid.New().String()
|
||||||
|
|
||||||
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
|
|
||||||
return existingRule, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
rule := RouteRule{
|
rule := RouteRule{
|
||||||
// TODO: consolidate these IDs
|
// TODO: consolidate these IDs
|
||||||
id: string(ruleKey),
|
id: ruleID,
|
||||||
mgmtId: id,
|
mgmtId: id,
|
||||||
sources: sources,
|
sources: sources,
|
||||||
dstSet: destination.Set,
|
dstSet: destination.Set,
|
||||||
@@ -509,7 +499,6 @@ func (m *Manager) addRouteFiltering(
|
|||||||
|
|
||||||
m.routeRules = append(m.routeRules, &rule)
|
m.routeRules = append(m.routeRules, &rule)
|
||||||
m.routeRules.Sort()
|
m.routeRules.Sort()
|
||||||
m.routeRulesMap[ruleKey] = &rule
|
|
||||||
|
|
||||||
return &rule, nil
|
return &rule, nil
|
||||||
}
|
}
|
||||||
@@ -526,20 +515,15 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
|||||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleKey := nbid.RuleID(rule.ID())
|
ruleID := rule.ID()
|
||||||
if _, ok := m.routeRulesMap[ruleKey]; !ok {
|
|
||||||
return fmt.Errorf("route rule not found: %s", ruleKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||||
return r.id == string(ruleKey)
|
return r.id == ruleID
|
||||||
})
|
})
|
||||||
if idx < 0 {
|
if idx < 0 {
|
||||||
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
|
return fmt.Errorf("route rule not found: %s", ruleID)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||||
delete(m.routeRulesMap, ruleKey)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -586,40 +570,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
// resetState clears all firewall rules and closes connection trackers.
|
|
||||||
// Must be called with m.mutex held.
|
|
||||||
func (m *Manager) resetState() {
|
|
||||||
maps.Clear(m.outgoingRules)
|
|
||||||
maps.Clear(m.incomingDenyRules)
|
|
||||||
maps.Clear(m.incomingRules)
|
|
||||||
maps.Clear(m.routeRulesMap)
|
|
||||||
m.routeRules = m.routeRules[:0]
|
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
|
||||||
m.udpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
|
||||||
m.icmpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
|
||||||
m.tcpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if fwder := m.forwarder.Load(); fwder != nil {
|
|
||||||
fwder.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.logger != nil {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := m.logger.Stop(ctx); err != nil {
|
|
||||||
log.Errorf("failed to shutdown logger: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeFirewall == nil {
|
||||||
|
|||||||
@@ -1,376 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
|
|
||||||
// filtering rule twice returns the same rule ID (idempotent behavior).
|
|
||||||
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
|
||||||
manager := setupTestManager(t)
|
|
||||||
|
|
||||||
sources := []netip.Prefix{
|
|
||||||
netip.MustParsePrefix("100.64.1.0/24"),
|
|
||||||
netip.MustParsePrefix("100.64.2.0/24"),
|
|
||||||
}
|
|
||||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
|
||||||
|
|
||||||
// Add rule first time
|
|
||||||
rule1, err := manager.AddRouteFiltering(
|
|
||||||
[]byte("policy-1"),
|
|
||||||
sources,
|
|
||||||
destination,
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
&fw.Port{Values: []uint16{443}},
|
|
||||||
fw.ActionAccept,
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, rule1)
|
|
||||||
|
|
||||||
// Add the same rule again
|
|
||||||
rule2, err := manager.AddRouteFiltering(
|
|
||||||
[]byte("policy-1"),
|
|
||||||
sources,
|
|
||||||
destination,
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
&fw.Port{Values: []uint16{443}},
|
|
||||||
fw.ActionAccept,
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, rule2)
|
|
||||||
|
|
||||||
// These should be the same (idempotent) like nftables/iptables implementations
|
|
||||||
assert.Equal(t, rule1.ID(), rule2.ID(),
|
|
||||||
"Adding the same rule twice should return the same rule ID (idempotent)")
|
|
||||||
|
|
||||||
manager.mutex.RLock()
|
|
||||||
ruleCount := len(manager.routeRules)
|
|
||||||
manager.mutex.RUnlock()
|
|
||||||
|
|
||||||
assert.Equal(t, 2, ruleCount,
|
|
||||||
"Should have exactly 2 rules (1 user rule + 1 block rule)")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
|
|
||||||
// different parameters get distinct IDs.
|
|
||||||
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
|
||||||
manager := setupTestManager(t)
|
|
||||||
|
|
||||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
|
||||||
|
|
||||||
// Add first rule
|
|
||||||
rule1, err := manager.AddRouteFiltering(
|
|
||||||
[]byte("policy-1"),
|
|
||||||
sources,
|
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
&fw.Port{Values: []uint16{443}},
|
|
||||||
fw.ActionAccept,
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Add different rule (different destination)
|
|
||||||
rule2, err := manager.AddRouteFiltering(
|
|
||||||
[]byte("policy-2"),
|
|
||||||
sources,
|
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
&fw.Port{Values: []uint16{443}},
|
|
||||||
fw.ActionAccept,
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.NotEqual(t, rule1.ID(), rule2.ID(),
|
|
||||||
"Different rules should have different IDs")
|
|
||||||
|
|
||||||
manager.mutex.RLock()
|
|
||||||
ruleCount := len(manager.routeRules)
|
|
||||||
manager.mutex.RUnlock()
|
|
||||||
|
|
||||||
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
|
|
||||||
// rule during a network map update does not disrupt existing traffic.
|
|
||||||
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
|
||||||
manager := setupTestManager(t)
|
|
||||||
|
|
||||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
|
||||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
|
||||||
|
|
||||||
rule1, err := manager.AddRouteFiltering(
|
|
||||||
[]byte("policy-1"),
|
|
||||||
sources,
|
|
||||||
destination,
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept,
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
srcIP := netip.MustParseAddr("100.64.1.5")
|
|
||||||
dstIP := netip.MustParseAddr("192.168.1.10")
|
|
||||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
|
||||||
require.True(t, pass, "Traffic should pass with rule in place")
|
|
||||||
|
|
||||||
// Re-add same rule (simulates network map update)
|
|
||||||
rule2, err := manager.AddRouteFiltering(
|
|
||||||
[]byte("policy-1"),
|
|
||||||
sources,
|
|
||||||
destination,
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept,
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
|
|
||||||
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
|
|
||||||
// would remove the only matching rule and cause a traffic gap.
|
|
||||||
if rule1.ID() != rule2.ID() {
|
|
||||||
err = manager.DeleteRouteRule(rule1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
|
||||||
assert.True(t, passAfter,
|
|
||||||
"Traffic should still pass after rule update - no gap should occur")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
|
|
||||||
// exactly one drop rule for the WireGuard network prefix, and calling it again
|
|
||||||
// returns the same rule without duplicating.
|
|
||||||
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
dev := mocks.NewMockDevice(ctrl)
|
|
||||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
|
||||||
|
|
||||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
|
||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
AddressFunc: func() wgaddr.Address {
|
|
||||||
return wgaddr.Address{
|
|
||||||
IP: wgNet.Addr(),
|
|
||||||
Network: wgNet,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
GetDeviceFunc: func() *device.FilteredDevice {
|
|
||||||
return &device.FilteredDevice{Device: dev}
|
|
||||||
},
|
|
||||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
|
||||||
return &wgdevice.Device{}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
require.NoError(t, manager.Close(nil))
|
|
||||||
})
|
|
||||||
|
|
||||||
// Call blockInvalidRouted directly multiple times
|
|
||||||
rule1, err := manager.blockInvalidRouted(ifaceMock)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, rule1)
|
|
||||||
|
|
||||||
rule2, err := manager.blockInvalidRouted(ifaceMock)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, rule2)
|
|
||||||
|
|
||||||
rule3, err := manager.blockInvalidRouted(ifaceMock)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, rule3)
|
|
||||||
|
|
||||||
// All should return the same rule
|
|
||||||
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
|
|
||||||
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
|
|
||||||
|
|
||||||
// Should have exactly 1 route rule
|
|
||||||
manager.mutex.RLock()
|
|
||||||
ruleCount := len(manager.routeRules)
|
|
||||||
manager.mutex.RUnlock()
|
|
||||||
|
|
||||||
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
|
|
||||||
|
|
||||||
// Verify the rule blocks traffic to the WG network
|
|
||||||
srcIP := netip.MustParseAddr("10.0.0.1")
|
|
||||||
dstIP := netip.MustParseAddr("100.64.0.50")
|
|
||||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
|
|
||||||
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
|
|
||||||
// EnableRouting multiple times (as happens on each route update) does not
|
|
||||||
// accumulate duplicate block rules in the routeRules slice.
|
|
||||||
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
dev := mocks.NewMockDevice(ctrl)
|
|
||||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
|
||||||
|
|
||||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
|
||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
AddressFunc: func() wgaddr.Address {
|
|
||||||
return wgaddr.Address{
|
|
||||||
IP: wgNet.Addr(),
|
|
||||||
Network: wgNet,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
GetDeviceFunc: func() *device.FilteredDevice {
|
|
||||||
return &device.FilteredDevice{Device: dev}
|
|
||||||
},
|
|
||||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
|
||||||
return &wgdevice.Device{}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
require.NoError(t, manager.Close(nil))
|
|
||||||
})
|
|
||||||
|
|
||||||
// Call EnableRouting multiple times (simulating repeated route updates)
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
require.NoError(t, manager.EnableRouting())
|
|
||||||
}
|
|
||||||
|
|
||||||
manager.mutex.RLock()
|
|
||||||
ruleCount := len(manager.routeRules)
|
|
||||||
manager.mutex.RUnlock()
|
|
||||||
|
|
||||||
assert.Equal(t, 1, ruleCount,
|
|
||||||
"Repeated EnableRouting should not accumulate block rules")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
|
|
||||||
// rule multiple times does not create duplicate entries.
|
|
||||||
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
|
|
||||||
manager := setupTestManager(t)
|
|
||||||
|
|
||||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
|
||||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
|
||||||
|
|
||||||
// Simulate 5 network map updates with the same route rule
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
rule, err := manager.AddRouteFiltering(
|
|
||||||
[]byte("policy-1"),
|
|
||||||
sources,
|
|
||||||
destination,
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
&fw.Port{Values: []uint16{443}},
|
|
||||||
fw.ActionAccept,
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
manager.mutex.RLock()
|
|
||||||
ruleCount := len(manager.routeRules)
|
|
||||||
manager.mutex.RUnlock()
|
|
||||||
|
|
||||||
assert.Equal(t, 2, ruleCount,
|
|
||||||
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
|
|
||||||
// after adding it multiple times works correctly.
|
|
||||||
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
|
||||||
manager := setupTestManager(t)
|
|
||||||
|
|
||||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
|
||||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
|
||||||
|
|
||||||
// Add same rule twice
|
|
||||||
rule1, err := manager.AddRouteFiltering(
|
|
||||||
[]byte("policy-1"),
|
|
||||||
sources,
|
|
||||||
destination,
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept,
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
rule2, err := manager.AddRouteFiltering(
|
|
||||||
[]byte("policy-1"),
|
|
||||||
sources,
|
|
||||||
destination,
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept,
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
|
|
||||||
|
|
||||||
// Delete using first reference
|
|
||||||
err = manager.DeleteRouteRule(rule1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Verify traffic no longer passes
|
|
||||||
srcIP := netip.MustParseAddr("100.64.1.5")
|
|
||||||
dstIP := netip.MustParseAddr("192.168.1.10")
|
|
||||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
|
||||||
assert.False(t, pass, "Traffic should not pass after rule deletion")
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupTestManager(t *testing.T) *Manager {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
dev := mocks.NewMockDevice(ctrl)
|
|
||||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
|
||||||
|
|
||||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
|
||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
AddressFunc: func() wgaddr.Address {
|
|
||||||
return wgaddr.Address{
|
|
||||||
IP: wgNet.Addr(),
|
|
||||||
Network: wgNet,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
GetDeviceFunc: func() *device.FilteredDevice {
|
|
||||||
return &device.FilteredDevice{Device: dev}
|
|
||||||
},
|
|
||||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
|
||||||
return &wgdevice.Device{}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NoError(t, manager.EnableRouting())
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
require.NoError(t, manager.Close(nil))
|
|
||||||
})
|
|
||||||
|
|
||||||
return manager
|
|
||||||
}
|
|
||||||
@@ -263,158 +263,6 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
|
||||||
// to the deny map and can be cleanly deleted without leaving orphans.
|
|
||||||
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, m.Close(nil))
|
|
||||||
}()
|
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
addr := netip.MustParseAddr("192.168.1.1")
|
|
||||||
|
|
||||||
// Add multiple deny rules for different ports
|
|
||||||
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
|
||||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
|
||||||
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
m.mutex.RLock()
|
|
||||||
denyCount := len(m.incomingDenyRules[addr])
|
|
||||||
m.mutex.RUnlock()
|
|
||||||
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
|
|
||||||
|
|
||||||
// Delete the first deny rule
|
|
||||||
err = m.DeletePeerRule(rule1[0])
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
m.mutex.RLock()
|
|
||||||
denyCount = len(m.incomingDenyRules[addr])
|
|
||||||
m.mutex.RUnlock()
|
|
||||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
|
|
||||||
|
|
||||||
// Delete the second deny rule
|
|
||||||
err = m.DeletePeerRule(rule2[0])
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
m.mutex.RLock()
|
|
||||||
_, exists := m.incomingDenyRules[addr]
|
|
||||||
m.mutex.RUnlock()
|
|
||||||
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
|
|
||||||
// peer rules (simulating network map updates) does not leak rules in the maps.
|
|
||||||
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, m.Close(nil))
|
|
||||||
}()
|
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
addr := netip.MustParseAddr("192.168.1.1")
|
|
||||||
|
|
||||||
// Simulate 10 network map updates: add rule, delete old, add new
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
// Add a deny rule
|
|
||||||
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
|
||||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Add an allow rule
|
|
||||||
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
|
||||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Delete them (simulating ACL manager cleanup)
|
|
||||||
for _, r := range rules {
|
|
||||||
require.NoError(t, m.DeletePeerRule(r))
|
|
||||||
}
|
|
||||||
for _, r := range allowRules {
|
|
||||||
require.NoError(t, m.DeletePeerRule(r))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.RLock()
|
|
||||||
denyCount := len(m.incomingDenyRules[addr])
|
|
||||||
allowCount := len(m.incomingRules[addr])
|
|
||||||
m.mutex.RUnlock()
|
|
||||||
|
|
||||||
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
|
|
||||||
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
|
|
||||||
// IP are stored in separate maps and don't interfere with each other.
|
|
||||||
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, m.Close(nil))
|
|
||||||
}()
|
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
|
|
||||||
// Add allow rule for port 80
|
|
||||||
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
|
||||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Add deny rule for port 22
|
|
||||||
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
|
||||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
addr := netip.MustParseAddr("192.168.1.1")
|
|
||||||
m.mutex.RLock()
|
|
||||||
allowCount := len(m.incomingRules[addr])
|
|
||||||
denyCount := len(m.incomingDenyRules[addr])
|
|
||||||
m.mutex.RUnlock()
|
|
||||||
|
|
||||||
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
|
|
||||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
|
|
||||||
|
|
||||||
// Delete allow rule should not affect deny rule
|
|
||||||
err = m.DeletePeerRule(allowRule[0])
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
m.mutex.RLock()
|
|
||||||
denyCountAfter := len(m.incomingDenyRules[addr])
|
|
||||||
m.mutex.RUnlock()
|
|
||||||
|
|
||||||
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
|
|
||||||
|
|
||||||
// Delete deny rule
|
|
||||||
err = m.DeletePeerRule(denyRule[0])
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
m.mutex.RLock()
|
|
||||||
_, denyExists := m.incomingDenyRules[addr]
|
|
||||||
_, allowExists := m.incomingRules[addr]
|
|
||||||
m.mutex.RUnlock()
|
|
||||||
|
|
||||||
require.False(t, denyExists, "Deny rules should be empty")
|
|
||||||
require.False(t, allowExists, "Allow rules should be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManagerReset(t *testing.T) {
|
func TestManagerReset(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
|||||||
@@ -5,8 +5,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -18,18 +16,9 @@ const (
|
|||||||
maxBatchSize = 1024 * 16
|
maxBatchSize = 1024 * 16
|
||||||
maxMessageSize = 1024 * 2
|
maxMessageSize = 1024 * 2
|
||||||
defaultFlushInterval = 2 * time.Second
|
defaultFlushInterval = 2 * time.Second
|
||||||
defaultLogChanSize = 1000
|
logChannelSize = 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
func getLogChannelSize() int {
|
|
||||||
if v := os.Getenv("NB_USPFILTER_LOG_BUFFER"); v != "" {
|
|
||||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return defaultLogChanSize
|
|
||||||
}
|
|
||||||
|
|
||||||
type Level uint32
|
type Level uint32
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -80,7 +69,7 @@ type Logger struct {
|
|||||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||||
l := &Logger{
|
l := &Logger{
|
||||||
output: logrusLogger.Out,
|
output: logrusLogger.Out,
|
||||||
msgChannel: make(chan logMessage, getLogChannelSize()),
|
msgChannel: make(chan logMessage, logChannelSize),
|
||||||
shutdown: make(chan struct{}),
|
shutdown: make(chan struct{}),
|
||||||
bufPool: sync.Pool{
|
bufPool: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
|
|||||||
@@ -29,9 +29,8 @@ type PacketFilter interface {
|
|||||||
type FilteredDevice struct {
|
type FilteredDevice struct {
|
||||||
tun.Device
|
tun.Device
|
||||||
|
|
||||||
filter PacketFilter
|
filter PacketFilter
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
closeOnce sync.Once
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// newDeviceFilter constructor function
|
// newDeviceFilter constructor function
|
||||||
@@ -41,20 +40,6 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the underlying tun device exactly once.
|
|
||||||
// wireguard-go's netTun.Close() panics on double-close due to a bare close(channel),
|
|
||||||
// and multiple code paths can trigger Close on the same device.
|
|
||||||
func (d *FilteredDevice) Close() error {
|
|
||||||
var err error
|
|
||||||
d.closeOnce.Do(func() {
|
|
||||||
err = d.Device.Close()
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read wraps read method with filtering feature
|
// Read wraps read method with filtering feature
|
||||||
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||||
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
||||||
|
|||||||
@@ -82,9 +82,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
|||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if cErr := tunIface.Close(); cErr != nil {
|
_ = tunIface.Close()
|
||||||
log.Debugf("failed to close tun device: %v", cErr)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("error configuring interface: %s", err)
|
return nil, fmt.Errorf("error configuring interface: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/errors"
|
"github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
@@ -229,10 +228,6 @@ func (w *WGIface) Close() error {
|
|||||||
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if nbnetstack.IsEnabled() {
|
|
||||||
return errors.FormatErrorOrNil(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := w.waitUntilRemoved(); err != nil {
|
if err := w.waitUntilRemoved(); err != nil {
|
||||||
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
||||||
if err := w.Destroy(); err != nil {
|
if err := w.Destroy(); err != nil {
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return t.tundev, tunNet, nil
|
return nsTunDev, tunNet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *NetStackTun) Close() error {
|
func (t *NetStackTun) Close() error {
|
||||||
|
|||||||
@@ -189,212 +189,6 @@ func TestDefaultManagerStateless(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same
|
|
||||||
// deny rules repeatedly does not accumulate duplicate rules in the uspfilter.
|
|
||||||
// This tests the full ACL manager -> uspfilter integration.
|
|
||||||
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
|
||||||
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "22",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "80",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
FirewallRulesIsEmpty: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
|
||||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
|
||||||
IP: network.Addr(),
|
|
||||||
Network: network,
|
|
||||||
}).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
|
||||||
|
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, fw.Close(nil))
|
|
||||||
}()
|
|
||||||
|
|
||||||
acl := NewDefaultManager(fw)
|
|
||||||
|
|
||||||
// Apply the same rules 5 times (simulating repeated network map updates)
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
acl.ApplyFiltering(networkMap, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound)
|
|
||||||
assert.Equal(t, 3, len(acl.peerRulesPairs),
|
|
||||||
"Should have exactly 3 rule pairs after 5 identical updates")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned
|
|
||||||
// up when they're removed from the network map in a subsequent update.
|
|
||||||
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
|
||||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
|
||||||
IP: network.Addr(),
|
|
||||||
Network: network,
|
|
||||||
}).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
|
||||||
|
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, fw.Close(nil))
|
|
||||||
}()
|
|
||||||
|
|
||||||
acl := NewDefaultManager(fw)
|
|
||||||
|
|
||||||
// First update: add deny and accept rules
|
|
||||||
networkMap1 := &mgmProto.NetworkMap{
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "22",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
FirewallRulesIsEmpty: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap1, false)
|
|
||||||
assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update")
|
|
||||||
|
|
||||||
// Second update: remove the deny rule, keep only accept
|
|
||||||
networkMap2 := &mgmProto.NetworkMap{
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
FirewallRulesIsEmpty: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap2, false)
|
|
||||||
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
|
||||||
"Should have 1 rule after removing deny rule")
|
|
||||||
|
|
||||||
// Third update: remove all rules
|
|
||||||
networkMap3 := &mgmProto.NetworkMap{
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{},
|
|
||||||
FirewallRulesIsEmpty: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap3, false)
|
|
||||||
assert.Equal(t, 0, len(acl.peerRulesPairs),
|
|
||||||
"Should have 0 rules after removing all rules")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRuleUpdateChangingAction verifies that when a rule's action changes from
|
|
||||||
// accept to deny (or vice versa), the old rule is properly removed and the new
|
|
||||||
// one added without leaking.
|
|
||||||
func TestRuleUpdateChangingAction(t *testing.T) {
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
|
||||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
|
||||||
IP: network.Addr(),
|
|
||||||
Network: network,
|
|
||||||
}).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
|
||||||
|
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, fw.Close(nil))
|
|
||||||
}()
|
|
||||||
|
|
||||||
acl := NewDefaultManager(fw)
|
|
||||||
|
|
||||||
// First update: accept rule
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "22",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
FirewallRulesIsEmpty: false,
|
|
||||||
}
|
|
||||||
acl.ApplyFiltering(networkMap, false)
|
|
||||||
assert.Equal(t, 1, len(acl.peerRulesPairs))
|
|
||||||
|
|
||||||
// Second update: change to deny (same IP/port/proto, different action)
|
|
||||||
networkMap.FirewallRules = []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "22",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
acl.ApplyFiltering(networkMap, false)
|
|
||||||
|
|
||||||
// Should still have exactly 1 rule (the old accept removed, new deny added)
|
|
||||||
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
|
||||||
"Changing action should result in exactly 1 rule, not 2")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPortInfoEmpty(t *testing.T) {
|
func TestPortInfoEmpty(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@@ -245,7 +244,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
localPeerState := peer.LocalPeerState{
|
localPeerState := peer.LocalPeerState{
|
||||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||||
PubKey: myPrivateKey.PublicKey().String(),
|
PubKey: myPrivateKey.PublicKey().String(),
|
||||||
KernelInterface: device.WireGuardModuleIsLoaded() && !netstack.IsEnabled(),
|
KernelInterface: device.WireGuardModuleIsLoaded(),
|
||||||
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
||||||
}
|
}
|
||||||
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||||
|
|||||||
@@ -6,9 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -29,8 +27,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
|
|
||||||
|
|
||||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||||
type ReadyListener interface {
|
type ReadyListener interface {
|
||||||
OnReady()
|
OnReady()
|
||||||
@@ -443,17 +439,6 @@ func (s *DefaultServer) SearchDomains() []string {
|
|||||||
// ProbeAvailability tests each upstream group's servers for availability
|
// ProbeAvailability tests each upstream group's servers for availability
|
||||||
// and deactivates the group if no server responds
|
// and deactivates the group if no server responds
|
||||||
func (s *DefaultServer) ProbeAvailability() {
|
func (s *DefaultServer) ProbeAvailability() {
|
||||||
if val := os.Getenv(envSkipDNSProbe); val != "" {
|
|
||||||
skipProbe, err := strconv.ParseBool(val)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err)
|
|
||||||
}
|
|
||||||
if skipProbe {
|
|
||||||
log.Infof("skipping DNS probe due to %s", envSkipDNSProbe)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, mux := range s.dnsMuxMap {
|
for _, mux := range s.dnsMuxMap {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
|||||||
@@ -190,75 +190,50 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
|
|||||||
return nberrors.FormatErrorOrNil(result)
|
return nberrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) {
|
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
||||||
if len(query.Question) == 0 {
|
if len(query.Question) == 0 {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
question := query.Question[0]
|
question := query.Question[0]
|
||||||
qname := strings.ToLower(question.Name)
|
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
|
||||||
|
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||||
|
|
||||||
logger.Tracef("question: domain=%s type=%s class=%s",
|
domain := strings.ToLower(question.Name)
|
||||||
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
|
||||||
|
|
||||||
resp := query.SetReply(query)
|
resp := query.SetReply(query)
|
||||||
network := resutil.NetworkForQtype(question.Qtype)
|
network := resutil.NetworkForQtype(question.Qtype)
|
||||||
if network == "" {
|
if network == "" {
|
||||||
resp.Rcode = dns.RcodeNotImplemented
|
resp.Rcode = dns.RcodeNotImplemented
|
||||||
f.writeResponse(logger, w, resp, qname, startTime)
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
return
|
logger.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
|
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
||||||
|
// query doesn't match any configured domain
|
||||||
if mostSpecificResId == "" {
|
if mostSpecificResId == "" {
|
||||||
resp.Rcode = dns.RcodeRefused
|
resp.Rcode = dns.RcodeRefused
|
||||||
f.writeResponse(logger, w, resp, qname, startTime)
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
return
|
logger.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
|
||||||
if result.Err != nil {
|
if result.Err != nil {
|
||||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
|
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
||||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...)
|
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
|
||||||
f.cache.set(qname, question.Qtype, result.IPs)
|
f.cache.set(domain, question.Qtype, result.IPs)
|
||||||
|
|
||||||
f.writeResponse(logger, w, resp, qname, startTime)
|
return resp
|
||||||
}
|
|
||||||
|
|
||||||
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
|
||||||
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
|
||||||
}
|
|
||||||
|
|
||||||
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
|
||||||
type udpResponseWriter struct {
|
|
||||||
dns.ResponseWriter
|
|
||||||
query *dns.Msg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
|
|
||||||
opt := u.query.IsEdns0()
|
|
||||||
maxSize := dns.MinMsgSize
|
|
||||||
if opt != nil {
|
|
||||||
maxSize = int(opt.UDPSize())
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Len() > maxSize {
|
|
||||||
resp.Truncate(maxSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
return u.ResponseWriter.WriteMsg(resp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
@@ -268,7 +243,30 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
})
|
})
|
||||||
|
|
||||||
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
resp := f.handleDNSQuery(logger, w, query)
|
||||||
|
if resp == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
opt := query.IsEdns0()
|
||||||
|
maxSize := dns.MinMsgSize
|
||||||
|
if opt != nil {
|
||||||
|
// client advertised a larger EDNS0 buffer
|
||||||
|
maxSize = int(opt.UDPSize())
|
||||||
|
}
|
||||||
|
|
||||||
|
// if our response is too big, truncate and set the TC bit
|
||||||
|
if resp.Len() > maxSize {
|
||||||
|
resp.Truncate(maxSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
logger.Errorf("failed to write DNS response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||||
|
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
@@ -278,7 +276,18 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
})
|
})
|
||||||
|
|
||||||
f.handleDNSQuery(logger, w, query, startTime)
|
resp := f.handleDNSQuery(logger, w, query)
|
||||||
|
if resp == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
logger.Errorf("failed to write DNS response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||||
|
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
||||||
@@ -325,7 +334,6 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
resp *dns.Msg,
|
resp *dns.Msg,
|
||||||
domain string,
|
domain string,
|
||||||
result resutil.LookupResult,
|
result resutil.LookupResult,
|
||||||
startTime time.Time,
|
|
||||||
) {
|
) {
|
||||||
qType := question.Qtype
|
qType := question.Qtype
|
||||||
qTypeName := dns.TypeToString[qType]
|
qTypeName := dns.TypeToString[qType]
|
||||||
@@ -335,7 +343,9 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
// NotFound: cache negative result and respond
|
// NotFound: cache negative result and respond
|
||||||
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
||||||
f.cache.set(domain, question.Qtype, nil)
|
f.cache.set(domain, question.Qtype, nil)
|
||||||
f.writeResponse(logger, w, resp, domain, startTime)
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
|
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -345,7 +355,9 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
||||||
resp.Rcode = dns.RcodeSuccess
|
resp.Rcode = dns.RcodeSuccess
|
||||||
f.writeResponse(logger, w, resp, domain, startTime)
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
|
logger.Errorf("failed to write cached DNS response: %v", writeErr)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -353,7 +365,9 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
||||||
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
||||||
resp.Rcode = verifyResult.Rcode
|
resp.Rcode = verifyResult.Rcode
|
||||||
f.writeResponse(logger, w, resp, domain, startTime)
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
|
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -361,12 +375,15 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
// No cache or verification failed. Log with or without the server field for more context.
|
// No cache or verification failed. Log with or without the server field for more context.
|
||||||
var dnsErr *net.DNSError
|
var dnsErr *net.DNSError
|
||||||
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
||||||
logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
||||||
} else {
|
} else {
|
||||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.writeResponse(logger, w, resp, domain, startTime)
|
// Write final failure response.
|
||||||
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
|
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getMatchingEntries retrieves the resource IDs for a given domain.
|
// getMatchingEntries retrieves the resource IDs for a given domain.
|
||||||
|
|||||||
@@ -318,9 +318,8 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
|||||||
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||||
|
|
||||||
resp := mockWriter.GetLastResponse()
|
|
||||||
if tt.shouldResolve {
|
if tt.shouldResolve {
|
||||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
||||||
@@ -330,9 +329,10 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
|||||||
mockFirewall.AssertExpectations(t)
|
mockFirewall.AssertExpectations(t)
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
} else {
|
} else {
|
||||||
require.NotNil(t, resp, "Expected response")
|
if resp != nil {
|
||||||
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
||||||
"Unauthorized domain should not return successful answers")
|
"Unauthorized domain should not return successful answers")
|
||||||
|
}
|
||||||
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
||||||
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
||||||
}
|
}
|
||||||
@@ -466,16 +466,14 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
|||||||
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now())
|
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
|
||||||
|
|
||||||
// Verify response
|
// Verify response
|
||||||
resp := mockWriter.GetLastResponse()
|
|
||||||
if tt.shouldResolve {
|
if tt.shouldResolve {
|
||||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
require.NotEmpty(t, resp.Answer)
|
require.NotEmpty(t, resp.Answer)
|
||||||
} else {
|
} else if resp != nil {
|
||||||
require.NotNil(t, resp, "Expected response")
|
|
||||||
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
||||||
"Unauthorized domain should be refused or have no answers")
|
"Unauthorized domain should be refused or have no answers")
|
||||||
}
|
}
|
||||||
@@ -530,10 +528,9 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
|||||||
query.SetQuestion("example.com.", dns.TypeA)
|
query.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||||
|
|
||||||
// Verify response contains all IPs
|
// Verify response contains all IPs
|
||||||
resp := mockWriter.GetLastResponse()
|
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
||||||
@@ -608,7 +605,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||||
|
|
||||||
// Check the response written to the writer
|
// Check the response written to the writer
|
||||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||||
@@ -678,8 +675,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|||||||
q1 := &dns.Msg{}
|
q1 := &dns.Msg{}
|
||||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
w1 := &test.MockResponseWriter{}
|
w1 := &test.MockResponseWriter{}
|
||||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
||||||
resp1 := w1.GetLastResponse()
|
|
||||||
require.NotNil(t, resp1)
|
require.NotNil(t, resp1)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
require.Len(t, resp1.Answer, 1)
|
require.Len(t, resp1.Answer, 1)
|
||||||
@@ -687,13 +683,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|||||||
// Second query: serve from cache after upstream failure
|
// Second query: serve from cache after upstream failure
|
||||||
q2 := &dns.Msg{}
|
q2 := &dns.Msg{}
|
||||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
w2 := &test.MockResponseWriter{}
|
var writtenResp *dns.Msg
|
||||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||||
|
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
||||||
|
|
||||||
resp2 := w2.GetLastResponse()
|
require.NotNil(t, writtenResp, "expected response to be written")
|
||||||
require.NotNil(t, resp2, "expected response to be written")
|
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
require.Len(t, writtenResp.Answer, 1)
|
||||||
require.Len(t, resp2.Answer, 1)
|
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
@@ -719,8 +715,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
|||||||
q1 := &dns.Msg{}
|
q1 := &dns.Msg{}
|
||||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||||
w1 := &test.MockResponseWriter{}
|
w1 := &test.MockResponseWriter{}
|
||||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
||||||
resp1 := w1.GetLastResponse()
|
|
||||||
require.NotNil(t, resp1)
|
require.NotNil(t, resp1)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
require.Len(t, resp1.Answer, 1)
|
require.Len(t, resp1.Answer, 1)
|
||||||
@@ -732,13 +727,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
|||||||
|
|
||||||
q2 := &dns.Msg{}
|
q2 := &dns.Msg{}
|
||||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||||
w2 := &test.MockResponseWriter{}
|
var writtenResp *dns.Msg
|
||||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||||
|
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
||||||
|
|
||||||
resp2 := w2.GetLastResponse()
|
require.NotNil(t, writtenResp)
|
||||||
require.NotNil(t, resp2)
|
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
require.Len(t, writtenResp.Answer, 1)
|
||||||
require.Len(t, resp2.Answer, 1)
|
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
@@ -789,9 +784,8 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
|||||||
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||||
|
|
||||||
resp := mockWriter.GetLastResponse()
|
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
|
||||||
@@ -903,15 +897,26 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
|||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
var writtenResp *dns.Msg
|
||||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
writtenResp = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
resp := mockWriter.GetLastResponse()
|
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||||
require.NotNil(t, resp, "Expected response to be written")
|
|
||||||
assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description)
|
// If a response was returned, it means it should be written (happens in wrapper functions)
|
||||||
|
if resp != nil && writtenResp == nil {
|
||||||
|
writtenResp = resp
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||||
|
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
||||||
|
|
||||||
if tt.expectNoAnswer {
|
if tt.expectNoAnswer {
|
||||||
assert.Empty(t, resp.Answer, "Response should have no answer records")
|
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
|
||||||
}
|
}
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
@@ -926,8 +931,15 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
|||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
// Don't set any question
|
// Don't set any question
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
writeCalled := false
|
||||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
writeCalled = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||||
|
|
||||||
assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query")
|
assert.Nil(t, resp, "Should return nil for empty query")
|
||||||
|
assert.False(t, writeCalled, "Should not write response for empty query")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
@@ -45,6 +44,7 @@ import (
|
|||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
@@ -141,6 +141,11 @@ type EngineConfig struct {
|
|||||||
ProfileConfig *profilemanager.Config
|
ProfileConfig *profilemanager.Config
|
||||||
|
|
||||||
LogPath string
|
LogPath string
|
||||||
|
|
||||||
|
// ProxyConfig contains system proxy settings for macOS
|
||||||
|
ProxyEnabled bool
|
||||||
|
ProxyHost string
|
||||||
|
ProxyPort int
|
||||||
}
|
}
|
||||||
|
|
||||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||||
@@ -224,6 +229,9 @@ type Engine struct {
|
|||||||
|
|
||||||
jobExecutor *jobexec.Executor
|
jobExecutor *jobexec.Executor
|
||||||
jobExecutorWG sync.WaitGroup
|
jobExecutorWG sync.WaitGroup
|
||||||
|
|
||||||
|
// proxyManager manages system-wide browser proxy settings on macOS
|
||||||
|
proxyManager *proxy.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@@ -314,6 +322,12 @@ func (e *Engine) Stop() error {
|
|||||||
e.updateManager.Stop()
|
e.updateManager.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.proxyManager != nil {
|
||||||
|
if err := e.proxyManager.DisableWebProxy(); err != nil {
|
||||||
|
log.Warnf("failed to disable system proxy: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Info("cleaning up status recorder states")
|
log.Info("cleaning up status recorder states")
|
||||||
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
||||||
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||||
@@ -449,6 +463,10 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
}
|
}
|
||||||
e.stateManager.Start()
|
e.stateManager.Start()
|
||||||
|
|
||||||
|
// Initialize proxy manager and register state for cleanup
|
||||||
|
proxy.RegisterState(e.stateManager)
|
||||||
|
e.proxyManager = proxy.NewManager(e.stateManager)
|
||||||
|
|
||||||
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.close()
|
e.close()
|
||||||
@@ -544,12 +562,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||||
e.shutdownWg.Add(1)
|
e.shutdownWg.Add(1)
|
||||||
wgIfaceName := e.wgInterface.Name()
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer e.shutdownWg.Done()
|
defer e.shutdownWg.Done()
|
||||||
|
|
||||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, wgIfaceName); shouldRestart {
|
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
||||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||||
e.triggerClientRestart()
|
e.triggerClientRestart()
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@@ -830,10 +847,6 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||||
started := time.Now()
|
|
||||||
defer func() {
|
|
||||||
log.Infof("sync finished in %s", time.Since(started))
|
|
||||||
}()
|
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
@@ -1023,7 +1036,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
state := e.statusRecorder.GetLocalPeerState()
|
state := e.statusRecorder.GetLocalPeerState()
|
||||||
state.IP = e.wgInterface.Address().String()
|
state.IP = e.wgInterface.Address().String()
|
||||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||||
state.KernelInterface = !e.wgInterface.IsUserspaceBind()
|
state.KernelInterface = device.WireGuardModuleIsLoaded()
|
||||||
state.FQDN = conf.GetFqdn()
|
state.FQDN = conf.GetFqdn()
|
||||||
|
|
||||||
e.statusRecorder.UpdateLocalPeerState(state)
|
e.statusRecorder.UpdateLocalPeerState(state)
|
||||||
@@ -1318,6 +1331,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
// If no server of a server group responds this will disable the respective handler and retry later.
|
// If no server of a server group responds this will disable the respective handler and retry later.
|
||||||
e.dnsServer.ProbeAvailability()
|
e.dnsServer.ProbeAvailability()
|
||||||
|
|
||||||
|
// Update system proxy state based on routes after network map is fully applied
|
||||||
|
e.updateSystemProxy(clientRoutes)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1924,7 +1940,7 @@ func (e *Engine) triggerClientRestart() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) startNetworkMonitor() {
|
func (e *Engine) startNetworkMonitor() {
|
||||||
if !e.config.NetworkMonitor || nbnetstack.IsEnabled() {
|
if !e.config.NetworkMonitor {
|
||||||
log.Infof("Network monitor is disabled, not starting")
|
log.Infof("Network monitor is disabled, not starting")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -2309,6 +2325,26 @@ func createFile(path string) error {
|
|||||||
return file.Close()
|
return file.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateSystemProxy triggers a proxy enable/disable cycle after the network map is updated.
|
||||||
|
func (e *Engine) updateSystemProxy(clientRoutes route.HAMap) {
|
||||||
|
if runtime.GOOS != "darwin" || e.proxyManager == nil {
|
||||||
|
log.Errorf("not updating proxy")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.proxyManager.EnableWebProxy(e.config.ProxyHost, e.config.ProxyPort); err != nil {
|
||||||
|
log.Errorf("enable system proxy: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Error("system proxy enabled after network map update")
|
||||||
|
|
||||||
|
if err := e.proxyManager.DisableWebProxy(); err != nil {
|
||||||
|
log.Errorf("disable system proxy: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Error("system proxy disabled after network map update")
|
||||||
|
}
|
||||||
|
|
||||||
func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
|
func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
|
||||||
remoteCred, err := signal.UnMarshalCredential(msg)
|
remoteCred, err := signal.UnMarshalCredential(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||||
@@ -95,10 +94,6 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
|
|
||||||
// updateSSHClientConfig updates the SSH client configuration with peer information
|
// updateSSHClientConfig updates the SSH client configuration with peer information
|
||||||
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||||
if netstack.IsEnabled() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
||||||
if len(peerInfo) == 0 {
|
if len(peerInfo) == 0 {
|
||||||
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
||||||
@@ -221,10 +216,6 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
|
|||||||
|
|
||||||
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
||||||
func (e *Engine) cleanupSSHConfig() {
|
func (e *Engine) cleanupSSHConfig() {
|
||||||
if netstack.IsEnabled() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
configMgr := sshconfig.New()
|
configMgr := sshconfig.New()
|
||||||
|
|
||||||
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
@@ -75,13 +74,12 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
|
|||||||
return NewUDPListener(m.wgIface, peerCfg)
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BindListener is used on Windows, JS, and netstack platforms:
|
// BindListener is only used on Windows and JS platforms:
|
||||||
// - JS: Cannot listen to UDP sockets
|
// - JS: Cannot listen to UDP sockets
|
||||||
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
||||||
// gateway points to, preventing them from reaching the loopback interface.
|
// gateway points to, preventing them from reaching the loopback interface.
|
||||||
// - Netstack: Allows multiple instances on the same host without port conflicts.
|
// BindListener bypasses this by passing data directly through the bind.
|
||||||
// BindListener bypasses these issues by passing data directly through the bind.
|
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
|
||||||
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
|
|
||||||
return NewUDPListener(m.wgIface, peerCfg)
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -37,6 +38,11 @@ func New() *NetworkMonitor {
|
|||||||
|
|
||||||
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
|
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
|
||||||
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
log.Debugf("Network monitor: skipping in netstack mode")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
nw.mu.Lock()
|
nw.mu.Lock()
|
||||||
if nw.cancel != nil {
|
if nw.cancel != nil {
|
||||||
nw.mu.Unlock()
|
nw.mu.Unlock()
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package ice
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -33,6 +32,24 @@ type ThreadSafeAgent struct {
|
|||||||
once sync.Once
|
once sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *ThreadSafeAgent) Close() error {
|
||||||
|
var err error
|
||||||
|
a.once.Do(func() {
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- a.Agent.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-done:
|
||||||
|
case <-time.After(iceAgentCloseTimeout):
|
||||||
|
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
||||||
iceKeepAlive := iceKeepAlive()
|
iceKeepAlive := iceKeepAlive()
|
||||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||||
@@ -76,41 +93,9 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if agent == nil {
|
|
||||||
return nil, fmt.Errorf("ice.NewAgent returned nil agent without error")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ThreadSafeAgent{Agent: agent}, nil
|
return &ThreadSafeAgent{Agent: agent}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *ThreadSafeAgent) Close() error {
|
|
||||||
var err error
|
|
||||||
a.once.Do(func() {
|
|
||||||
// Defensive check to prevent nil pointer dereference
|
|
||||||
// This can happen during sleep/wake transitions or memory corruption scenarios
|
|
||||||
// github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?)
|
|
||||||
// [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c]
|
|
||||||
agent := a.Agent
|
|
||||||
if agent == nil {
|
|
||||||
log.Warnf("ICE agent is nil during close, skipping")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
done := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
done <- agent.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err = <-done:
|
|
||||||
case <-time.After(iceAgentCloseTimeout):
|
|
||||||
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func GenerateICECredentials() (string, string, error) {
|
func GenerateICECredentials() (string, string, error) {
|
||||||
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
|
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -107,10 +107,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
}
|
}
|
||||||
w.log.Debugf("agent already exists, recreate the connection")
|
w.log.Debugf("agent already exists, recreate the connection")
|
||||||
w.agentDialerCancel()
|
w.agentDialerCancel()
|
||||||
if w.agent != nil {
|
if err := w.agent.Close(); err != nil {
|
||||||
if err := w.agent.Close(); err != nil {
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionID, err := NewICESessionID()
|
sessionID, err := NewICESessionID()
|
||||||
|
|||||||
@@ -252,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config.AdminURL == nil {
|
if config.AdminURL == nil {
|
||||||
log.Infof("using default Admin URL %s", DefaultAdminURL)
|
log.Infof("using default Admin URL %s", DefaultManagementURL)
|
||||||
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
|
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
|||||||
262
client/internal/proxy/manager_darwin.go
Normal file
262
client/internal/proxy/manager_darwin.go
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
const networksetupPath = "/usr/sbin/networksetup"
|
||||||
|
|
||||||
|
// Manager handles system-wide proxy configuration on macOS.
|
||||||
|
type Manager struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
stateManager *statemanager.Manager
|
||||||
|
modifiedServices []string
|
||||||
|
enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new proxy manager.
|
||||||
|
func NewManager(stateManager *statemanager.Manager) *Manager {
|
||||||
|
return &Manager{
|
||||||
|
stateManager: stateManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveNetworkServices returns the list of active network services.
|
||||||
|
func GetActiveNetworkServices() ([]string, error) {
|
||||||
|
cmd := exec.Command(networksetupPath, "-listallnetworkservices")
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list network services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := strings.Split(string(out), "\n")
|
||||||
|
var services []string
|
||||||
|
for _, line := range lines {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if line == "" || strings.HasPrefix(line, "*") || strings.Contains(line, "asterisk") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
services = append(services, line)
|
||||||
|
}
|
||||||
|
return services, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableWebProxy enables web proxy for all active network services.
|
||||||
|
func (m *Manager) EnableWebProxy(host string, port int) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.enabled {
|
||||||
|
log.Debug("web proxy already enabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
services, err := GetActiveNetworkServices()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var modifiedServices []string
|
||||||
|
for _, service := range services {
|
||||||
|
if err := m.enableProxyForService(service, host, port); err != nil {
|
||||||
|
log.Warnf("enable proxy for %s: %v", service, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
modifiedServices = append(modifiedServices, service)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.modifiedServices = modifiedServices
|
||||||
|
m.enabled = true
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
log.Infof("enabled web proxy on %d services -> %s:%d", len(modifiedServices), host, port)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) enableProxyForService(service, host string, port int) error {
|
||||||
|
portStr := fmt.Sprintf("%d", port)
|
||||||
|
|
||||||
|
// Set web proxy (HTTP)
|
||||||
|
cmd := exec.Command(networksetupPath, "-setwebproxy", service, host, portStr)
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("set web proxy: %w, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable web proxy
|
||||||
|
cmd = exec.Command(networksetupPath, "-setwebproxystate", service, "on")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("enable web proxy state: %w, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set secure web proxy (HTTPS)
|
||||||
|
cmd = exec.Command(networksetupPath, "-setsecurewebproxy", service, host, portStr)
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("set secure web proxy: %w, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable secure web proxy
|
||||||
|
cmd = exec.Command(networksetupPath, "-setsecurewebproxystate", service, "on")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("enable secure web proxy state: %w, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("enabled proxy for service %s", service)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableWebProxy disables web proxy for all modified network services.
|
||||||
|
func (m *Manager) DisableWebProxy() error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if !m.enabled {
|
||||||
|
log.Debug("web proxy already disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
services := m.modifiedServices
|
||||||
|
if len(services) == 0 {
|
||||||
|
services, _ = GetActiveNetworkServices()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
if err := m.disableProxyForService(service); err != nil {
|
||||||
|
log.Warnf("disable proxy for %s: %v", service, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.modifiedServices = nil
|
||||||
|
m.enabled = false
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
log.Info("disabled web proxy")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) disableProxyForService(service string) error {
|
||||||
|
// Disable web proxy (HTTP)
|
||||||
|
cmd := exec.Command(networksetupPath, "-setwebproxystate", service, "off")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("disable web proxy: %w, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable secure web proxy (HTTPS)
|
||||||
|
cmd = exec.Command(networksetupPath, "-setsecurewebproxystate", service, "off")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("disable secure web proxy: %w, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("disabled proxy for service %s", service)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAutoproxyURL sets the automatic proxy configuration URL (PAC file).
|
||||||
|
func (m *Manager) SetAutoproxyURL(pacURL string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
services, err := GetActiveNetworkServices()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var modifiedServices []string
|
||||||
|
for _, service := range services {
|
||||||
|
cmd := exec.Command(networksetupPath, "-setautoproxyurl", service, pacURL)
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Warnf("set autoproxy for %s: %v, output: %s", service, err, out)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = exec.Command(networksetupPath, "-setautoproxystate", service, "on")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Warnf("enable autoproxy for %s: %v, output: %s", service, err, out)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
modifiedServices = append(modifiedServices, service)
|
||||||
|
log.Debugf("set autoproxy URL for %s -> %s", service, pacURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.modifiedServices = modifiedServices
|
||||||
|
m.enabled = true
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableAutoproxy disables automatic proxy configuration.
|
||||||
|
func (m *Manager) DisableAutoproxy() error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
services := m.modifiedServices
|
||||||
|
if len(services) == 0 {
|
||||||
|
services, _ = GetActiveNetworkServices()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
cmd := exec.Command(networksetupPath, "-setautoproxystate", service, "off")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Warnf("disable autoproxy for %s: %v, output: %s", service, err, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.modifiedServices = nil
|
||||||
|
m.enabled = false
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEnabled returns whether the proxy is currently enabled.
|
||||||
|
func (m *Manager) IsEnabled() bool {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore restores proxy settings from a previous state.
|
||||||
|
func (m *Manager) Restore(services []string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
if err := m.disableProxyForService(service); err != nil {
|
||||||
|
log.Warnf("restore proxy for %s: %v", service, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.modifiedServices = nil
|
||||||
|
m.enabled = false
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) updateState() {
|
||||||
|
if m.stateManager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.enabled && len(m.modifiedServices) > 0 {
|
||||||
|
state := &ShutdownState{
|
||||||
|
ModifiedServices: m.modifiedServices,
|
||||||
|
}
|
||||||
|
if err := m.stateManager.UpdateState(state); err != nil {
|
||||||
|
log.Errorf("update proxy state: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := m.stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||||
|
log.Debugf("delete proxy state: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
45
client/internal/proxy/manager_other.go
Normal file
45
client/internal/proxy/manager_other.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
//go:build !darwin || ios
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager is a no-op proxy manager for non-macOS platforms.
|
||||||
|
type Manager struct{}
|
||||||
|
|
||||||
|
// NewManager creates a new proxy manager (no-op on non-macOS).
|
||||||
|
func NewManager(_ *statemanager.Manager) *Manager {
|
||||||
|
return &Manager{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableWebProxy is a no-op on non-macOS platforms.
|
||||||
|
func (m *Manager) EnableWebProxy(host string, port int) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableWebProxy is a no-op on non-macOS platforms.
|
||||||
|
func (m *Manager) DisableWebProxy() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAutoproxyURL is a no-op on non-macOS platforms.
|
||||||
|
func (m *Manager) SetAutoproxyURL(pacURL string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableAutoproxy is a no-op on non-macOS platforms.
|
||||||
|
func (m *Manager) DisableAutoproxy() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEnabled always returns false on non-macOS platforms.
|
||||||
|
func (m *Manager) IsEnabled() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore is a no-op on non-macOS platforms.
|
||||||
|
func (m *Manager) Restore(services []string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
88
client/internal/proxy/manager_test.go
Normal file
88
client/internal/proxy/manager_test.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetActiveNetworkServices(t *testing.T) {
|
||||||
|
services, err := GetActiveNetworkServices()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, services, "should have at least one network service")
|
||||||
|
|
||||||
|
// Check that services don't contain invalid entries
|
||||||
|
for _, service := range services {
|
||||||
|
assert.NotEmpty(t, service)
|
||||||
|
assert.NotContains(t, service, "*")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_EnableDisableWebProxy(t *testing.T) {
|
||||||
|
// Skip this test in CI as it requires admin privileges
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping proxy test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
m := NewManager(nil)
|
||||||
|
assert.NotNil(t, m)
|
||||||
|
assert.False(t, m.IsEnabled())
|
||||||
|
|
||||||
|
// This test would require admin privileges to actually enable the proxy
|
||||||
|
// So we just test the basic state management
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShutdownState_Name(t *testing.T) {
|
||||||
|
state := &ShutdownState{}
|
||||||
|
assert.Equal(t, "proxy_state", state.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShutdownState_Cleanup_EmptyServices(t *testing.T) {
|
||||||
|
state := &ShutdownState{
|
||||||
|
ModifiedServices: []string{},
|
||||||
|
}
|
||||||
|
err := state.Cleanup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContains(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
s string
|
||||||
|
substr string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"Enabled: Yes", "Enabled: Yes", true},
|
||||||
|
{"Enabled: No", "Enabled: Yes", false},
|
||||||
|
{"Server: 127.0.0.1\nEnabled: Yes\nPort: 8080", "Enabled: Yes", true},
|
||||||
|
{"", "Enabled: Yes", false},
|
||||||
|
{"Enabled: Yes", "", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.s+"_"+tt.substr, func(t *testing.T) {
|
||||||
|
got := contains(tt.s, tt.substr)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsProxyEnabled(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
output string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"Enabled: Yes\nServer: 127.0.0.1\nPort: 8080", true},
|
||||||
|
{"Enabled: No\nServer: \nPort: 0", false},
|
||||||
|
{"Server: 127.0.0.1\nEnabled: Yes\nPort: 8080", true},
|
||||||
|
{"", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.output, func(t *testing.T) {
|
||||||
|
got := isProxyEnabled(tt.output)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
105
client/internal/proxy/state_darwin.go
Normal file
105
client/internal/proxy/state_darwin.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ShutdownState stores proxy state for cleanup on unclean shutdown.
|
||||||
|
type ShutdownState struct {
|
||||||
|
ModifiedServices []string `json:"modified_services"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the state name for persistence.
|
||||||
|
func (s *ShutdownState) Name() string {
|
||||||
|
return "proxy_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup restores proxy settings after an unclean shutdown.
|
||||||
|
func (s *ShutdownState) Cleanup() error {
|
||||||
|
if len(s.ModifiedServices) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("cleaning up proxy state for %d services", len(s.ModifiedServices))
|
||||||
|
|
||||||
|
for _, service := range s.ModifiedServices {
|
||||||
|
// Disable web proxy (HTTP)
|
||||||
|
cmd := exec.Command(networksetupPath, "-setwebproxystate", service, "off")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Warnf("cleanup web proxy for %s: %v, output: %s", service, err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable secure web proxy (HTTPS)
|
||||||
|
cmd = exec.Command(networksetupPath, "-setsecurewebproxystate", service, "off")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Warnf("cleanup secure web proxy for %s: %v, output: %s", service, err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable autoproxy
|
||||||
|
cmd = exec.Command(networksetupPath, "-setautoproxystate", service, "off")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Warnf("cleanup autoproxy for %s: %v, output: %s", service, err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("cleaned up proxy for service %s", service)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterState registers the proxy state with the state manager.
|
||||||
|
func RegisterState(stateManager *statemanager.Manager) {
|
||||||
|
if stateManager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stateManager.RegisterState(&ShutdownState{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProxyState returns the current proxy state from the command line.
|
||||||
|
func GetProxyState(service string) (webProxy, secureProxy, autoProxy bool, err error) {
|
||||||
|
// Check web proxy state
|
||||||
|
cmd := exec.Command(networksetupPath, "-getwebproxy", service)
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return false, false, false, fmt.Errorf("get web proxy: %w", err)
|
||||||
|
}
|
||||||
|
webProxy = isProxyEnabled(string(out))
|
||||||
|
|
||||||
|
// Check secure web proxy state
|
||||||
|
cmd = exec.Command(networksetupPath, "-getsecurewebproxy", service)
|
||||||
|
out, err = cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return false, false, false, fmt.Errorf("get secure web proxy: %w", err)
|
||||||
|
}
|
||||||
|
secureProxy = isProxyEnabled(string(out))
|
||||||
|
|
||||||
|
// Check autoproxy state
|
||||||
|
cmd = exec.Command(networksetupPath, "-getautoproxyurl", service)
|
||||||
|
out, err = cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return false, false, false, fmt.Errorf("get autoproxy: %w", err)
|
||||||
|
}
|
||||||
|
autoProxy = isProxyEnabled(string(out))
|
||||||
|
|
||||||
|
return webProxy, secureProxy, autoProxy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isProxyEnabled(output string) bool {
|
||||||
|
return !contains(output, "Enabled: No") && contains(output, "Enabled: Yes")
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
for i := 0; i <= len(s)-len(substr); i++ {
|
||||||
|
if s[i:i+len(substr)] == substr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
24
client/internal/proxy/state_other.go
Normal file
24
client/internal/proxy/state_other.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
//go:build !darwin || ios
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ShutdownState is a no-op state for non-macOS platforms.
|
||||||
|
type ShutdownState struct{}
|
||||||
|
|
||||||
|
// Name returns the state name.
|
||||||
|
func (s *ShutdownState) Name() string {
|
||||||
|
return "proxy_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup is a no-op on non-macOS platforms.
|
||||||
|
func (s *ShutdownState) Cleanup() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterState is a no-op on non-macOS platforms.
|
||||||
|
func RegisterState(stateManager *statemanager.Manager) {
|
||||||
|
}
|
||||||
@@ -173,21 +173,12 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
||||||
var once sync.Once
|
|
||||||
var wgIface *net.Interface
|
|
||||||
toInterface := func() *net.Interface {
|
|
||||||
once.Do(func() {
|
|
||||||
wgIface = m.wgInterface.ToInterface()
|
|
||||||
})
|
|
||||||
return wgIface
|
|
||||||
}
|
|
||||||
|
|
||||||
m.routeRefCounter = refcounter.New(
|
m.routeRefCounter = refcounter.New(
|
||||||
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
||||||
return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface())
|
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
|
||||||
},
|
},
|
||||||
func(prefix netip.Prefix, _ struct{}) error {
|
func(prefix netip.Prefix, _ struct{}) error {
|
||||||
return m.sysOps.RemoveVPNRoute(prefix, toInterface())
|
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface())
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,17 +4,16 @@ package systemops
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
|
"syscall"
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
||||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||||
if routeMessageFlags&unix.RTF_UP == 0 {
|
if routeMessageFlags&syscall.RTF_UP == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 {
|
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -25,51 +24,42 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
|
|||||||
func formatBSDFlags(flags int) string {
|
func formatBSDFlags(flags int) string {
|
||||||
var flagStrs []string
|
var flagStrs []string
|
||||||
|
|
||||||
if flags&unix.RTF_UP != 0 {
|
if flags&syscall.RTF_UP != 0 {
|
||||||
flagStrs = append(flagStrs, "U")
|
flagStrs = append(flagStrs, "U")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_GATEWAY != 0 {
|
if flags&syscall.RTF_GATEWAY != 0 {
|
||||||
flagStrs = append(flagStrs, "G")
|
flagStrs = append(flagStrs, "G")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_HOST != 0 {
|
if flags&syscall.RTF_HOST != 0 {
|
||||||
flagStrs = append(flagStrs, "H")
|
flagStrs = append(flagStrs, "H")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_REJECT != 0 {
|
if flags&syscall.RTF_REJECT != 0 {
|
||||||
flagStrs = append(flagStrs, "R")
|
flagStrs = append(flagStrs, "R")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_DYNAMIC != 0 {
|
if flags&syscall.RTF_DYNAMIC != 0 {
|
||||||
flagStrs = append(flagStrs, "D")
|
flagStrs = append(flagStrs, "D")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_MODIFIED != 0 {
|
if flags&syscall.RTF_MODIFIED != 0 {
|
||||||
flagStrs = append(flagStrs, "M")
|
flagStrs = append(flagStrs, "M")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_STATIC != 0 {
|
if flags&syscall.RTF_STATIC != 0 {
|
||||||
flagStrs = append(flagStrs, "S")
|
flagStrs = append(flagStrs, "S")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_LLINFO != 0 {
|
if flags&syscall.RTF_LLINFO != 0 {
|
||||||
flagStrs = append(flagStrs, "L")
|
flagStrs = append(flagStrs, "L")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_LOCAL != 0 {
|
if flags&syscall.RTF_LOCAL != 0 {
|
||||||
flagStrs = append(flagStrs, "l")
|
flagStrs = append(flagStrs, "l")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_BLACKHOLE != 0 {
|
if flags&syscall.RTF_BLACKHOLE != 0 {
|
||||||
flagStrs = append(flagStrs, "B")
|
flagStrs = append(flagStrs, "B")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_CLONING != 0 {
|
if flags&syscall.RTF_CLONING != 0 {
|
||||||
flagStrs = append(flagStrs, "C")
|
flagStrs = append(flagStrs, "C")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_WASCLONED != 0 {
|
if flags&syscall.RTF_WASCLONED != 0 {
|
||||||
flagStrs = append(flagStrs, "W")
|
flagStrs = append(flagStrs, "W")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_PROTO1 != 0 {
|
|
||||||
flagStrs = append(flagStrs, "1")
|
|
||||||
}
|
|
||||||
if flags&unix.RTF_PROTO2 != 0 {
|
|
||||||
flagStrs = append(flagStrs, "2")
|
|
||||||
}
|
|
||||||
if flags&unix.RTF_PROTO3 != 0 {
|
|
||||||
flagStrs = append(flagStrs, "3")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(flagStrs) == 0 {
|
if len(flagStrs) == 0 {
|
||||||
return "-"
|
return "-"
|
||||||
|
|||||||
@@ -4,18 +4,17 @@ package systemops
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
|
"syscall"
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
||||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||||
if routeMessageFlags&unix.RTF_UP == 0 {
|
if routeMessageFlags&syscall.RTF_UP == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0
|
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||||
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 {
|
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -26,46 +25,37 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
|
|||||||
func formatBSDFlags(flags int) string {
|
func formatBSDFlags(flags int) string {
|
||||||
var flagStrs []string
|
var flagStrs []string
|
||||||
|
|
||||||
if flags&unix.RTF_UP != 0 {
|
if flags&syscall.RTF_UP != 0 {
|
||||||
flagStrs = append(flagStrs, "U")
|
flagStrs = append(flagStrs, "U")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_GATEWAY != 0 {
|
if flags&syscall.RTF_GATEWAY != 0 {
|
||||||
flagStrs = append(flagStrs, "G")
|
flagStrs = append(flagStrs, "G")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_HOST != 0 {
|
if flags&syscall.RTF_HOST != 0 {
|
||||||
flagStrs = append(flagStrs, "H")
|
flagStrs = append(flagStrs, "H")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_REJECT != 0 {
|
if flags&syscall.RTF_REJECT != 0 {
|
||||||
flagStrs = append(flagStrs, "R")
|
flagStrs = append(flagStrs, "R")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_DYNAMIC != 0 {
|
if flags&syscall.RTF_DYNAMIC != 0 {
|
||||||
flagStrs = append(flagStrs, "D")
|
flagStrs = append(flagStrs, "D")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_MODIFIED != 0 {
|
if flags&syscall.RTF_MODIFIED != 0 {
|
||||||
flagStrs = append(flagStrs, "M")
|
flagStrs = append(flagStrs, "M")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_STATIC != 0 {
|
if flags&syscall.RTF_STATIC != 0 {
|
||||||
flagStrs = append(flagStrs, "S")
|
flagStrs = append(flagStrs, "S")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_LLINFO != 0 {
|
if flags&syscall.RTF_LLINFO != 0 {
|
||||||
flagStrs = append(flagStrs, "L")
|
flagStrs = append(flagStrs, "L")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_LOCAL != 0 {
|
if flags&syscall.RTF_LOCAL != 0 {
|
||||||
flagStrs = append(flagStrs, "l")
|
flagStrs = append(flagStrs, "l")
|
||||||
}
|
}
|
||||||
if flags&unix.RTF_BLACKHOLE != 0 {
|
if flags&syscall.RTF_BLACKHOLE != 0 {
|
||||||
flagStrs = append(flagStrs, "B")
|
flagStrs = append(flagStrs, "B")
|
||||||
}
|
}
|
||||||
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
|
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||||
if flags&unix.RTF_PROTO1 != 0 {
|
|
||||||
flagStrs = append(flagStrs, "1")
|
|
||||||
}
|
|
||||||
if flags&unix.RTF_PROTO2 != 0 {
|
|
||||||
flagStrs = append(flagStrs, "2")
|
|
||||||
}
|
|
||||||
if flags&unix.RTF_PROTO3 != 0 {
|
|
||||||
flagStrs = append(flagStrs, "3")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(flagStrs) == 0 {
|
if len(flagStrs) == 0 {
|
||||||
return "-"
|
return "-"
|
||||||
|
|||||||
4
go.mod
4
go.mod
@@ -42,7 +42,6 @@ require (
|
|||||||
github.com/cilium/ebpf v0.15.0
|
github.com/cilium/ebpf v0.15.0
|
||||||
github.com/coder/websocket v1.8.13
|
github.com/coder/websocket v1.8.13
|
||||||
github.com/coreos/go-iptables v0.7.0
|
github.com/coreos/go-iptables v0.7.0
|
||||||
github.com/coreos/go-oidc/v3 v3.14.1
|
|
||||||
github.com/creack/pty v1.1.24
|
github.com/creack/pty v1.1.24
|
||||||
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
|
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
|
||||||
github.com/dexidp/dex/api/v2 v2.4.0
|
github.com/dexidp/dex/api/v2 v2.4.0
|
||||||
@@ -69,7 +68,7 @@ require (
|
|||||||
github.com/mdlayher/socket v0.5.1
|
github.com/mdlayher/socket v0.5.1
|
||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||||
github.com/oapi-codegen/runtime v1.1.2
|
github.com/oapi-codegen/runtime v1.1.2
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
@@ -168,6 +167,7 @@ require (
|
|||||||
github.com/containerd/containerd v1.7.29 // indirect
|
github.com/containerd/containerd v1.7.29 // indirect
|
||||||
github.com/containerd/log v0.1.0 // indirect
|
github.com/containerd/log v0.1.0 // indirect
|
||||||
github.com/containerd/platforms v0.2.1 // indirect
|
github.com/containerd/platforms v0.2.1 // indirect
|
||||||
|
github.com/coreos/go-oidc/v3 v3.14.1 // indirect
|
||||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -406,8 +406,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
|||||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25 h1:iwAq/Ncaq0etl4uAlVsbNBzC1yY52o0AmY7uCm2AMTs=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f h1:CTBf0je/FpKr2lVSMZLak7m8aaWcS6ur4SOfhSSazFI=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||||
|
|||||||
@@ -327,60 +327,6 @@ func ensureLocalConnector(ctx context.Context, stor storage.Storage) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasNonLocalConnectors checks if there are any connectors other than the local connector.
|
|
||||||
func (p *Provider) HasNonLocalConnectors(ctx context.Context) (bool, error) {
|
|
||||||
connectors, err := p.storage.ListConnectors(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("failed to list connectors: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
p.logger.Info("checking for non-local connectors", "total_connectors", len(connectors))
|
|
||||||
for _, conn := range connectors {
|
|
||||||
p.logger.Info("found connector in storage", "id", conn.ID, "type", conn.Type, "name", conn.Name)
|
|
||||||
if conn.ID != "local" || conn.Type != "local" {
|
|
||||||
p.logger.Info("found non-local connector", "id", conn.ID)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
p.logger.Info("no non-local connectors found")
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DisableLocalAuth removes the local (password) connector.
|
|
||||||
// Returns an error if no other connectors are configured.
|
|
||||||
func (p *Provider) DisableLocalAuth(ctx context.Context) error {
|
|
||||||
hasOthers, err := p.HasNonLocalConnectors(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !hasOthers {
|
|
||||||
return fmt.Errorf("cannot disable local authentication: no other identity providers configured")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if local connector exists
|
|
||||||
_, err = p.storage.GetConnector(ctx, "local")
|
|
||||||
if errors.Is(err, storage.ErrNotFound) {
|
|
||||||
// Already disabled
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to check local connector: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete the local connector
|
|
||||||
if err := p.storage.DeleteConnector(ctx, "local"); err != nil {
|
|
||||||
return fmt.Errorf("failed to delete local connector: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
p.logger.Info("local authentication disabled")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// EnableLocalAuth creates the local (password) connector if it doesn't exist.
|
|
||||||
func (p *Provider) EnableLocalAuth(ctx context.Context) error {
|
|
||||||
return ensureLocalConnector(ctx, p.storage)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ensureStaticConnectors creates or updates static connectors in storage
|
// ensureStaticConnectors creates or updates static connectors in storage
|
||||||
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
|
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
|
||||||
for _, conn := range connectors {
|
for _, conn := range connectors {
|
||||||
|
|||||||
@@ -91,17 +91,16 @@ read_reverse_proxy_type() {
|
|||||||
echo " [3] Nginx Proxy Manager (generates config + instructions)" > /dev/stderr
|
echo " [3] Nginx Proxy Manager (generates config + instructions)" > /dev/stderr
|
||||||
echo " [4] External Caddy (generates Caddyfile snippet)" > /dev/stderr
|
echo " [4] External Caddy (generates Caddyfile snippet)" > /dev/stderr
|
||||||
echo " [5] Other/Manual (displays setup documentation)" > /dev/stderr
|
echo " [5] Other/Manual (displays setup documentation)" > /dev/stderr
|
||||||
echo " [6] Traefik TCP Proxy (single port 443 + STUN)" > /dev/stderr
|
|
||||||
echo "" > /dev/stderr
|
echo "" > /dev/stderr
|
||||||
echo -n "Enter choice [0-6] (default: 0): " > /dev/stderr
|
echo -n "Enter choice [0-5] (default: 0): " > /dev/stderr
|
||||||
read -r CHOICE < /dev/tty
|
read -r CHOICE < /dev/tty
|
||||||
|
|
||||||
if [[ -z "$CHOICE" ]]; then
|
if [[ -z "$CHOICE" ]]; then
|
||||||
CHOICE="0"
|
CHOICE="0"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ ! "$CHOICE" =~ ^[0-6]$ ]]; then
|
if [[ ! "$CHOICE" =~ ^[0-5]$ ]]; then
|
||||||
echo "Invalid choice. Please enter a number between 0 and 6." > /dev/stderr
|
echo "Invalid choice. Please enter a number between 0 and 5." > /dev/stderr
|
||||||
read_reverse_proxy_type
|
read_reverse_proxy_type
|
||||||
return
|
return
|
||||||
fi
|
fi
|
||||||
@@ -141,35 +140,6 @@ read_traefik_certresolver() {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
read_traefik_tcp_acme_email() {
|
|
||||||
echo "" > /dev/stderr
|
|
||||||
echo "Enter your email for Let's Encrypt certificate notifications." > /dev/stderr
|
|
||||||
echo -n "Email address: " > /dev/stderr
|
|
||||||
read -r EMAIL < /dev/tty
|
|
||||||
if [[ -z "$EMAIL" ]]; then
|
|
||||||
echo "Email is required for Let's Encrypt." > /dev/stderr
|
|
||||||
read_traefik_tcp_acme_email
|
|
||||||
return
|
|
||||||
fi
|
|
||||||
echo "$EMAIL"
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
read_enable_proxy() {
|
|
||||||
echo "" > /dev/stderr
|
|
||||||
echo "Do you want to enable the NetBird Proxy service?" > /dev/stderr
|
|
||||||
echo "The proxy exposes internal NetBird network resources to the internet." > /dev/stderr
|
|
||||||
echo -n "Enable proxy? [y/N]: " > /dev/stderr
|
|
||||||
read -r CHOICE < /dev/tty
|
|
||||||
|
|
||||||
if [[ "$CHOICE" =~ ^[Yy]$ ]]; then
|
|
||||||
echo "true"
|
|
||||||
else
|
|
||||||
echo "false"
|
|
||||||
fi
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
read_port_binding_preference() {
|
read_port_binding_preference() {
|
||||||
echo "" > /dev/stderr
|
echo "" > /dev/stderr
|
||||||
echo "Should container ports be bound to localhost only (127.0.0.1)?" > /dev/stderr
|
echo "Should container ports be bound to localhost only (127.0.0.1)?" > /dev/stderr
|
||||||
@@ -236,30 +206,6 @@ wait_management() {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
wait_management_traefik() {
|
|
||||||
set +e
|
|
||||||
echo -n "Waiting for Management server to become ready"
|
|
||||||
counter=1
|
|
||||||
while true; do
|
|
||||||
# Check the embedded IdP endpoint through Traefik
|
|
||||||
if curl -sk -f -o /dev/null "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2/.well-known/openid-configuration" 2>/dev/null; then
|
|
||||||
break
|
|
||||||
fi
|
|
||||||
if [[ $counter -eq 60 ]]; then
|
|
||||||
echo ""
|
|
||||||
echo "Taking too long. Checking logs..."
|
|
||||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 traefik
|
|
||||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 management
|
|
||||||
fi
|
|
||||||
echo -n " ."
|
|
||||||
sleep 2
|
|
||||||
counter=$((counter + 1))
|
|
||||||
done
|
|
||||||
echo " done"
|
|
||||||
set -e
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
wait_management_direct() {
|
wait_management_direct() {
|
||||||
set +e
|
set +e
|
||||||
local upstream_host=$(get_upstream_host)
|
local upstream_host=$(get_upstream_host)
|
||||||
@@ -300,12 +246,10 @@ initialize_default_values() {
|
|||||||
|
|
||||||
# Docker images
|
# Docker images
|
||||||
CADDY_IMAGE="caddy"
|
CADDY_IMAGE="caddy"
|
||||||
#DASHBOARD_IMAGE="netbirdio/dashboard:latest"
|
DASHBOARD_IMAGE="netbirdio/dashboard:latest"
|
||||||
DASHBOARD_IMAGE="netbirdio/dashboard:pr-552"
|
|
||||||
SIGNAL_IMAGE="netbirdio/signal:latest"
|
SIGNAL_IMAGE="netbirdio/signal:latest"
|
||||||
RELAY_IMAGE="netbirdio/relay:latest"
|
RELAY_IMAGE="netbirdio/relay:latest"
|
||||||
MANAGEMENT_IMAGE="netbirdio/management:latest"
|
MANAGEMENT_IMAGE="netbirdio/management:latest"
|
||||||
PROXY_IMAGE=""
|
|
||||||
|
|
||||||
# Reverse proxy configuration
|
# Reverse proxy configuration
|
||||||
REVERSE_PROXY_TYPE="0"
|
REVERSE_PROXY_TYPE="0"
|
||||||
@@ -319,14 +263,6 @@ initialize_default_values() {
|
|||||||
RELAY_HOST_PORT="8084"
|
RELAY_HOST_PORT="8084"
|
||||||
BIND_LOCALHOST_ONLY="true"
|
BIND_LOCALHOST_ONLY="true"
|
||||||
EXTERNAL_PROXY_NETWORK=""
|
EXTERNAL_PROXY_NETWORK=""
|
||||||
|
|
||||||
# Traefik TCP proxy configuration
|
|
||||||
TRAEFIK_IMAGE="traefik:v3.6"
|
|
||||||
TRAEFIK_TCP_ACME_EMAIL=""
|
|
||||||
|
|
||||||
# NetBird Proxy configuration
|
|
||||||
ENABLE_PROXY="false"
|
|
||||||
PROXY_TOKEN=""
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -357,17 +293,8 @@ configure_reverse_proxy() {
|
|||||||
TRAEFIK_CERTRESOLVER=$(read_traefik_certresolver)
|
TRAEFIK_CERTRESOLVER=$(read_traefik_certresolver)
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Handle Traefik TCP proxy prompts
|
|
||||||
if [[ "$REVERSE_PROXY_TYPE" == "6" ]]; then
|
|
||||||
TRAEFIK_TCP_ACME_EMAIL=$(read_traefik_tcp_acme_email)
|
|
||||||
|
|
||||||
# Prompt for NetBird Proxy configuration
|
|
||||||
ENABLE_PROXY=$(read_enable_proxy)
|
|
||||||
# Note: PROXY_TOKEN will be auto-generated after Management starts
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Handle port binding for external proxy options (2-5)
|
# Handle port binding for external proxy options (2-5)
|
||||||
if [[ "$REVERSE_PROXY_TYPE" -ge 2 && "$REVERSE_PROXY_TYPE" -le 5 ]]; then
|
if [[ "$REVERSE_PROXY_TYPE" -ge 2 ]]; then
|
||||||
BIND_LOCALHOST_ONLY=$(read_port_binding_preference)
|
BIND_LOCALHOST_ONLY=$(read_port_binding_preference)
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@@ -386,7 +313,7 @@ check_existing_installation() {
|
|||||||
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
|
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
|
||||||
echo "You can use the following commands:"
|
echo "You can use the following commands:"
|
||||||
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
|
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
|
||||||
echo " rm -f docker-compose.yml Caddyfile dashboard.env management.json relay.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt proxy.env"
|
echo " rm -f docker-compose.yml Caddyfile dashboard.env management.json relay.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
|
||||||
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
|
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
@@ -420,15 +347,6 @@ generate_configuration_files() {
|
|||||||
5)
|
5)
|
||||||
render_docker_compose_exposed_ports > docker-compose.yml
|
render_docker_compose_exposed_ports > docker-compose.yml
|
||||||
;;
|
;;
|
||||||
6)
|
|
||||||
render_docker_compose_traefik_tcp > docker-compose.yml
|
|
||||||
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
|
||||||
# Create placeholder proxy.env so docker-compose can validate
|
|
||||||
# This will be overwritten with the actual token after Management starts
|
|
||||||
echo "# Placeholder - will be updated with token after Management starts" > proxy.env
|
|
||||||
echo "NB_PROXY_TOKEN=placeholder" >> proxy.env
|
|
||||||
fi
|
|
||||||
;;
|
|
||||||
*)
|
*)
|
||||||
echo "Invalid reverse proxy type: $REVERSE_PROXY_TYPE" > /dev/stderr
|
echo "Invalid reverse proxy type: $REVERSE_PROXY_TYPE" > /dev/stderr
|
||||||
exit 1
|
exit 1
|
||||||
@@ -484,50 +402,6 @@ start_services_and_show_instructions() {
|
|||||||
echo ""
|
echo ""
|
||||||
echo "NetBird containers are running. Configure NPM as shown above, then access:"
|
echo "NetBird containers are running. Configure NPM as shown above, then access:"
|
||||||
echo " $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
echo " $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
||||||
elif [[ "$REVERSE_PROXY_TYPE" == "6" ]]; then
|
|
||||||
# Traefik TCP Proxy - two-phase startup if proxy is enabled
|
|
||||||
echo -e "$MSG_STARTING_SERVICES"
|
|
||||||
|
|
||||||
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
|
||||||
# Phase 1: Start core services (without proxy)
|
|
||||||
echo "Starting core services..."
|
|
||||||
$DOCKER_COMPOSE_COMMAND up -d traefik dashboard signal relay management
|
|
||||||
|
|
||||||
sleep 3
|
|
||||||
wait_management_traefik
|
|
||||||
|
|
||||||
# Phase 2: Create proxy token and start proxy
|
|
||||||
echo ""
|
|
||||||
echo "Creating proxy access token..."
|
|
||||||
# Use docker exec with bash to run the token command directly
|
|
||||||
# (bypassing the entrypoint which adds 'management' as first arg)
|
|
||||||
PROXY_TOKEN=$($DOCKER_COMPOSE_COMMAND exec -T management \
|
|
||||||
bash -c '/go/bin/netbird-mgmt token create --name "default-proxy" --config /etc/netbird/management.json' 2>/dev/null | grep "^Token:" | awk '{print $2}')
|
|
||||||
|
|
||||||
if [[ -z "$PROXY_TOKEN" ]]; then
|
|
||||||
echo "ERROR: Failed to create proxy token. Check management logs." > /dev/stderr
|
|
||||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 management
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "Proxy token created successfully."
|
|
||||||
|
|
||||||
# Generate proxy.env with the token
|
|
||||||
render_proxy_env > proxy.env
|
|
||||||
|
|
||||||
# Start proxy service
|
|
||||||
echo "Starting proxy service..."
|
|
||||||
$DOCKER_COMPOSE_COMMAND up -d proxy
|
|
||||||
else
|
|
||||||
# No proxy - start all services at once
|
|
||||||
$DOCKER_COMPOSE_COMMAND up -d
|
|
||||||
|
|
||||||
sleep 3
|
|
||||||
wait_management_traefik
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo -e "$MSG_DONE"
|
|
||||||
print_post_setup_instructions
|
|
||||||
else
|
else
|
||||||
# External proxies (nginx, external Caddy, other) - need manual config first
|
# External proxies (nginx, external Caddy, other) - need manual config first
|
||||||
print_post_setup_instructions
|
print_post_setup_instructions
|
||||||
@@ -673,29 +547,6 @@ EOF
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
render_proxy_env() {
|
|
||||||
cat <<EOF
|
|
||||||
# NetBird Proxy Configuration
|
|
||||||
NB_PROXY_DEBUG_LOGS=false
|
|
||||||
# Use internal Docker network to connect to management (avoids hairpin NAT issues)
|
|
||||||
NB_PROXY_MANAGEMENT_ADDRESS=http://management:80
|
|
||||||
# Allow insecure gRPC connection to management (required for internal Docker network)
|
|
||||||
NB_PROXY_ALLOW_INSECURE=true
|
|
||||||
# Public URL where this proxy is reachable (used for cluster registration)
|
|
||||||
NB_PROXY_DOMAIN=$NETBIRD_DOMAIN
|
|
||||||
NB_PROXY_ADDRESS=:8443
|
|
||||||
NB_PROXY_TOKEN=$PROXY_TOKEN
|
|
||||||
NB_PROXY_CERTIFICATE_DIRECTORY=/certs
|
|
||||||
NB_PROXY_ACME_CERTIFICATES=true
|
|
||||||
NB_PROXY_ACME_CHALLENGE_TYPE=tls-alpn-01
|
|
||||||
NB_PROXY_OIDC_CLIENT_ID=netbird-proxy
|
|
||||||
NB_PROXY_OIDC_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2
|
|
||||||
NB_PROXY_OIDC_SCOPES=openid,profile,email
|
|
||||||
NB_PROXY_FORWARDED_PROTO=https
|
|
||||||
EOF
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
render_docker_compose() {
|
render_docker_compose() {
|
||||||
cat <<EOF
|
cat <<EOF
|
||||||
services:
|
services:
|
||||||
@@ -885,18 +736,7 @@ $(if [[ -n "$tls_labels" ]]; then echo " - traefik.http.routers.netbird-rel
|
|||||||
|
|
||||||
# Management (includes embedded IdP)
|
# Management (includes embedded IdP)
|
||||||
management:
|
management:
|
||||||
$(if [[ "$ENABLE_PROXY" == "true" ]]; then
|
|
||||||
cat <<MGMT_BUILD
|
|
||||||
build:
|
|
||||||
context: ..
|
|
||||||
dockerfile: management/Dockerfile.multistage
|
|
||||||
pull_policy: build
|
|
||||||
MGMT_BUILD
|
|
||||||
else
|
|
||||||
cat <<MGMT_IMAGE
|
|
||||||
image: $MANAGEMENT_IMAGE
|
image: $MANAGEMENT_IMAGE
|
||||||
MGMT_IMAGE
|
|
||||||
fi)
|
|
||||||
container_name: netbird-management
|
container_name: netbird-management
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
networks: [$network_name]
|
networks: [$network_name]
|
||||||
@@ -1275,258 +1115,6 @@ EOF
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
render_docker_compose_traefik_tcp() {
|
|
||||||
# Generate proxy service section if enabled
|
|
||||||
local proxy_service=""
|
|
||||||
local proxy_volumes=""
|
|
||||||
local proxy_tcp_labels=""
|
|
||||||
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
|
||||||
proxy_service="
|
|
||||||
# NetBird Proxy - exposes internal resources to the internet
|
|
||||||
proxy:
|
|
||||||
build:
|
|
||||||
context: ../
|
|
||||||
dockerfile: proxy/Dockerfile
|
|
||||||
# Always rebuild to pick up code changes during testing
|
|
||||||
pull_policy: build
|
|
||||||
#image: $PROXY_IMAGE
|
|
||||||
container_name: netbird-proxy
|
|
||||||
# Hairpin NAT fix: route domain back to traefik's static IP within Docker
|
|
||||||
extra_hosts:
|
|
||||||
- \"$NETBIRD_DOMAIN:172.30.0.10\"
|
|
||||||
restart: unless-stopped
|
|
||||||
networks: [netbird]
|
|
||||||
depends_on:
|
|
||||||
- signal
|
|
||||||
env_file:
|
|
||||||
- ./proxy.env
|
|
||||||
volumes:
|
|
||||||
- netbird_proxy_certs:/certs
|
|
||||||
labels:
|
|
||||||
# TCP passthrough for any unmatched domain (proxy handles its own TLS)
|
|
||||||
- traefik.enable=true
|
|
||||||
- traefik.tcp.routers.proxy-passthrough.entrypoints=websecure
|
|
||||||
- traefik.tcp.routers.proxy-passthrough.rule=HostSNI(\`*\`)
|
|
||||||
- traefik.tcp.routers.proxy-passthrough.tls.passthrough=true
|
|
||||||
- traefik.tcp.routers.proxy-passthrough.service=proxy-tls
|
|
||||||
- traefik.tcp.routers.proxy-passthrough.priority=1
|
|
||||||
- traefik.tcp.services.proxy-tls.loadbalancer.server.port=8443
|
|
||||||
logging:
|
|
||||||
driver: \"json-file\"
|
|
||||||
options:
|
|
||||||
max-size: \"500m\"
|
|
||||||
max-file: \"2\"
|
|
||||||
"
|
|
||||||
proxy_volumes="
|
|
||||||
netbird_proxy_certs:"
|
|
||||||
fi
|
|
||||||
|
|
||||||
cat <<EOF
|
|
||||||
services:
|
|
||||||
# Traefik - single port 443 entry point with TLS termination
|
|
||||||
traefik:
|
|
||||||
image: $TRAEFIK_IMAGE
|
|
||||||
container_name: netbird-traefik
|
|
||||||
restart: unless-stopped
|
|
||||||
networks:
|
|
||||||
netbird:
|
|
||||||
ipv4_address: 172.30.0.10
|
|
||||||
ports:
|
|
||||||
- '443:443'
|
|
||||||
volumes:
|
|
||||||
- netbird_traefik_data:/data
|
|
||||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
|
||||||
command:
|
|
||||||
# Logging
|
|
||||||
- --log.level=INFO
|
|
||||||
- --accesslog=true
|
|
||||||
# Docker provider
|
|
||||||
- --providers.docker=true
|
|
||||||
- --providers.docker.exposedbydefault=false
|
|
||||||
- --providers.docker.network=netbird
|
|
||||||
# Entrypoints
|
|
||||||
- --entrypoints.websecure.address=:443
|
|
||||||
- --entrypoints.websecure.allowACMEByPass=true
|
|
||||||
# Disable timeouts for long-lived gRPC streams
|
|
||||||
- --entrypoints.websecure.transport.respondingtimeouts.readtimeout=0s
|
|
||||||
- --entrypoints.websecure.transport.respondingtimeouts.writetimeout=0s
|
|
||||||
- --entrypoints.websecure.transport.respondingtimeouts.idletimeout=0s
|
|
||||||
# Let's Encrypt ACME
|
|
||||||
- --certificatesresolvers.letsencrypt.acme.email=$TRAEFIK_TCP_ACME_EMAIL
|
|
||||||
- --certificatesresolvers.letsencrypt.acme.storage=/data/acme.json
|
|
||||||
- --certificatesresolvers.letsencrypt.acme.tlschallenge=true
|
|
||||||
# gRPC transport settings (disable response timeout for long-lived streams)
|
|
||||||
- --serverstransport.forwardingtimeouts.responseheadertimeout=0s
|
|
||||||
- --serverstransport.forwardingtimeouts.idleconntimeout=0s
|
|
||||||
logging:
|
|
||||||
driver: "json-file"
|
|
||||||
options:
|
|
||||||
max-size: "500m"
|
|
||||||
max-file: "2"
|
|
||||||
|
|
||||||
# UI dashboard
|
|
||||||
dashboard:
|
|
||||||
image: $DASHBOARD_IMAGE
|
|
||||||
container_name: netbird-dashboard
|
|
||||||
restart: unless-stopped
|
|
||||||
networks: [netbird]
|
|
||||||
env_file:
|
|
||||||
- ./dashboard.env
|
|
||||||
labels:
|
|
||||||
- traefik.enable=true
|
|
||||||
- traefik.http.routers.netbird-dashboard.entrypoints=websecure
|
|
||||||
- traefik.http.routers.netbird-dashboard.rule=Host(\`$NETBIRD_DOMAIN\`)
|
|
||||||
- traefik.http.routers.netbird-dashboard.tls=true
|
|
||||||
- traefik.http.routers.netbird-dashboard.tls.certresolver=letsencrypt
|
|
||||||
- traefik.http.routers.netbird-dashboard.service=dashboard
|
|
||||||
- traefik.http.routers.netbird-dashboard.priority=1
|
|
||||||
- traefik.http.services.dashboard.loadbalancer.server.port=80
|
|
||||||
logging:
|
|
||||||
driver: "json-file"
|
|
||||||
options:
|
|
||||||
max-size: "500m"
|
|
||||||
max-file: "2"
|
|
||||||
|
|
||||||
# Signal
|
|
||||||
signal:
|
|
||||||
image: $SIGNAL_IMAGE
|
|
||||||
container_name: netbird-signal
|
|
||||||
restart: unless-stopped
|
|
||||||
networks: [netbird]
|
|
||||||
labels:
|
|
||||||
- traefik.enable=true
|
|
||||||
# Signal WebSocket
|
|
||||||
- traefik.http.routers.netbird-signal-ws.entrypoints=websecure
|
|
||||||
- traefik.http.routers.netbird-signal-ws.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/ws-proxy/signal\`)
|
|
||||||
- traefik.http.routers.netbird-signal-ws.tls=true
|
|
||||||
- traefik.http.routers.netbird-signal-ws.tls.certresolver=letsencrypt
|
|
||||||
- traefik.http.routers.netbird-signal-ws.service=signal-ws
|
|
||||||
- traefik.http.routers.netbird-signal-ws.priority=100
|
|
||||||
- traefik.http.services.signal-ws.loadbalancer.server.port=80
|
|
||||||
# Signal gRPC
|
|
||||||
- traefik.http.routers.netbird-signal-grpc.entrypoints=websecure
|
|
||||||
- traefik.http.routers.netbird-signal-grpc.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/signalexchange.SignalExchange/\`)
|
|
||||||
- traefik.http.routers.netbird-signal-grpc.tls=true
|
|
||||||
- traefik.http.routers.netbird-signal-grpc.tls.certresolver=letsencrypt
|
|
||||||
- traefik.http.routers.netbird-signal-grpc.service=signal-grpc
|
|
||||||
- traefik.http.routers.netbird-signal-grpc.priority=100
|
|
||||||
- traefik.http.services.signal-grpc.loadbalancer.server.port=10000
|
|
||||||
- traefik.http.services.signal-grpc.loadbalancer.server.scheme=h2c
|
|
||||||
logging:
|
|
||||||
driver: "json-file"
|
|
||||||
options:
|
|
||||||
max-size: "500m"
|
|
||||||
max-file: "2"
|
|
||||||
|
|
||||||
# Relay (includes embedded STUN server)
|
|
||||||
relay:
|
|
||||||
image: $RELAY_IMAGE
|
|
||||||
container_name: netbird-relay
|
|
||||||
restart: unless-stopped
|
|
||||||
networks: [netbird]
|
|
||||||
ports:
|
|
||||||
- '$NETBIRD_STUN_PORT:$NETBIRD_STUN_PORT/udp'
|
|
||||||
env_file:
|
|
||||||
- ./relay.env
|
|
||||||
labels:
|
|
||||||
- traefik.enable=true
|
|
||||||
- traefik.http.routers.netbird-relay.entrypoints=websecure
|
|
||||||
- traefik.http.routers.netbird-relay.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/relay\`)
|
|
||||||
- traefik.http.routers.netbird-relay.tls=true
|
|
||||||
- traefik.http.routers.netbird-relay.tls.certresolver=letsencrypt
|
|
||||||
- traefik.http.routers.netbird-relay.service=relay
|
|
||||||
- traefik.http.routers.netbird-relay.priority=100
|
|
||||||
- traefik.http.services.relay.loadbalancer.server.port=80
|
|
||||||
logging:
|
|
||||||
driver: "json-file"
|
|
||||||
options:
|
|
||||||
max-size: "500m"
|
|
||||||
max-file: "2"
|
|
||||||
|
|
||||||
# Management (includes embedded IdP)
|
|
||||||
management:
|
|
||||||
$(if [[ "$ENABLE_PROXY" == "true" ]]; then
|
|
||||||
cat <<MGMT_BUILD
|
|
||||||
build:
|
|
||||||
context: ..
|
|
||||||
dockerfile: management/Dockerfile.multistage
|
|
||||||
pull_policy: build
|
|
||||||
MGMT_BUILD
|
|
||||||
else
|
|
||||||
cat <<MGMT_IMAGE
|
|
||||||
image: $MANAGEMENT_IMAGE
|
|
||||||
MGMT_IMAGE
|
|
||||||
fi)
|
|
||||||
container_name: netbird-management
|
|
||||||
restart: unless-stopped
|
|
||||||
networks: [netbird]
|
|
||||||
volumes:
|
|
||||||
- netbird_management:/var/lib/netbird
|
|
||||||
- ./management.json:/etc/netbird/management.json
|
|
||||||
command: [
|
|
||||||
"--port", "80",
|
|
||||||
"--log-file", "console",
|
|
||||||
"--log-level", "info",
|
|
||||||
"--disable-anonymous-metrics=false",
|
|
||||||
"--single-account-mode-domain=netbird.selfhosted",
|
|
||||||
"--dns-domain=netbird.selfhosted",
|
|
||||||
"--idp-sign-key-refresh-enabled",
|
|
||||||
]
|
|
||||||
labels:
|
|
||||||
- traefik.enable=true
|
|
||||||
# Management API
|
|
||||||
- traefik.http.routers.netbird-api.entrypoints=websecure
|
|
||||||
- traefik.http.routers.netbird-api.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/api\`)
|
|
||||||
- traefik.http.routers.netbird-api.tls=true
|
|
||||||
- traefik.http.routers.netbird-api.tls.certresolver=letsencrypt
|
|
||||||
- traefik.http.routers.netbird-api.service=management
|
|
||||||
- traefik.http.routers.netbird-api.priority=100
|
|
||||||
# Management WebSocket
|
|
||||||
- traefik.http.routers.netbird-mgmt-ws.entrypoints=websecure
|
|
||||||
- traefik.http.routers.netbird-mgmt-ws.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/ws-proxy/management\`)
|
|
||||||
- traefik.http.routers.netbird-mgmt-ws.tls=true
|
|
||||||
- traefik.http.routers.netbird-mgmt-ws.tls.certresolver=letsencrypt
|
|
||||||
- traefik.http.routers.netbird-mgmt-ws.service=management
|
|
||||||
- traefik.http.routers.netbird-mgmt-ws.priority=100
|
|
||||||
# Management gRPC
|
|
||||||
- traefik.http.routers.netbird-mgmt-grpc.entrypoints=websecure
|
|
||||||
- traefik.http.routers.netbird-mgmt-grpc.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/management.ManagementService/\`)
|
|
||||||
- traefik.http.routers.netbird-mgmt-grpc.tls=true
|
|
||||||
- traefik.http.routers.netbird-mgmt-grpc.tls.certresolver=letsencrypt
|
|
||||||
- traefik.http.routers.netbird-mgmt-grpc.service=management-grpc
|
|
||||||
- traefik.http.routers.netbird-mgmt-grpc.priority=100
|
|
||||||
# OAuth2 (embedded IdP)
|
|
||||||
- traefik.http.routers.netbird-oauth2.entrypoints=websecure
|
|
||||||
- traefik.http.routers.netbird-oauth2.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/oauth2\`)
|
|
||||||
- traefik.http.routers.netbird-oauth2.tls=true
|
|
||||||
- traefik.http.routers.netbird-oauth2.tls.certresolver=letsencrypt
|
|
||||||
- traefik.http.routers.netbird-oauth2.service=management
|
|
||||||
- traefik.http.routers.netbird-oauth2.priority=100
|
|
||||||
# Services
|
|
||||||
- traefik.http.services.management.loadbalancer.server.port=80
|
|
||||||
- traefik.http.services.management-grpc.loadbalancer.server.port=80
|
|
||||||
- traefik.http.services.management-grpc.loadbalancer.server.scheme=h2c
|
|
||||||
logging:
|
|
||||||
driver: "json-file"
|
|
||||||
options:
|
|
||||||
max-size: "500m"
|
|
||||||
max-file: "2"
|
|
||||||
${proxy_service}
|
|
||||||
volumes:
|
|
||||||
netbird_traefik_data:
|
|
||||||
netbird_management:${proxy_volumes}
|
|
||||||
|
|
||||||
networks:
|
|
||||||
netbird:
|
|
||||||
driver: bridge
|
|
||||||
ipam:
|
|
||||||
config:
|
|
||||||
- subnet: 172.30.0.0/24
|
|
||||||
gateway: 172.30.0.1
|
|
||||||
EOF
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
render_npm_advanced_config() {
|
render_npm_advanced_config() {
|
||||||
local upstream_host=$(get_upstream_host)
|
local upstream_host=$(get_upstream_host)
|
||||||
local relay_addr="${upstream_host}:${RELAY_HOST_PORT}"
|
local relay_addr="${upstream_host}:${RELAY_HOST_PORT}"
|
||||||
@@ -1836,36 +1424,6 @@ print_manual_instructions() {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
print_traefik_tcp_instructions() {
|
|
||||||
echo ""
|
|
||||||
echo "$MSG_SEPARATOR"
|
|
||||||
echo " TRAEFIK TCP PROXY SETUP"
|
|
||||||
echo "$MSG_SEPARATOR"
|
|
||||||
echo ""
|
|
||||||
echo "This configuration uses Traefik as a single entry point on port 443."
|
|
||||||
echo "Traefik handles TLS termination with Let's Encrypt and routes to services."
|
|
||||||
echo ""
|
|
||||||
echo "Open ports:"
|
|
||||||
echo " - 443/tcp (HTTPS - all NetBird services)"
|
|
||||||
echo " - $NETBIRD_STUN_PORT/udp (STUN - required for NAT traversal)"
|
|
||||||
echo ""
|
|
||||||
echo "Generated files:"
|
|
||||||
echo " - docker-compose.yml (container definitions with Traefik labels)"
|
|
||||||
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
|
||||||
echo " - proxy.env (NetBird Proxy configuration)"
|
|
||||||
echo ""
|
|
||||||
echo "NetBird Proxy:"
|
|
||||||
echo " The proxy service is enabled and will be built from source."
|
|
||||||
echo " Any domain NOT matching $NETBIRD_DOMAIN will be passed through to the proxy."
|
|
||||||
echo " The proxy handles its own TLS certificates via ACME TLS-ALPN-01 challenge."
|
|
||||||
echo " Point your proxy domains (CNAMEs) to this server's IP address."
|
|
||||||
fi
|
|
||||||
echo ""
|
|
||||||
echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
|
||||||
echo "Follow the onboarding steps to set up your NetBird instance."
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
print_post_setup_instructions() {
|
print_post_setup_instructions() {
|
||||||
case "$REVERSE_PROXY_TYPE" in
|
case "$REVERSE_PROXY_TYPE" in
|
||||||
0)
|
0)
|
||||||
@@ -1886,9 +1444,6 @@ print_post_setup_instructions() {
|
|||||||
5)
|
5)
|
||||||
print_manual_instructions
|
print_manual_instructions
|
||||||
;;
|
;;
|
||||||
6)
|
|
||||||
print_traefik_tcp_instructions
|
|
||||||
;;
|
|
||||||
*)
|
*)
|
||||||
echo "Unknown reverse proxy type: $REVERSE_PROXY_TYPE" > /dev/stderr
|
echo "Unknown reverse proxy type: $REVERSE_PROXY_TYPE" > /dev/stderr
|
||||||
;;
|
;;
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
FROM golang:1.25-bookworm AS builder
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# Install build dependencies
|
|
||||||
RUN apt-get update && apt-get install -y gcc libc6-dev && rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
COPY go.mod go.sum ./
|
|
||||||
RUN go mod download
|
|
||||||
|
|
||||||
COPY . .
|
|
||||||
RUN CGO_ENABLED=1 GOOS=linux go build -ldflags="-s -w" -o netbird-mgmt ./management
|
|
||||||
|
|
||||||
FROM ubuntu:24.04
|
|
||||||
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
|
||||||
ENTRYPOINT [ "/go/bin/netbird-mgmt","management"]
|
|
||||||
CMD ["--log-file", "console"]
|
|
||||||
COPY --from=builder /app/netbird-mgmt /go/bin/netbird-mgmt
|
|
||||||
@@ -19,8 +19,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/formatter/hook"
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
"github.com/netbirdio/netbird/management/internals/server"
|
"github.com/netbirdio/netbird/management/internals/server"
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
@@ -215,14 +213,11 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
|||||||
// Set HttpConfig values from EmbeddedIdP
|
// Set HttpConfig values from EmbeddedIdP
|
||||||
cfg.HttpConfig.AuthIssuer = issuer
|
cfg.HttpConfig.AuthIssuer = issuer
|
||||||
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
||||||
cfg.HttpConfig.AuthClientID = cfg.HttpConfig.AuthAudience
|
|
||||||
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
|
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
|
||||||
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
||||||
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
||||||
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
||||||
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
||||||
callbackURL := strings.TrimSuffix(cfg.HttpConfig.AuthIssuer, "/oauth2")
|
|
||||||
cfg.HttpConfig.AuthCallbackURL = callbackURL + types.ProxyCallbackEndpointFull
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,10 +80,4 @@ func init() {
|
|||||||
migrationCmd.AddCommand(upCmd)
|
migrationCmd.AddCommand(upCmd)
|
||||||
|
|
||||||
rootCmd.AddCommand(migrationCmd)
|
rootCmd.AddCommand(migrationCmd)
|
||||||
|
|
||||||
tokenCmd.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
|
||||||
tokenCmd.AddCommand(tokenCreateCmd)
|
|
||||||
tokenCmd.AddCommand(tokenListCmd)
|
|
||||||
tokenCmd.AddCommand(tokenRevokeCmd)
|
|
||||||
rootCmd.AddCommand(tokenCmd)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,209 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"text/tabwriter"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/formatter/hook"
|
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
tokenName string
|
|
||||||
tokenExpireIn string
|
|
||||||
tokenDatadir string
|
|
||||||
|
|
||||||
tokenCmd = &cobra.Command{
|
|
||||||
Use: "token",
|
|
||||||
Short: "Manage proxy access tokens",
|
|
||||||
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenCreateCmd = &cobra.Command{
|
|
||||||
Use: "create",
|
|
||||||
Short: "Create a new proxy access token",
|
|
||||||
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
|
|
||||||
RunE: tokenCreateRun,
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenListCmd = &cobra.Command{
|
|
||||||
Use: "list",
|
|
||||||
Aliases: []string{"ls"},
|
|
||||||
Short: "List all proxy access tokens",
|
|
||||||
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
|
|
||||||
RunE: tokenListRun,
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenRevokeCmd = &cobra.Command{
|
|
||||||
Use: "revoke [token-id]",
|
|
||||||
Short: "Revoke a proxy access token",
|
|
||||||
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
|
|
||||||
Args: cobra.ExactArgs(1),
|
|
||||||
RunE: tokenRevokeRun,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
tokenCmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
|
|
||||||
|
|
||||||
tokenCreateCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
|
|
||||||
tokenCreateCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
|
|
||||||
tokenCreateCmd.MarkFlagRequired("name") //nolint
|
|
||||||
}
|
|
||||||
|
|
||||||
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
|
|
||||||
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
|
||||||
if err := util.InitLog("error", "console"); err != nil {
|
|
||||||
return fmt.Errorf("init log: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint
|
|
||||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource)
|
|
||||||
|
|
||||||
config, err := loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("load config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
datadir := config.Datadir
|
|
||||||
if tokenDatadir != "" {
|
|
||||||
datadir = tokenDatadir
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create store: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := s.Close(ctx); err != nil {
|
|
||||||
log.Debugf("close store: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return fn(ctx, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func tokenCreateRun(cmd *cobra.Command, _ []string) error {
|
|
||||||
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
|
||||||
expiresIn, err := parseDuration(tokenExpireIn)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse expiration: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
generated, err := types.CreateNewProxyAccessToken(tokenName, expiresIn, nil, "CLI")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("generate token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
|
|
||||||
return fmt.Errorf("save token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("Token created successfully!") //nolint:forbidigo
|
|
||||||
fmt.Printf("Token: %s\n", generated.PlainToken) //nolint:forbidigo
|
|
||||||
fmt.Println() //nolint:forbidigo
|
|
||||||
fmt.Println("IMPORTANT: Save this token now. It will not be shown again.") //nolint:forbidigo
|
|
||||||
fmt.Printf("Token ID: %s\n", generated.ID) //nolint:forbidigo
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func tokenListRun(cmd *cobra.Command, _ []string) error {
|
|
||||||
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
|
||||||
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("list tokens: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(tokens) == 0 {
|
|
||||||
fmt.Println("No proxy access tokens found.") //nolint:forbidigo
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
|
||||||
fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
|
|
||||||
fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
|
|
||||||
|
|
||||||
for _, t := range tokens {
|
|
||||||
expires := "never"
|
|
||||||
if t.ExpiresAt != nil {
|
|
||||||
expires = t.ExpiresAt.Format("2006-01-02")
|
|
||||||
}
|
|
||||||
|
|
||||||
lastUsed := "never"
|
|
||||||
if t.LastUsed != nil {
|
|
||||||
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
|
|
||||||
}
|
|
||||||
|
|
||||||
revoked := "no"
|
|
||||||
if t.Revoked {
|
|
||||||
revoked = "yes"
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
|
|
||||||
t.ID,
|
|
||||||
t.Name,
|
|
||||||
t.CreatedAt.Format("2006-01-02"),
|
|
||||||
expires,
|
|
||||||
lastUsed,
|
|
||||||
revoked,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Flush()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func tokenRevokeRun(cmd *cobra.Command, args []string) error {
|
|
||||||
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
|
||||||
tokenID := args[0]
|
|
||||||
|
|
||||||
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
|
|
||||||
return fmt.Errorf("revoke token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Token %s revoked successfully.\n", tokenID) //nolint:forbidigo
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseDuration parses a duration string with support for days (e.g., "30d", "365d").
|
|
||||||
// An empty string returns zero duration (no expiration).
|
|
||||||
func parseDuration(s string) (time.Duration, error) {
|
|
||||||
if len(s) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if s[len(s)-1] == 'd' {
|
|
||||||
d, err := strconv.Atoi(s[:len(s)-1])
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("invalid day format: %s", s)
|
|
||||||
}
|
|
||||||
if d <= 0 {
|
|
||||||
return 0, fmt.Errorf("duration must be positive: %s", s)
|
|
||||||
}
|
|
||||||
return time.Duration(d) * 24 * time.Hour, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := time.ParseDuration(s)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if d <= 0 {
|
|
||||||
return 0, fmt.Errorf("duration must be positive: %s", s)
|
|
||||||
}
|
|
||||||
return d, nil
|
|
||||||
}
|
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestParseDuration(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
expected time.Duration
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "empty string returns zero",
|
|
||||||
input: "",
|
|
||||||
expected: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "days suffix",
|
|
||||||
input: "30d",
|
|
||||||
expected: 30 * 24 * time.Hour,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "one day",
|
|
||||||
input: "1d",
|
|
||||||
expected: 24 * time.Hour,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "365 days",
|
|
||||||
input: "365d",
|
|
||||||
expected: 365 * 24 * time.Hour,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "hours via Go duration",
|
|
||||||
input: "24h",
|
|
||||||
expected: 24 * time.Hour,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "minutes via Go duration",
|
|
||||||
input: "30m",
|
|
||||||
expected: 30 * time.Minute,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "complex Go duration",
|
|
||||||
input: "1h30m",
|
|
||||||
expected: 90 * time.Minute,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid day format",
|
|
||||||
input: "abcd",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative days",
|
|
||||||
input: "-1d",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero days",
|
|
||||||
input: "0d",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-numeric days",
|
|
||||||
input: "xyzd",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative Go duration",
|
|
||||||
input: "-24h",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero Go duration",
|
|
||||||
input: "0s",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid Go duration",
|
|
||||||
input: "notaduration",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result, err := parseDuration(tt.input)
|
|
||||||
if tt.wantErr {
|
|
||||||
assert.Error(t, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, tt.expected, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -174,7 +174,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
semaphore := make(chan struct{}, 10)
|
semaphore := make(chan struct{}, 10)
|
||||||
|
|
||||||
account.InjectProxyPolicies(ctx)
|
|
||||||
dnsCache := &cache.DNSConfigCache{}
|
dnsCache := &cache.DNSConfigCache{}
|
||||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
@@ -248,10 +247,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||||
|
|
||||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update})
|
||||||
Update: update,
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
})
|
|
||||||
}(peer)
|
}(peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -327,7 +323,6 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
return fmt.Errorf("failed to get validated peers: %v", err)
|
return fmt.Errorf("failed to get validated peers: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
account.InjectProxyPolicies(ctx)
|
|
||||||
dnsCache := &cache.DNSConfigCache{}
|
dnsCache := &cache.DNSConfigCache{}
|
||||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
@@ -375,10 +370,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||||
|
|
||||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
|
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update})
|
||||||
Update: update,
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -443,8 +435,6 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
account.InjectProxyPolicies(ctx)
|
|
||||||
|
|
||||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, 0, err
|
return nil, nil, nil, 0, err
|
||||||
@@ -788,7 +778,6 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
})
|
})
|
||||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||||
|
|
||||||
@@ -851,7 +840,6 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
|||||||
if c.experimentalNetworkMap(peer.AccountID) {
|
if c.experimentalNetworkMap(peer.AccountID) {
|
||||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
||||||
} else {
|
} else {
|
||||||
account.InjectProxyPolicies(ctx)
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
|
|||||||
@@ -25,14 +25,11 @@ func TestCreateChannel(t *testing.T) {
|
|||||||
func TestSendUpdate(t *testing.T) {
|
func TestSendUpdate(t *testing.T) {
|
||||||
peer := "test-sendupdate"
|
peer := "test-sendupdate"
|
||||||
peersUpdater := NewPeersUpdateManager(nil)
|
peersUpdater := NewPeersUpdateManager(nil)
|
||||||
update1 := &network_map.UpdateMessage{
|
update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
|
||||||
Update: &proto.SyncResponse{
|
NetworkMap: &proto.NetworkMap{
|
||||||
NetworkMap: &proto.NetworkMap{
|
Serial: 0,
|
||||||
Serial: 0,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
}}
|
||||||
}
|
|
||||||
_ = peersUpdater.CreateChannel(context.Background(), peer)
|
_ = peersUpdater.CreateChannel(context.Background(), peer)
|
||||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||||
t.Error("Error creating the channel")
|
t.Error("Error creating the channel")
|
||||||
@@ -48,14 +45,11 @@ func TestSendUpdate(t *testing.T) {
|
|||||||
peersUpdater.SendUpdate(context.Background(), peer, update1)
|
peersUpdater.SendUpdate(context.Background(), peer, update1)
|
||||||
}
|
}
|
||||||
|
|
||||||
update2 := &network_map.UpdateMessage{
|
update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
|
||||||
Update: &proto.SyncResponse{
|
NetworkMap: &proto.NetworkMap{
|
||||||
NetworkMap: &proto.NetworkMap{
|
Serial: 10,
|
||||||
Serial: 10,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
}}
|
||||||
}
|
|
||||||
|
|
||||||
peersUpdater.SendUpdate(context.Background(), peer, update2)
|
peersUpdater.SendUpdate(context.Background(), peer, update2)
|
||||||
timeout := time.After(5 * time.Second)
|
timeout := time.After(5 * time.Second)
|
||||||
|
|||||||
@@ -4,19 +4,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MessageType indicates the type of update message for debouncing strategy
|
|
||||||
type MessageType int
|
|
||||||
|
|
||||||
const (
|
|
||||||
// MessageTypeNetworkMap represents network map updates (peers, routes, DNS, firewall)
|
|
||||||
// These updates can be safely debounced - only the latest state matters
|
|
||||||
MessageTypeNetworkMap MessageType = iota
|
|
||||||
// MessageTypeControlConfig represents control/config updates (tokens, peer expiration)
|
|
||||||
// These updates should not be dropped as they contain time-sensitive information
|
|
||||||
MessageTypeControlConfig
|
|
||||||
)
|
|
||||||
|
|
||||||
type UpdateMessage struct {
|
type UpdateMessage struct {
|
||||||
Update *proto.SyncResponse
|
Update *proto.SyncResponse
|
||||||
MessageType MessageType
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
@@ -33,7 +32,6 @@ type Manager interface {
|
|||||||
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
|
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
|
||||||
SetAccountManager(accountManager account.Manager)
|
SetAccountManager(accountManager account.Manager)
|
||||||
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
||||||
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type managerImpl struct {
|
type managerImpl struct {
|
||||||
@@ -184,36 +182,3 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
|||||||
func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
||||||
return m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
return m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
|
||||||
existingPeerID, err := m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
|
||||||
if err == nil && existingPeerID != "" {
|
|
||||||
// Peer already exists
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
name := fmt.Sprintf("proxy-%s", xid.New().String())
|
|
||||||
peer := &peer.Peer{
|
|
||||||
Ephemeral: true,
|
|
||||||
ProxyMeta: peer.ProxyMeta{
|
|
||||||
Cluster: cluster,
|
|
||||||
Embedded: true,
|
|
||||||
},
|
|
||||||
Name: name,
|
|
||||||
Key: peerKey,
|
|
||||||
LoginExpirationEnabled: false,
|
|
||||||
InactivityExpirationEnabled: false,
|
|
||||||
Meta: peer.PeerSystemMeta{
|
|
||||||
Hostname: name,
|
|
||||||
GoOS: "proxy",
|
|
||||||
OS: "proxy",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, false)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create proxy peer: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -162,17 +162,3 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateProxyPeer mocks base method.
|
|
||||||
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
|
|
||||||
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,105 +0,0 @@
|
|||||||
package accesslogs
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/peer"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AccessLogEntry struct {
|
|
||||||
ID string `gorm:"primaryKey"`
|
|
||||||
AccountID string `gorm:"index"`
|
|
||||||
ServiceID string `gorm:"index"`
|
|
||||||
Timestamp time.Time `gorm:"index"`
|
|
||||||
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
|
|
||||||
Method string `gorm:"index"`
|
|
||||||
Host string `gorm:"index"`
|
|
||||||
Path string `gorm:"index"`
|
|
||||||
Duration time.Duration `gorm:"index"`
|
|
||||||
StatusCode int `gorm:"index"`
|
|
||||||
Reason string
|
|
||||||
UserId string `gorm:"index"`
|
|
||||||
AuthMethodUsed string `gorm:"index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FromProto creates an AccessLogEntry from a proto.AccessLog
|
|
||||||
func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
|
|
||||||
a.ID = serviceLog.GetLogId()
|
|
||||||
a.ServiceID = serviceLog.GetServiceId()
|
|
||||||
a.Timestamp = serviceLog.GetTimestamp().AsTime()
|
|
||||||
a.Method = serviceLog.GetMethod()
|
|
||||||
a.Host = serviceLog.GetHost()
|
|
||||||
a.Path = serviceLog.GetPath()
|
|
||||||
a.Duration = time.Duration(serviceLog.GetDurationMs()) * time.Millisecond
|
|
||||||
a.StatusCode = int(serviceLog.GetResponseCode())
|
|
||||||
a.UserId = serviceLog.GetUserId()
|
|
||||||
a.AuthMethodUsed = serviceLog.GetAuthMechanism()
|
|
||||||
a.AccountID = serviceLog.GetAccountId()
|
|
||||||
|
|
||||||
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
|
|
||||||
if ip, err := netip.ParseAddr(sourceIP); err == nil {
|
|
||||||
a.GeoLocation.ConnectionIP = net.IP(ip.AsSlice())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !serviceLog.GetAuthSuccess() {
|
|
||||||
a.Reason = "Authentication failed"
|
|
||||||
} else if serviceLog.GetResponseCode() >= 400 {
|
|
||||||
a.Reason = "Request failed"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToAPIResponse converts an AccessLogEntry to the API ProxyAccessLog type
|
|
||||||
func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
|
|
||||||
var sourceIP *string
|
|
||||||
if a.GeoLocation.ConnectionIP != nil {
|
|
||||||
ip := a.GeoLocation.ConnectionIP.String()
|
|
||||||
sourceIP = &ip
|
|
||||||
}
|
|
||||||
|
|
||||||
var reason *string
|
|
||||||
if a.Reason != "" {
|
|
||||||
reason = &a.Reason
|
|
||||||
}
|
|
||||||
|
|
||||||
var userID *string
|
|
||||||
if a.UserId != "" {
|
|
||||||
userID = &a.UserId
|
|
||||||
}
|
|
||||||
|
|
||||||
var authMethod *string
|
|
||||||
if a.AuthMethodUsed != "" {
|
|
||||||
authMethod = &a.AuthMethodUsed
|
|
||||||
}
|
|
||||||
|
|
||||||
var countryCode *string
|
|
||||||
if a.GeoLocation.CountryCode != "" {
|
|
||||||
countryCode = &a.GeoLocation.CountryCode
|
|
||||||
}
|
|
||||||
|
|
||||||
var cityName *string
|
|
||||||
if a.GeoLocation.CityName != "" {
|
|
||||||
cityName = &a.GeoLocation.CityName
|
|
||||||
}
|
|
||||||
|
|
||||||
return &api.ProxyAccessLog{
|
|
||||||
Id: a.ID,
|
|
||||||
ServiceId: a.ServiceID,
|
|
||||||
Timestamp: a.Timestamp,
|
|
||||||
Method: a.Method,
|
|
||||||
Host: a.Host,
|
|
||||||
Path: a.Path,
|
|
||||||
DurationMs: int(a.Duration.Milliseconds()),
|
|
||||||
StatusCode: a.StatusCode,
|
|
||||||
SourceIp: sourceIP,
|
|
||||||
Reason: reason,
|
|
||||||
UserId: userID,
|
|
||||||
AuthMethodUsed: authMethod,
|
|
||||||
CountryCode: countryCode,
|
|
||||||
CityName: cityName,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,124 +0,0 @@
|
|||||||
package accesslogs
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// DefaultPageSize is the default number of records per page
|
|
||||||
DefaultPageSize = 50
|
|
||||||
// MaxPageSize is the maximum number of records allowed per page
|
|
||||||
MaxPageSize = 100
|
|
||||||
)
|
|
||||||
|
|
||||||
// AccessLogFilter holds pagination and filtering parameters for access logs
|
|
||||||
type AccessLogFilter struct {
|
|
||||||
// Page is the current page number (1-indexed)
|
|
||||||
Page int
|
|
||||||
// PageSize is the number of records per page
|
|
||||||
PageSize int
|
|
||||||
|
|
||||||
// Filtering parameters
|
|
||||||
Search *string // General search across log ID, host, path, source IP, and user fields
|
|
||||||
SourceIP *string // Filter by source IP address
|
|
||||||
Host *string // Filter by host header
|
|
||||||
Path *string // Filter by request path (supports LIKE pattern)
|
|
||||||
UserID *string // Filter by authenticated user ID
|
|
||||||
UserEmail *string // Filter by user email (requires user lookup)
|
|
||||||
UserName *string // Filter by user name (requires user lookup)
|
|
||||||
Method *string // Filter by HTTP method
|
|
||||||
Status *string // Filter by status: "success" (2xx/3xx) or "failed" (1xx/4xx/5xx)
|
|
||||||
StatusCode *int // Filter by HTTP status code
|
|
||||||
StartDate *time.Time // Filter by timestamp >= start_date
|
|
||||||
EndDate *time.Time // Filter by timestamp <= end_date
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseFromRequest parses pagination and filter parameters from HTTP request query parameters
|
|
||||||
func (f *AccessLogFilter) ParseFromRequest(r *http.Request) {
|
|
||||||
queryParams := r.URL.Query()
|
|
||||||
|
|
||||||
f.Page = 1
|
|
||||||
if pageStr := queryParams.Get("page"); pageStr != "" {
|
|
||||||
if page, err := strconv.Atoi(pageStr); err == nil && page > 0 {
|
|
||||||
f.Page = page
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
f.PageSize = DefaultPageSize
|
|
||||||
if pageSizeStr := queryParams.Get("page_size"); pageSizeStr != "" {
|
|
||||||
if pageSize, err := strconv.Atoi(pageSizeStr); err == nil && pageSize > 0 {
|
|
||||||
f.PageSize = pageSize
|
|
||||||
if f.PageSize > MaxPageSize {
|
|
||||||
f.PageSize = MaxPageSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if search := queryParams.Get("search"); search != "" {
|
|
||||||
f.Search = &search
|
|
||||||
}
|
|
||||||
|
|
||||||
if sourceIP := queryParams.Get("source_ip"); sourceIP != "" {
|
|
||||||
f.SourceIP = &sourceIP
|
|
||||||
}
|
|
||||||
|
|
||||||
if host := queryParams.Get("host"); host != "" {
|
|
||||||
f.Host = &host
|
|
||||||
}
|
|
||||||
|
|
||||||
if path := queryParams.Get("path"); path != "" {
|
|
||||||
f.Path = &path
|
|
||||||
}
|
|
||||||
|
|
||||||
if userID := queryParams.Get("user_id"); userID != "" {
|
|
||||||
f.UserID = &userID
|
|
||||||
}
|
|
||||||
|
|
||||||
if userEmail := queryParams.Get("user_email"); userEmail != "" {
|
|
||||||
f.UserEmail = &userEmail
|
|
||||||
}
|
|
||||||
|
|
||||||
if userName := queryParams.Get("user_name"); userName != "" {
|
|
||||||
f.UserName = &userName
|
|
||||||
}
|
|
||||||
|
|
||||||
if method := queryParams.Get("method"); method != "" {
|
|
||||||
f.Method = &method
|
|
||||||
}
|
|
||||||
|
|
||||||
if status := queryParams.Get("status"); status != "" {
|
|
||||||
f.Status = &status
|
|
||||||
}
|
|
||||||
|
|
||||||
if statusCodeStr := queryParams.Get("status_code"); statusCodeStr != "" {
|
|
||||||
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil && statusCode > 0 {
|
|
||||||
f.StatusCode = &statusCode
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if startDate := queryParams.Get("start_date"); startDate != "" {
|
|
||||||
parsedStartDate, err := time.Parse(time.RFC3339, startDate)
|
|
||||||
if err == nil {
|
|
||||||
f.StartDate = &parsedStartDate
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if endDate := queryParams.Get("end_date"); endDate != "" {
|
|
||||||
parsedEndDate, err := time.Parse(time.RFC3339, endDate)
|
|
||||||
if err == nil {
|
|
||||||
f.EndDate = &parsedEndDate
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetOffset calculates the database offset for pagination
|
|
||||||
func (f *AccessLogFilter) GetOffset() int {
|
|
||||||
return (f.Page - 1) * f.PageSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLimit returns the page size for database queries
|
|
||||||
func (f *AccessLogFilter) GetLimit() int {
|
|
||||||
return f.PageSize
|
|
||||||
}
|
|
||||||
@@ -1,161 +0,0 @@
|
|||||||
package accesslogs
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAccessLogFilter_ParseFromRequest(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
queryParams map[string]string
|
|
||||||
expectedPage int
|
|
||||||
expectedPageSize int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "default values when no params provided",
|
|
||||||
queryParams: map[string]string{},
|
|
||||||
expectedPage: 1,
|
|
||||||
expectedPageSize: DefaultPageSize,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid page and page_size",
|
|
||||||
queryParams: map[string]string{
|
|
||||||
"page": "2",
|
|
||||||
"page_size": "25",
|
|
||||||
},
|
|
||||||
expectedPage: 2,
|
|
||||||
expectedPageSize: 25,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "page_size exceeds max, should cap at MaxPageSize",
|
|
||||||
queryParams: map[string]string{
|
|
||||||
"page": "1",
|
|
||||||
"page_size": "200",
|
|
||||||
},
|
|
||||||
expectedPage: 1,
|
|
||||||
expectedPageSize: MaxPageSize,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid page number, should use default",
|
|
||||||
queryParams: map[string]string{
|
|
||||||
"page": "invalid",
|
|
||||||
"page_size": "10",
|
|
||||||
},
|
|
||||||
expectedPage: 1,
|
|
||||||
expectedPageSize: 10,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid page_size, should use default",
|
|
||||||
queryParams: map[string]string{
|
|
||||||
"page": "2",
|
|
||||||
"page_size": "invalid",
|
|
||||||
},
|
|
||||||
expectedPage: 2,
|
|
||||||
expectedPageSize: DefaultPageSize,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero page number, should use default",
|
|
||||||
queryParams: map[string]string{
|
|
||||||
"page": "0",
|
|
||||||
"page_size": "10",
|
|
||||||
},
|
|
||||||
expectedPage: 1,
|
|
||||||
expectedPageSize: 10,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative page number, should use default",
|
|
||||||
queryParams: map[string]string{
|
|
||||||
"page": "-1",
|
|
||||||
"page_size": "10",
|
|
||||||
},
|
|
||||||
expectedPage: 1,
|
|
||||||
expectedPageSize: 10,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero page_size, should use default",
|
|
||||||
queryParams: map[string]string{
|
|
||||||
"page": "1",
|
|
||||||
"page_size": "0",
|
|
||||||
},
|
|
||||||
expectedPage: 1,
|
|
||||||
expectedPageSize: DefaultPageSize,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
||||||
q := req.URL.Query()
|
|
||||||
for key, value := range tt.queryParams {
|
|
||||||
q.Set(key, value)
|
|
||||||
}
|
|
||||||
req.URL.RawQuery = q.Encode()
|
|
||||||
|
|
||||||
filter := &AccessLogFilter{}
|
|
||||||
filter.ParseFromRequest(req)
|
|
||||||
|
|
||||||
assert.Equal(t, tt.expectedPage, filter.Page, "Page mismatch")
|
|
||||||
assert.Equal(t, tt.expectedPageSize, filter.PageSize, "PageSize mismatch")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccessLogFilter_GetOffset(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
page int
|
|
||||||
pageSize int
|
|
||||||
expectedOffset int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "first page",
|
|
||||||
page: 1,
|
|
||||||
pageSize: 50,
|
|
||||||
expectedOffset: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "second page",
|
|
||||||
page: 2,
|
|
||||||
pageSize: 50,
|
|
||||||
expectedOffset: 50,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "third page with page size 25",
|
|
||||||
page: 3,
|
|
||||||
pageSize: 25,
|
|
||||||
expectedOffset: 50,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "page 10 with page size 10",
|
|
||||||
page: 10,
|
|
||||||
pageSize: 10,
|
|
||||||
expectedOffset: 90,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
filter := &AccessLogFilter{
|
|
||||||
Page: tt.page,
|
|
||||||
PageSize: tt.pageSize,
|
|
||||||
}
|
|
||||||
|
|
||||||
offset := filter.GetOffset()
|
|
||||||
assert.Equal(t, tt.expectedOffset, offset)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccessLogFilter_GetLimit(t *testing.T) {
|
|
||||||
filter := &AccessLogFilter{
|
|
||||||
Page: 2,
|
|
||||||
PageSize: 25,
|
|
||||||
}
|
|
||||||
|
|
||||||
limit := filter.GetLimit()
|
|
||||||
assert.Equal(t, 25, limit, "GetLimit should return PageSize")
|
|
||||||
}
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
package accesslogs
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Manager interface {
|
|
||||||
SaveAccessLog(ctx context.Context, proxyLog *AccessLogEntry) error
|
|
||||||
GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *AccessLogFilter) ([]*AccessLogEntry, int64, error)
|
|
||||||
}
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
package manager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
type handler struct {
|
|
||||||
manager accesslogs.Manager
|
|
||||||
}
|
|
||||||
|
|
||||||
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager) {
|
|
||||||
h := &handler{
|
|
||||||
manager: manager,
|
|
||||||
}
|
|
||||||
|
|
||||||
router.HandleFunc("/events/proxy", h.getAccessLogs).Methods("GET", "OPTIONS")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var filter accesslogs.AccessLogFilter
|
|
||||||
filter.ParseFromRequest(r)
|
|
||||||
|
|
||||||
logs, totalCount, err := h.manager.GetAllAccessLogs(r.Context(), userAuth.AccountId, userAuth.UserId, &filter)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
apiLogs := make([]api.ProxyAccessLog, 0, len(logs))
|
|
||||||
for _, log := range logs {
|
|
||||||
apiLogs = append(apiLogs, *log.ToAPIResponse())
|
|
||||||
}
|
|
||||||
|
|
||||||
response := &api.ProxyAccessLogsResponse{
|
|
||||||
Data: apiLogs,
|
|
||||||
Page: filter.Page,
|
|
||||||
PageSize: filter.PageSize,
|
|
||||||
TotalRecords: int(totalCount),
|
|
||||||
TotalPages: getTotalPageCount(int(totalCount), filter.PageSize),
|
|
||||||
}
|
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, response)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getTotalPageCount calculates the total number of pages
|
|
||||||
func getTotalPageCount(totalCount, pageSize int) int {
|
|
||||||
if pageSize <= 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return (totalCount + pageSize - 1) / pageSize
|
|
||||||
}
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
package manager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
)
|
|
||||||
|
|
||||||
type managerImpl struct {
|
|
||||||
store store.Store
|
|
||||||
permissionsManager permissions.Manager
|
|
||||||
geo geolocation.Geolocation
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager {
|
|
||||||
return &managerImpl{
|
|
||||||
store: store,
|
|
||||||
permissionsManager: permissionsManager,
|
|
||||||
geo: geo,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveAccessLog saves an access log entry to the database after enriching it
|
|
||||||
func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.AccessLogEntry) error {
|
|
||||||
if m.geo != nil && logEntry.GeoLocation.ConnectionIP != nil {
|
|
||||||
location, err := m.geo.Lookup(logEntry.GeoLocation.ConnectionIP)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("failed to get location for access log source IP [%s]: %v", logEntry.GeoLocation.ConnectionIP.String(), err)
|
|
||||||
} else {
|
|
||||||
logEntry.GeoLocation.CountryCode = location.Country.ISOCode
|
|
||||||
logEntry.GeoLocation.CityName = location.City.Names.En
|
|
||||||
logEntry.GeoLocation.GeoNameID = location.City.GeonameID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.store.CreateAccessLog(ctx, logEntry); err != nil {
|
|
||||||
log.WithContext(ctx).WithFields(log.Fields{
|
|
||||||
"service_id": logEntry.ServiceID,
|
|
||||||
"method": logEntry.Method,
|
|
||||||
"host": logEntry.Host,
|
|
||||||
"path": logEntry.Path,
|
|
||||||
"status": logEntry.StatusCode,
|
|
||||||
}).Errorf("failed to save access log: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering
|
|
||||||
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
|
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return nil, 0, status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.resolveUserFilters(ctx, accountID, filter); err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("failed to resolve user filters: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logs, totalCount, err := m.store.GetAccountAccessLogs(ctx, store.LockingStrengthNone, accountID, *filter)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return logs, totalCount, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// resolveUserFilters converts user email/name filters to user ID filter
|
|
||||||
func (m *managerImpl) resolveUserFilters(ctx context.Context, accountID string, filter *accesslogs.AccessLogFilter) error {
|
|
||||||
if filter.UserEmail == nil && filter.UserName == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
users, err := m.store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var matchingUserIDs []string
|
|
||||||
for _, user := range users {
|
|
||||||
if filter.UserEmail != nil && strings.Contains(strings.ToLower(user.Email), strings.ToLower(*filter.UserEmail)) {
|
|
||||||
matchingUserIDs = append(matchingUserIDs, user.Id)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if filter.UserName != nil && strings.Contains(strings.ToLower(user.Name), strings.ToLower(*filter.UserName)) {
|
|
||||||
matchingUserIDs = append(matchingUserIDs, user.Id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(matchingUserIDs) > 0 {
|
|
||||||
filter.UserID = &matchingUserIDs[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
package domain
|
|
||||||
|
|
||||||
type Type string
|
|
||||||
|
|
||||||
const (
|
|
||||||
TypeFree Type = "free"
|
|
||||||
TypeCustom Type = "custom"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Domain struct {
|
|
||||||
ID string `gorm:"unique;primaryKey;autoIncrement"`
|
|
||||||
Domain string `gorm:"unique"` // Domain records must be unique, this avoids domain reuse across accounts.
|
|
||||||
AccountID string `gorm:"index"`
|
|
||||||
TargetCluster string // The proxy cluster this domain should be validated against
|
|
||||||
Type Type `gorm:"-"`
|
|
||||||
Validated bool
|
|
||||||
}
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
package domain
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Manager interface {
|
|
||||||
GetDomains(ctx context.Context, accountID, userID string) ([]*Domain, error)
|
|
||||||
CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*Domain, error)
|
|
||||||
DeleteDomain(ctx context.Context, accountID, userID, domainID string) error
|
|
||||||
ValidateDomain(ctx context.Context, accountID, userID, domainID string)
|
|
||||||
}
|
|
||||||
@@ -1,136 +0,0 @@
|
|||||||
package manager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
)
|
|
||||||
|
|
||||||
type handler struct {
|
|
||||||
manager Manager
|
|
||||||
}
|
|
||||||
|
|
||||||
func RegisterEndpoints(router *mux.Router, manager Manager) {
|
|
||||||
h := &handler{
|
|
||||||
manager: manager,
|
|
||||||
}
|
|
||||||
|
|
||||||
router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS")
|
|
||||||
router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS")
|
|
||||||
router.HandleFunc("/domains/{domainId}", h.deleteCustomDomain).Methods("DELETE", "OPTIONS")
|
|
||||||
router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS")
|
|
||||||
}
|
|
||||||
|
|
||||||
func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType {
|
|
||||||
switch t {
|
|
||||||
case domain.TypeCustom:
|
|
||||||
return api.ReverseProxyDomainTypeCustom
|
|
||||||
case domain.TypeFree:
|
|
||||||
return api.ReverseProxyDomainTypeFree
|
|
||||||
}
|
|
||||||
// By default return as a "free" domain as that is more restrictive.
|
|
||||||
// TODO: is this correct?
|
|
||||||
return api.ReverseProxyDomainTypeFree
|
|
||||||
}
|
|
||||||
|
|
||||||
func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
|
|
||||||
resp := api.ReverseProxyDomain{
|
|
||||||
Domain: d.Domain,
|
|
||||||
Id: d.ID,
|
|
||||||
Type: domainTypeToApi(d.Type),
|
|
||||||
Validated: d.Validated,
|
|
||||||
}
|
|
||||||
if d.TargetCluster != "" {
|
|
||||||
resp.TargetCluster = &d.TargetCluster
|
|
||||||
}
|
|
||||||
return resp
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ret := make([]api.ReverseProxyDomain, 0)
|
|
||||||
for _, d := range domains {
|
|
||||||
ret = append(ret, domainToApi(d))
|
|
||||||
}
|
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, ret)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req api.PostApiReverseProxiesDomainsJSONRequestBody
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
domain, err := h.manager.CreateDomain(r.Context(), userAuth.AccountId, userAuth.UserId, req.Domain, req.TargetCluster)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, domainToApi(domain))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
domainID := mux.Vars(r)["domainId"]
|
|
||||||
if domainID == "" {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.manager.DeleteDomain(r.Context(), userAuth.AccountId, userAuth.UserId, domainID); err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusNoContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
domainID := mux.Vars(r)["domainId"]
|
|
||||||
if domainID == "" {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go h.manager.ValidateDomain(r.Context(), userAuth.AccountId, userAuth.UserId, domainID)
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusAccepted)
|
|
||||||
}
|
|
||||||
@@ -1,279 +0,0 @@
|
|||||||
package manager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
)
|
|
||||||
|
|
||||||
type store interface {
|
|
||||||
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
|
|
||||||
|
|
||||||
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
|
|
||||||
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
|
|
||||||
ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error)
|
|
||||||
CreateCustomDomain(ctx context.Context, accountID string, domainName string, targetCluster string, validated bool) (*domain.Domain, error)
|
|
||||||
UpdateCustomDomain(ctx context.Context, accountID string, d *domain.Domain) (*domain.Domain, error)
|
|
||||||
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type proxyURLProvider interface {
|
|
||||||
GetConnectedProxyURLs() []string
|
|
||||||
}
|
|
||||||
|
|
||||||
type Manager struct {
|
|
||||||
store store
|
|
||||||
validator domain.Validator
|
|
||||||
proxyURLProvider proxyURLProvider
|
|
||||||
permissionsManager permissions.Manager
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager {
|
|
||||||
return Manager{
|
|
||||||
store: store,
|
|
||||||
proxyURLProvider: proxyURLProvider,
|
|
||||||
validator: domain.Validator{
|
|
||||||
Resolver: net.DefaultResolver,
|
|
||||||
},
|
|
||||||
permissionsManager: permissionsManager,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
|
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return nil, status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
|
|
||||||
domains, err := m.store.ListCustomDomains(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("list custom domains: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var ret []*domain.Domain
|
|
||||||
|
|
||||||
// Add connected proxy clusters as free domains.
|
|
||||||
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
|
|
||||||
allowList := m.proxyURLAllowList()
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"accountID": accountID,
|
|
||||||
"proxyAllowList": allowList,
|
|
||||||
}).Debug("getting domains with proxy allow list")
|
|
||||||
|
|
||||||
for _, cluster := range allowList {
|
|
||||||
ret = append(ret, &domain.Domain{
|
|
||||||
Domain: cluster,
|
|
||||||
AccountID: accountID,
|
|
||||||
Type: domain.TypeFree,
|
|
||||||
Validated: true,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add custom domains.
|
|
||||||
for _, d := range domains {
|
|
||||||
ret = append(ret, &domain.Domain{
|
|
||||||
ID: d.ID,
|
|
||||||
Domain: d.Domain,
|
|
||||||
AccountID: accountID,
|
|
||||||
TargetCluster: d.TargetCluster,
|
|
||||||
Type: domain.TypeCustom,
|
|
||||||
Validated: d.Validated,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) {
|
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return nil, status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the target cluster is in the available clusters
|
|
||||||
allowList := m.proxyURLAllowList()
|
|
||||||
clusterValid := false
|
|
||||||
for _, cluster := range allowList {
|
|
||||||
if cluster == targetCluster {
|
|
||||||
clusterValid = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !clusterValid {
|
|
||||||
return nil, fmt.Errorf("target cluster %s is not available", targetCluster)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt an initial validation against the specified cluster only
|
|
||||||
var validated bool
|
|
||||||
if m.validator.IsValid(ctx, domainName, []string{targetCluster}) {
|
|
||||||
validated = true
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := m.store.CreateCustomDomain(ctx, accountID, domainName, targetCluster, validated)
|
|
||||||
if err != nil {
|
|
||||||
return d, fmt.Errorf("create domain in store: %w", err)
|
|
||||||
}
|
|
||||||
return d, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error {
|
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
|
||||||
if err != nil {
|
|
||||||
return status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
|
|
||||||
// TODO: check for "no records" type error. Because that is a success condition.
|
|
||||||
return fmt.Errorf("delete domain from store: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) {
|
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
|
||||||
if err != nil {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"accountID": accountID,
|
|
||||||
"domainID": domainID,
|
|
||||||
}).WithError(err).Error("validate domain")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"accountID": accountID,
|
|
||||||
"domainID": domainID,
|
|
||||||
}).WithError(err).Error("validate domain")
|
|
||||||
}
|
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"accountID": accountID,
|
|
||||||
"domainID": domainID,
|
|
||||||
}).Info("starting domain validation")
|
|
||||||
|
|
||||||
d, err := m.store.GetCustomDomain(context.Background(), accountID, domainID)
|
|
||||||
if err != nil {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"accountID": accountID,
|
|
||||||
"domainID": domainID,
|
|
||||||
}).WithError(err).Error("get custom domain from store")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate only against the domain's target cluster
|
|
||||||
targetCluster := d.TargetCluster
|
|
||||||
if targetCluster == "" {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"accountID": accountID,
|
|
||||||
"domainID": domainID,
|
|
||||||
"domain": d.Domain,
|
|
||||||
}).Warn("domain has no target cluster set, skipping validation")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"accountID": accountID,
|
|
||||||
"domainID": domainID,
|
|
||||||
"domain": d.Domain,
|
|
||||||
"targetCluster": targetCluster,
|
|
||||||
}).Info("validating domain against target cluster")
|
|
||||||
|
|
||||||
if m.validator.IsValid(context.Background(), d.Domain, []string{targetCluster}) {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"accountID": accountID,
|
|
||||||
"domainID": domainID,
|
|
||||||
"domain": d.Domain,
|
|
||||||
}).Info("domain validated successfully")
|
|
||||||
d.Validated = true
|
|
||||||
if _, err := m.store.UpdateCustomDomain(context.Background(), accountID, d); err != nil {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"accountID": accountID,
|
|
||||||
"domainID": domainID,
|
|
||||||
"domain": d.Domain,
|
|
||||||
}).WithError(err).Error("update custom domain in store")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"accountID": accountID,
|
|
||||||
"domainID": domainID,
|
|
||||||
"domain": d.Domain,
|
|
||||||
"targetCluster": targetCluster,
|
|
||||||
}).Warn("domain validation failed - CNAME does not match target cluster")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// proxyURLAllowList retrieves a list of currently connected proxies and
|
|
||||||
// their URLs
|
|
||||||
func (m Manager) proxyURLAllowList() []string {
|
|
||||||
var reverseProxyAddresses []string
|
|
||||||
if m.proxyURLProvider != nil {
|
|
||||||
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
|
|
||||||
}
|
|
||||||
return reverseProxyAddresses
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
|
|
||||||
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
|
||||||
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
|
||||||
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
|
||||||
allowList := m.proxyURLAllowList()
|
|
||||||
if len(allowList) == 0 {
|
|
||||||
return "", fmt.Errorf("no proxy clusters available")
|
|
||||||
}
|
|
||||||
|
|
||||||
if cluster, ok := ExtractClusterFromFreeDomain(domain, allowList); ok {
|
|
||||||
return cluster, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
customDomains, err := m.store.ListCustomDomains(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("list custom domains: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
targetCluster, valid := extractClusterFromCustomDomains(domain, customDomains)
|
|
||||||
if valid {
|
|
||||||
return targetCluster, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractClusterFromCustomDomains(domain string, customDomains []*domain.Domain) (string, bool) {
|
|
||||||
for _, customDomain := range customDomains {
|
|
||||||
if strings.HasSuffix(domain, "."+customDomain.Domain) {
|
|
||||||
return customDomain.TargetCluster, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExtractClusterFromFreeDomain extracts the cluster address from a free domain.
|
|
||||||
// Free domains have the format: <name>.<nonce>.<cluster> (e.g., myapp.abc123.eu.proxy.netbird.io)
|
|
||||||
// It matches the domain suffix against available clusters and returns the matching cluster.
|
|
||||||
func ExtractClusterFromFreeDomain(domain string, availableClusters []string) (string, bool) {
|
|
||||||
for _, cluster := range availableClusters {
|
|
||||||
if strings.HasSuffix(domain, "."+cluster) {
|
|
||||||
return cluster, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
package domain
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
type resolver interface {
|
|
||||||
LookupCNAME(context.Context, string) (string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Validator struct {
|
|
||||||
Resolver resolver
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewValidator initializes a validator with a specific DNS Resolver.
|
|
||||||
// If a Validator is used without specifying a Resolver, then it will
|
|
||||||
// use the net.DefaultResolver.
|
|
||||||
func NewValidator(resolver resolver) *Validator {
|
|
||||||
return &Validator{
|
|
||||||
Resolver: resolver,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsValid looks up the CNAME record for the passed domain with a prefix
|
|
||||||
// and compares it against the acceptable domains.
|
|
||||||
// If the returned CNAME matches any accepted domain, it will return true,
|
|
||||||
// otherwise, including in the event of a DNS error, it will return false.
|
|
||||||
// The comparison is very simple, so wildcards will not match if included
|
|
||||||
// in the acceptable domain list.
|
|
||||||
func (v *Validator) IsValid(ctx context.Context, domain string, accept []string) bool {
|
|
||||||
_, valid := v.ValidateWithCluster(ctx, domain, accept)
|
|
||||||
return valid
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateWithCluster validates a custom domain and returns the matched cluster address.
|
|
||||||
// Returns the cluster address and true if valid, or empty string and false if invalid.
|
|
||||||
func (v *Validator) ValidateWithCluster(ctx context.Context, domain string, accept []string) (string, bool) {
|
|
||||||
if v.Resolver == nil {
|
|
||||||
v.Resolver = net.DefaultResolver
|
|
||||||
}
|
|
||||||
|
|
||||||
lookupDomain := "validation." + domain
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"domain": domain,
|
|
||||||
"lookupDomain": lookupDomain,
|
|
||||||
"acceptList": accept,
|
|
||||||
}).Debug("looking up CNAME for domain validation")
|
|
||||||
|
|
||||||
cname, err := v.Resolver.LookupCNAME(ctx, lookupDomain)
|
|
||||||
if err != nil {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"domain": domain,
|
|
||||||
"lookupDomain": lookupDomain,
|
|
||||||
}).WithError(err).Warn("CNAME lookup failed for domain validation")
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
nakedCNAME := strings.TrimSuffix(cname, ".")
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"domain": domain,
|
|
||||||
"cname": cname,
|
|
||||||
"nakedCNAME": nakedCNAME,
|
|
||||||
"acceptList": accept,
|
|
||||||
}).Debug("CNAME lookup result for domain validation")
|
|
||||||
|
|
||||||
for _, acceptDomain := range accept {
|
|
||||||
normalizedAccept := strings.TrimSuffix(acceptDomain, ".")
|
|
||||||
if nakedCNAME == normalizedAccept {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"domain": domain,
|
|
||||||
"cname": nakedCNAME,
|
|
||||||
"cluster": acceptDomain,
|
|
||||||
}).Info("domain CNAME matched cluster")
|
|
||||||
return acceptDomain, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"domain": domain,
|
|
||||||
"cname": nakedCNAME,
|
|
||||||
"acceptList": accept,
|
|
||||||
}).Warn("domain CNAME does not match any accepted cluster")
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
package domain_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
|
||||||
)
|
|
||||||
|
|
||||||
type resolver struct {
|
|
||||||
CNAME string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r resolver) LookupCNAME(_ context.Context, _ string) (string, error) {
|
|
||||||
return r.CNAME, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsValid(t *testing.T) {
|
|
||||||
tests := map[string]struct {
|
|
||||||
resolver interface {
|
|
||||||
LookupCNAME(context.Context, string) (string, error)
|
|
||||||
}
|
|
||||||
domain string
|
|
||||||
accept []string
|
|
||||||
expect bool
|
|
||||||
}{
|
|
||||||
"match": {
|
|
||||||
resolver: resolver{"bar.example.com."}, // Including trailing "." in response.
|
|
||||||
domain: "foo.example.com",
|
|
||||||
accept: []string{"bar.example.com"},
|
|
||||||
expect: true,
|
|
||||||
},
|
|
||||||
"no match": {
|
|
||||||
resolver: resolver{"invalid"},
|
|
||||||
domain: "foo.example.com",
|
|
||||||
accept: []string{"bar.example.com"},
|
|
||||||
expect: false,
|
|
||||||
},
|
|
||||||
"accept trailing dot": {
|
|
||||||
resolver: resolver{"bar.example.com."},
|
|
||||||
domain: "foo.example.com",
|
|
||||||
accept: []string{"bar.example.com."}, // Including trailing "." in accept.
|
|
||||||
expect: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, test := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
validator := domain.NewValidator(test.resolver)
|
|
||||||
actual := validator.IsValid(t.Context(), test.domain, test.accept)
|
|
||||||
if test.expect != actual {
|
|
||||||
t.Errorf("Incorrect return value:\nexpect: %v\nactual: %v", test.expect, actual)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
package reverseproxy
|
|
||||||
|
|
||||||
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Manager interface {
|
|
||||||
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
|
|
||||||
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
|
|
||||||
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
|
||||||
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
|
||||||
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
|
|
||||||
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
|
|
||||||
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
|
|
||||||
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
|
|
||||||
ReloadService(ctx context.Context, accountID, serviceID string) error
|
|
||||||
GetGlobalServices(ctx context.Context) ([]*Service, error)
|
|
||||||
GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error)
|
|
||||||
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
|
|
||||||
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
|
|
||||||
}
|
|
||||||
@@ -1,225 +0,0 @@
|
|||||||
// Code generated by MockGen. DO NOT EDIT.
|
|
||||||
// Source: ./interface.go
|
|
||||||
|
|
||||||
// Package reverseproxy is a generated GoMock package.
|
|
||||||
package reverseproxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
context "context"
|
|
||||||
reflect "reflect"
|
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MockManager is a mock of Manager interface.
|
|
||||||
type MockManager struct {
|
|
||||||
ctrl *gomock.Controller
|
|
||||||
recorder *MockManagerMockRecorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockManagerMockRecorder is the mock recorder for MockManager.
|
|
||||||
type MockManagerMockRecorder struct {
|
|
||||||
mock *MockManager
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMockManager creates a new mock instance.
|
|
||||||
func NewMockManager(ctrl *gomock.Controller) *MockManager {
|
|
||||||
mock := &MockManager{ctrl: ctrl}
|
|
||||||
mock.recorder = &MockManagerMockRecorder{mock}
|
|
||||||
return mock
|
|
||||||
}
|
|
||||||
|
|
||||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
|
||||||
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
|
||||||
return m.recorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateService mocks base method.
|
|
||||||
func (m *MockManager) CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "CreateService", ctx, accountID, userID, service)
|
|
||||||
ret0, _ := ret[0].(*Service)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateService indicates an expected call of CreateService.
|
|
||||||
func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteService mocks base method.
|
|
||||||
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "DeleteService", ctx, accountID, userID, serviceID)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteService indicates an expected call of DeleteService.
|
|
||||||
func (mr *MockManagerMockRecorder) DeleteService(ctx, accountID, userID, serviceID interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockManager)(nil).DeleteService), ctx, accountID, userID, serviceID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountServices mocks base method.
|
|
||||||
func (m *MockManager) GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "GetAccountServices", ctx, accountID)
|
|
||||||
ret0, _ := ret[0].([]*Service)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountServices indicates an expected call of GetAccountServices.
|
|
||||||
func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllServices mocks base method.
|
|
||||||
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "GetAllServices", ctx, accountID, userID)
|
|
||||||
ret0, _ := ret[0].([]*Service)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllServices indicates an expected call of GetAllServices.
|
|
||||||
func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetGlobalServices mocks base method.
|
|
||||||
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "GetGlobalServices", ctx)
|
|
||||||
ret0, _ := ret[0].([]*Service)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetGlobalServices indicates an expected call of GetGlobalServices.
|
|
||||||
func (mr *MockManagerMockRecorder) GetGlobalServices(ctx interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGlobalServices", reflect.TypeOf((*MockManager)(nil).GetGlobalServices), ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetService mocks base method.
|
|
||||||
func (m *MockManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "GetService", ctx, accountID, userID, serviceID)
|
|
||||||
ret0, _ := ret[0].(*Service)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetService indicates an expected call of GetService.
|
|
||||||
func (mr *MockManagerMockRecorder) GetService(ctx, accountID, userID, serviceID interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockManager)(nil).GetService), ctx, accountID, userID, serviceID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetServiceByID mocks base method.
|
|
||||||
func (m *MockManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "GetServiceByID", ctx, accountID, serviceID)
|
|
||||||
ret0, _ := ret[0].(*Service)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetServiceByID indicates an expected call of GetServiceByID.
|
|
||||||
func (mr *MockManagerMockRecorder) GetServiceByID(ctx, accountID, serviceID interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByID", reflect.TypeOf((*MockManager)(nil).GetServiceByID), ctx, accountID, serviceID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetServiceIDByTargetID mocks base method.
|
|
||||||
func (m *MockManager) GetServiceIDByTargetID(ctx context.Context, accountID, resourceID string) (string, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "GetServiceIDByTargetID", ctx, accountID, resourceID)
|
|
||||||
ret0, _ := ret[0].(string)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetServiceIDByTargetID indicates an expected call of GetServiceIDByTargetID.
|
|
||||||
func (mr *MockManagerMockRecorder) GetServiceIDByTargetID(ctx, accountID, resourceID interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceIDByTargetID", reflect.TypeOf((*MockManager)(nil).GetServiceIDByTargetID), ctx, accountID, resourceID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReloadAllServicesForAccount mocks base method.
|
|
||||||
func (m *MockManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "ReloadAllServicesForAccount", ctx, accountID)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReloadAllServicesForAccount indicates an expected call of ReloadAllServicesForAccount.
|
|
||||||
func (mr *MockManagerMockRecorder) ReloadAllServicesForAccount(ctx, accountID interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadAllServicesForAccount", reflect.TypeOf((*MockManager)(nil).ReloadAllServicesForAccount), ctx, accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReloadService mocks base method.
|
|
||||||
func (m *MockManager) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "ReloadService", ctx, accountID, serviceID)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReloadService indicates an expected call of ReloadService.
|
|
||||||
func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCertificateIssuedAt mocks base method.
|
|
||||||
func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "SetCertificateIssuedAt", ctx, accountID, serviceID)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCertificateIssuedAt indicates an expected call of SetCertificateIssuedAt.
|
|
||||||
func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, serviceID interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCertificateIssuedAt", reflect.TypeOf((*MockManager)(nil).SetCertificateIssuedAt), ctx, accountID, serviceID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetStatus mocks base method.
|
|
||||||
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetStatus indicates an expected call of SetStatus.
|
|
||||||
func (mr *MockManagerMockRecorder) SetStatus(ctx, accountID, serviceID, status interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateService mocks base method.
|
|
||||||
func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "UpdateService", ctx, accountID, userID, service)
|
|
||||||
ret0, _ := ret[0].(*Service)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateService indicates an expected call of UpdateService.
|
|
||||||
func (mr *MockManagerMockRecorder) UpdateService(ctx, accountID, userID, service interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockManager)(nil).UpdateService), ctx, accountID, userID, service)
|
|
||||||
}
|
|
||||||
@@ -1,170 +0,0 @@
|
|||||||
package manager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
|
||||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
|
||||||
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
)
|
|
||||||
|
|
||||||
type handler struct {
|
|
||||||
manager reverseproxy.Manager
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterEndpoints registers all service HTTP endpoints.
|
|
||||||
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
|
|
||||||
h := &handler{
|
|
||||||
manager: manager,
|
|
||||||
}
|
|
||||||
|
|
||||||
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
|
|
||||||
domainmanager.RegisterEndpoints(domainRouter, domainManager)
|
|
||||||
|
|
||||||
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
|
|
||||||
|
|
||||||
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
|
|
||||||
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
|
|
||||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
|
|
||||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS")
|
|
||||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
apiServices := make([]*api.Service, 0, len(allServices))
|
|
||||||
for _, service := range allServices {
|
|
||||||
apiServices = append(apiServices, service.ToAPIResponse())
|
|
||||||
}
|
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, apiServices)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req api.ServiceRequest
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
service := new(reverseproxy.Service)
|
|
||||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
|
||||||
|
|
||||||
if err = service.Validate(); err != nil {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
createdService, err := h.manager.CreateService(r.Context(), userAuth.AccountId, userAuth.UserId, service)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
serviceID := mux.Vars(r)["serviceId"]
|
|
||||||
if serviceID == "" {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
service, err := h.manager.GetService(r.Context(), userAuth.AccountId, userAuth.UserId, serviceID)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, service.ToAPIResponse())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
serviceID := mux.Vars(r)["serviceId"]
|
|
||||||
if serviceID == "" {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req api.ServiceRequest
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
service := new(reverseproxy.Service)
|
|
||||||
service.ID = serviceID
|
|
||||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
|
||||||
|
|
||||||
if err = service.Validate(); err != nil {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedService, err := h.manager.UpdateService(r.Context(), userAuth.AccountId, userAuth.UserId, service)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
serviceID := mux.Vars(r)["serviceId"]
|
|
||||||
if serviceID == "" {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.manager.DeleteService(r.Context(), userAuth.AccountId, userAuth.UserId, serviceID); err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
|
||||||
}
|
|
||||||
@@ -1,500 +0,0 @@
|
|||||||
package manager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
)
|
|
||||||
|
|
||||||
const unknownHostPlaceholder = "unknown"
|
|
||||||
|
|
||||||
// ClusterDeriver derives the proxy cluster from a domain.
|
|
||||||
type ClusterDeriver interface {
|
|
||||||
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type managerImpl struct {
|
|
||||||
store store.Store
|
|
||||||
accountManager account.Manager
|
|
||||||
permissionsManager permissions.Manager
|
|
||||||
proxyGRPCServer *nbgrpc.ProxyServiceServer
|
|
||||||
clusterDeriver ClusterDeriver
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewManager creates a new service manager.
|
|
||||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
|
|
||||||
return &managerImpl{
|
|
||||||
store: store,
|
|
||||||
accountManager: accountManager,
|
|
||||||
permissionsManager: permissionsManager,
|
|
||||||
proxyGRPCServer: proxyGRPCServer,
|
|
||||||
clusterDeriver: clusterDeriver,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return nil, status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
|
|
||||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, service := range services {
|
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return services, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
|
||||||
for _, target := range service.Targets {
|
|
||||||
switch target.TargetType {
|
|
||||||
case reverseproxy.TargetTypePeer:
|
|
||||||
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
|
|
||||||
target.Host = unknownHostPlaceholder
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
target.Host = peer.IP.String()
|
|
||||||
case reverseproxy.TargetTypeHost:
|
|
||||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
|
||||||
target.Host = unknownHostPlaceholder
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
target.Host = resource.Prefix.Addr().String()
|
|
||||||
case reverseproxy.TargetTypeDomain:
|
|
||||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
|
||||||
target.Host = unknownHostPlaceholder
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
target.Host = resource.Domain
|
|
||||||
case reverseproxy.TargetTypeSubnet:
|
|
||||||
// For subnets we do not do any lookups on the resource
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
|
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return nil, status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
|
|
||||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
return service, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return nil, status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
|
|
||||||
var proxyCluster string
|
|
||||||
if m.clusterDeriver != nil {
|
|
||||||
proxyCluster, err = m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
|
||||||
if err != nil {
|
|
||||||
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
|
|
||||||
return nil, status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
service.AccountID = accountID
|
|
||||||
service.ProxyCluster = proxyCluster
|
|
||||||
service.InitNewRecord()
|
|
||||||
err = service.Auth.HashSecrets()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("hash secrets: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate session JWT signing keys
|
|
||||||
keyPair, err := sessionkey.GenerateKeyPair()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("generate session keys: %w", err)
|
|
||||||
}
|
|
||||||
service.SessionPrivateKey = keyPair.PrivateKey
|
|
||||||
service.SessionPublicKey = keyPair.PublicKey
|
|
||||||
|
|
||||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
|
||||||
// Check for duplicate domain
|
|
||||||
existingService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
|
|
||||||
if err != nil {
|
|
||||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
|
||||||
return fmt.Errorf("failed to check existing service: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if existingService != nil {
|
|
||||||
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = transaction.CreateService(ctx, service); err != nil {
|
|
||||||
return fmt.Errorf("failed to create service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
|
|
||||||
|
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
|
||||||
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
|
||||||
|
|
||||||
return service, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return nil, status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
|
|
||||||
var oldCluster string
|
|
||||||
var domainChanged bool
|
|
||||||
var serviceEnabledChanged bool
|
|
||||||
|
|
||||||
err = service.Auth.HashSecrets()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("hash secrets: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
|
||||||
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
oldCluster = existingService.ProxyCluster
|
|
||||||
|
|
||||||
if existingService.Domain != service.Domain {
|
|
||||||
domainChanged = true
|
|
||||||
conflictService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
|
|
||||||
if err != nil {
|
|
||||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
|
||||||
return fmt.Errorf("check existing service: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if conflictService != nil && conflictService.ID != service.ID {
|
|
||||||
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.clusterDeriver != nil {
|
|
||||||
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
|
||||||
if err != nil {
|
|
||||||
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
|
|
||||||
}
|
|
||||||
service.ProxyCluster = newCluster
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
service.ProxyCluster = existingService.ProxyCluster
|
|
||||||
}
|
|
||||||
|
|
||||||
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
|
||||||
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
|
||||||
service.Auth.PasswordAuth.Password == "" {
|
|
||||||
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
|
|
||||||
}
|
|
||||||
|
|
||||||
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
|
|
||||||
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
|
|
||||||
service.Auth.PinAuth.Pin == "" {
|
|
||||||
service.Auth.PinAuth = existingService.Auth.PinAuth
|
|
||||||
}
|
|
||||||
|
|
||||||
service.Meta = existingService.Meta
|
|
||||||
service.SessionPrivateKey = existingService.SessionPrivateKey
|
|
||||||
service.SessionPublicKey = existingService.SessionPublicKey
|
|
||||||
serviceEnabledChanged = existingService.Enabled != service.Enabled
|
|
||||||
|
|
||||||
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = transaction.UpdateService(ctx, service); err != nil {
|
|
||||||
return fmt.Errorf("update service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
|
|
||||||
|
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
|
||||||
switch {
|
|
||||||
case domainChanged && oldCluster != service.ProxyCluster:
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), oldCluster)
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
|
||||||
case !service.Enabled && serviceEnabledChanged:
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
|
|
||||||
case service.Enabled && serviceEnabledChanged:
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
|
||||||
default:
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
|
|
||||||
}
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
|
||||||
|
|
||||||
return service, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
|
||||||
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
|
|
||||||
for _, target := range targets {
|
|
||||||
switch target.TargetType {
|
|
||||||
case reverseproxy.TargetTypePeer:
|
|
||||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
|
||||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
|
||||||
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
|
|
||||||
}
|
|
||||||
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
|
|
||||||
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
|
||||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
|
||||||
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
|
||||||
if err != nil {
|
|
||||||
return status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
|
|
||||||
var service *reverseproxy.Service
|
|
||||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
|
||||||
var err error
|
|
||||||
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
|
||||||
return fmt.Errorf("failed to delete service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
|
|
||||||
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
|
||||||
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
|
|
||||||
// Call this when receiving a gRPC notification that the certificate was issued.
|
|
||||||
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
|
||||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
|
||||||
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
service.Meta.CertificateIssuedAt = time.Now()
|
|
||||||
|
|
||||||
if err = transaction.UpdateService(ctx, service); err != nil {
|
|
||||||
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
|
|
||||||
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
|
|
||||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
|
||||||
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
service.Meta.Status = string(status)
|
|
||||||
|
|
||||||
if err = transaction.UpdateService(ctx, service); err != nil {
|
|
||||||
return fmt.Errorf("failed to update service status: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
|
||||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
|
||||||
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
|
||||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get services: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, service := range services {
|
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
|
||||||
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, service := range services {
|
|
||||||
err = m.replaceHostByLookup(ctx, service.AccountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return services, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
|
|
||||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return service, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
|
||||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, service := range services {
|
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return services, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
|
||||||
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
|
||||||
if err != nil {
|
|
||||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if target == nil {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return target.ServiceID, nil
|
|
||||||
}
|
|
||||||
@@ -1,463 +0,0 @@
|
|||||||
package reverseproxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/url"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/rs/xid"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
|
||||||
"github.com/netbirdio/netbird/util/crypt"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Operation string
|
|
||||||
|
|
||||||
const (
|
|
||||||
Create Operation = "create"
|
|
||||||
Update Operation = "update"
|
|
||||||
Delete Operation = "delete"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ProxyStatus string
|
|
||||||
|
|
||||||
const (
|
|
||||||
StatusPending ProxyStatus = "pending"
|
|
||||||
StatusActive ProxyStatus = "active"
|
|
||||||
StatusTunnelNotCreated ProxyStatus = "tunnel_not_created"
|
|
||||||
StatusCertificatePending ProxyStatus = "certificate_pending"
|
|
||||||
StatusCertificateFailed ProxyStatus = "certificate_failed"
|
|
||||||
StatusError ProxyStatus = "error"
|
|
||||||
|
|
||||||
TargetTypePeer = "peer"
|
|
||||||
TargetTypeHost = "host"
|
|
||||||
TargetTypeDomain = "domain"
|
|
||||||
TargetTypeSubnet = "subnet"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Target struct {
|
|
||||||
ID uint `gorm:"primaryKey" json:"-"`
|
|
||||||
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
|
||||||
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
|
||||||
Path *string `json:"path,omitempty"`
|
|
||||||
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
|
||||||
Port int `gorm:"index:idx_target_port" json:"port"`
|
|
||||||
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
|
||||||
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
|
||||||
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
|
|
||||||
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PasswordAuthConfig struct {
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
Password string `json:"password"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PINAuthConfig struct {
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
Pin string `json:"pin"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BearerAuthConfig struct {
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthConfig struct {
|
|
||||||
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
|
|
||||||
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
|
|
||||||
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AuthConfig) HashSecrets() error {
|
|
||||||
if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" {
|
|
||||||
hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("hash password: %w", err)
|
|
||||||
}
|
|
||||||
a.PasswordAuth.Password = hashedPassword
|
|
||||||
}
|
|
||||||
|
|
||||||
if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" {
|
|
||||||
hashedPin, err := argon2id.Hash(a.PinAuth.Pin)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("hash pin: %w", err)
|
|
||||||
}
|
|
||||||
a.PinAuth.Pin = hashedPin
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AuthConfig) ClearSecrets() {
|
|
||||||
if a.PasswordAuth != nil {
|
|
||||||
a.PasswordAuth.Password = ""
|
|
||||||
}
|
|
||||||
if a.PinAuth != nil {
|
|
||||||
a.PinAuth.Pin = ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type OIDCValidationConfig struct {
|
|
||||||
Issuer string
|
|
||||||
Audiences []string
|
|
||||||
KeysLocation string
|
|
||||||
MaxTokenAgeSeconds int64
|
|
||||||
}
|
|
||||||
|
|
||||||
type ServiceMeta struct {
|
|
||||||
CreatedAt time.Time
|
|
||||||
CertificateIssuedAt time.Time
|
|
||||||
Status string
|
|
||||||
}
|
|
||||||
|
|
||||||
type Service struct {
|
|
||||||
ID string `gorm:"primaryKey"`
|
|
||||||
AccountID string `gorm:"index"`
|
|
||||||
Name string
|
|
||||||
Domain string `gorm:"index"`
|
|
||||||
ProxyCluster string `gorm:"index"`
|
|
||||||
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
|
|
||||||
Enabled bool
|
|
||||||
PassHostHeader bool
|
|
||||||
RewriteRedirects bool
|
|
||||||
Auth AuthConfig `gorm:"serializer:json"`
|
|
||||||
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
|
||||||
SessionPrivateKey string `gorm:"column:session_private_key"`
|
|
||||||
SessionPublicKey string `gorm:"column:session_public_key"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
|
|
||||||
for _, target := range targets {
|
|
||||||
target.AccountID = accountID
|
|
||||||
}
|
|
||||||
|
|
||||||
s := &Service{
|
|
||||||
AccountID: accountID,
|
|
||||||
Name: name,
|
|
||||||
Domain: domain,
|
|
||||||
ProxyCluster: proxyCluster,
|
|
||||||
Targets: targets,
|
|
||||||
Enabled: enabled,
|
|
||||||
}
|
|
||||||
s.InitNewRecord()
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitNewRecord generates a new unique ID and resets metadata for a newly created
|
|
||||||
// Service record. This overwrites any existing ID and Meta fields and should
|
|
||||||
// only be called during initial creation, not for updates.
|
|
||||||
func (s *Service) InitNewRecord() {
|
|
||||||
s.ID = xid.New().String()
|
|
||||||
s.Meta = ServiceMeta{
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
Status: string(StatusPending),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) ToAPIResponse() *api.Service {
|
|
||||||
s.Auth.ClearSecrets()
|
|
||||||
|
|
||||||
authConfig := api.ServiceAuthConfig{}
|
|
||||||
|
|
||||||
if s.Auth.PasswordAuth != nil {
|
|
||||||
authConfig.PasswordAuth = &api.PasswordAuthConfig{
|
|
||||||
Enabled: s.Auth.PasswordAuth.Enabled,
|
|
||||||
Password: s.Auth.PasswordAuth.Password,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.Auth.PinAuth != nil {
|
|
||||||
authConfig.PinAuth = &api.PINAuthConfig{
|
|
||||||
Enabled: s.Auth.PinAuth.Enabled,
|
|
||||||
Pin: s.Auth.PinAuth.Pin,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.Auth.BearerAuth != nil {
|
|
||||||
authConfig.BearerAuth = &api.BearerAuthConfig{
|
|
||||||
Enabled: s.Auth.BearerAuth.Enabled,
|
|
||||||
DistributionGroups: &s.Auth.BearerAuth.DistributionGroups,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert internal targets to API targets
|
|
||||||
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
|
|
||||||
for _, target := range s.Targets {
|
|
||||||
apiTargets = append(apiTargets, api.ServiceTarget{
|
|
||||||
Path: target.Path,
|
|
||||||
Host: &target.Host,
|
|
||||||
Port: target.Port,
|
|
||||||
Protocol: api.ServiceTargetProtocol(target.Protocol),
|
|
||||||
TargetId: target.TargetId,
|
|
||||||
TargetType: api.ServiceTargetTargetType(target.TargetType),
|
|
||||||
Enabled: target.Enabled,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
meta := api.ServiceMeta{
|
|
||||||
CreatedAt: s.Meta.CreatedAt,
|
|
||||||
Status: api.ServiceMetaStatus(s.Meta.Status),
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.Meta.CertificateIssuedAt.IsZero() {
|
|
||||||
meta.CertificateIssuedAt = &s.Meta.CertificateIssuedAt
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &api.Service{
|
|
||||||
Id: s.ID,
|
|
||||||
Name: s.Name,
|
|
||||||
Domain: s.Domain,
|
|
||||||
Targets: apiTargets,
|
|
||||||
Enabled: s.Enabled,
|
|
||||||
PassHostHeader: &s.PassHostHeader,
|
|
||||||
RewriteRedirects: &s.RewriteRedirects,
|
|
||||||
Auth: authConfig,
|
|
||||||
Meta: meta,
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.ProxyCluster != "" {
|
|
||||||
resp.ProxyCluster = &s.ProxyCluster
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
|
|
||||||
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
|
|
||||||
for _, target := range s.Targets {
|
|
||||||
if !target.Enabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Make path prefix stripping configurable per-target.
|
|
||||||
// Currently the matching prefix is baked into the target URL path,
|
|
||||||
// so the proxy strips-then-re-adds it (effectively a no-op).
|
|
||||||
targetURL := url.URL{
|
|
||||||
Scheme: target.Protocol,
|
|
||||||
Host: target.Host,
|
|
||||||
Path: "/", // TODO: support service path
|
|
||||||
}
|
|
||||||
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
|
|
||||||
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
|
|
||||||
}
|
|
||||||
|
|
||||||
path := "/"
|
|
||||||
if target.Path != nil {
|
|
||||||
path = *target.Path
|
|
||||||
}
|
|
||||||
pathMappings = append(pathMappings, &proto.PathMapping{
|
|
||||||
Path: path,
|
|
||||||
Target: targetURL.String(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
auth := &proto.Authentication{
|
|
||||||
SessionKey: s.SessionPublicKey,
|
|
||||||
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled {
|
|
||||||
auth.Password = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled {
|
|
||||||
auth.Pin = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
|
|
||||||
auth.Oidc = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return &proto.ProxyMapping{
|
|
||||||
Type: operationToProtoType(operation),
|
|
||||||
Id: s.ID,
|
|
||||||
Domain: s.Domain,
|
|
||||||
Path: pathMappings,
|
|
||||||
AuthToken: authToken,
|
|
||||||
Auth: auth,
|
|
||||||
AccountId: s.AccountID,
|
|
||||||
PassHostHeader: s.PassHostHeader,
|
|
||||||
RewriteRedirects: s.RewriteRedirects,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
|
|
||||||
switch op {
|
|
||||||
case Create:
|
|
||||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
|
||||||
case Update:
|
|
||||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
|
|
||||||
case Delete:
|
|
||||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
|
|
||||||
default:
|
|
||||||
log.Fatalf("unknown operation type: %v", op)
|
|
||||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// isDefaultPort reports whether port is the standard default for the given scheme
|
|
||||||
// (443 for https, 80 for http).
|
|
||||||
func isDefaultPort(scheme string, port int) bool {
|
|
||||||
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
|
|
||||||
s.Name = req.Name
|
|
||||||
s.Domain = req.Domain
|
|
||||||
s.AccountID = accountID
|
|
||||||
|
|
||||||
targets := make([]*Target, 0, len(req.Targets))
|
|
||||||
for _, apiTarget := range req.Targets {
|
|
||||||
target := &Target{
|
|
||||||
AccountID: accountID,
|
|
||||||
Path: apiTarget.Path,
|
|
||||||
Port: apiTarget.Port,
|
|
||||||
Protocol: string(apiTarget.Protocol),
|
|
||||||
TargetId: apiTarget.TargetId,
|
|
||||||
TargetType: string(apiTarget.TargetType),
|
|
||||||
Enabled: apiTarget.Enabled,
|
|
||||||
}
|
|
||||||
if apiTarget.Host != nil {
|
|
||||||
target.Host = *apiTarget.Host
|
|
||||||
}
|
|
||||||
targets = append(targets, target)
|
|
||||||
}
|
|
||||||
s.Targets = targets
|
|
||||||
|
|
||||||
s.Enabled = req.Enabled
|
|
||||||
|
|
||||||
if req.PassHostHeader != nil {
|
|
||||||
s.PassHostHeader = *req.PassHostHeader
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.RewriteRedirects != nil {
|
|
||||||
s.RewriteRedirects = *req.RewriteRedirects
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Auth.PasswordAuth != nil {
|
|
||||||
s.Auth.PasswordAuth = &PasswordAuthConfig{
|
|
||||||
Enabled: req.Auth.PasswordAuth.Enabled,
|
|
||||||
Password: req.Auth.PasswordAuth.Password,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Auth.PinAuth != nil {
|
|
||||||
s.Auth.PinAuth = &PINAuthConfig{
|
|
||||||
Enabled: req.Auth.PinAuth.Enabled,
|
|
||||||
Pin: req.Auth.PinAuth.Pin,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Auth.BearerAuth != nil {
|
|
||||||
bearerAuth := &BearerAuthConfig{
|
|
||||||
Enabled: req.Auth.BearerAuth.Enabled,
|
|
||||||
}
|
|
||||||
if req.Auth.BearerAuth.DistributionGroups != nil {
|
|
||||||
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
|
|
||||||
}
|
|
||||||
s.Auth.BearerAuth = bearerAuth
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) Validate() error {
|
|
||||||
if s.Name == "" {
|
|
||||||
return errors.New("service name is required")
|
|
||||||
}
|
|
||||||
if len(s.Name) > 255 {
|
|
||||||
return errors.New("service name exceeds maximum length of 255 characters")
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.Domain == "" {
|
|
||||||
return errors.New("service domain is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(s.Targets) == 0 {
|
|
||||||
return errors.New("at least one target is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, target := range s.Targets {
|
|
||||||
switch target.TargetType {
|
|
||||||
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
|
||||||
// host field will be ignored
|
|
||||||
case TargetTypeSubnet:
|
|
||||||
if target.Host == "" {
|
|
||||||
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
|
|
||||||
}
|
|
||||||
if target.TargetId == "" {
|
|
||||||
return fmt.Errorf("target %d has empty target_id", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) EventMeta() map[string]any {
|
|
||||||
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) Copy() *Service {
|
|
||||||
targets := make([]*Target, len(s.Targets))
|
|
||||||
for i, target := range s.Targets {
|
|
||||||
targetCopy := *target
|
|
||||||
targets[i] = &targetCopy
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Service{
|
|
||||||
ID: s.ID,
|
|
||||||
AccountID: s.AccountID,
|
|
||||||
Name: s.Name,
|
|
||||||
Domain: s.Domain,
|
|
||||||
ProxyCluster: s.ProxyCluster,
|
|
||||||
Targets: targets,
|
|
||||||
Enabled: s.Enabled,
|
|
||||||
PassHostHeader: s.PassHostHeader,
|
|
||||||
RewriteRedirects: s.RewriteRedirects,
|
|
||||||
Auth: s.Auth,
|
|
||||||
Meta: s.Meta,
|
|
||||||
SessionPrivateKey: s.SessionPrivateKey,
|
|
||||||
SessionPublicKey: s.SessionPublicKey,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
|
||||||
if enc == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.SessionPrivateKey != "" {
|
|
||||||
var err error
|
|
||||||
s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
|
||||||
if enc == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.SessionPrivateKey != "" {
|
|
||||||
var err error
|
|
||||||
s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,405 +0,0 @@
|
|||||||
package reverseproxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
func validProxy() *Service {
|
|
||||||
return &Service{
|
|
||||||
Name: "test",
|
|
||||||
Domain: "example.com",
|
|
||||||
Targets: []*Target{
|
|
||||||
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_Valid(t *testing.T) {
|
|
||||||
require.NoError(t, validProxy().Validate())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_EmptyName(t *testing.T) {
|
|
||||||
rp := validProxy()
|
|
||||||
rp.Name = ""
|
|
||||||
assert.ErrorContains(t, rp.Validate(), "name is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_EmptyDomain(t *testing.T) {
|
|
||||||
rp := validProxy()
|
|
||||||
rp.Domain = ""
|
|
||||||
assert.ErrorContains(t, rp.Validate(), "domain is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_NoTargets(t *testing.T) {
|
|
||||||
rp := validProxy()
|
|
||||||
rp.Targets = nil
|
|
||||||
assert.ErrorContains(t, rp.Validate(), "at least one target")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_EmptyTargetId(t *testing.T) {
|
|
||||||
rp := validProxy()
|
|
||||||
rp.Targets[0].TargetId = ""
|
|
||||||
assert.ErrorContains(t, rp.Validate(), "empty target_id")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_InvalidTargetType(t *testing.T) {
|
|
||||||
rp := validProxy()
|
|
||||||
rp.Targets[0].TargetType = "invalid"
|
|
||||||
assert.ErrorContains(t, rp.Validate(), "invalid target_type")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_ResourceTarget(t *testing.T) {
|
|
||||||
rp := validProxy()
|
|
||||||
rp.Targets = append(rp.Targets, &Target{
|
|
||||||
TargetId: "resource-1",
|
|
||||||
TargetType: TargetTypeHost,
|
|
||||||
Host: "example.org",
|
|
||||||
Port: 443,
|
|
||||||
Protocol: "https",
|
|
||||||
Enabled: true,
|
|
||||||
})
|
|
||||||
require.NoError(t, rp.Validate())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
|
|
||||||
rp := validProxy()
|
|
||||||
rp.Targets = append(rp.Targets, &Target{
|
|
||||||
TargetId: "",
|
|
||||||
TargetType: TargetTypePeer,
|
|
||||||
Host: "10.0.0.2",
|
|
||||||
Port: 80,
|
|
||||||
Protocol: "http",
|
|
||||||
Enabled: true,
|
|
||||||
})
|
|
||||||
err := rp.Validate()
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Contains(t, err.Error(), "target 1")
|
|
||||||
assert.Contains(t, err.Error(), "empty target_id")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsDefaultPort(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
scheme string
|
|
||||||
port int
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{"http", 80, true},
|
|
||||||
{"https", 443, true},
|
|
||||||
{"http", 443, false},
|
|
||||||
{"https", 80, false},
|
|
||||||
{"http", 8080, false},
|
|
||||||
{"https", 8443, false},
|
|
||||||
{"http", 0, false},
|
|
||||||
{"https", 0, false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) {
|
|
||||||
assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
|
|
||||||
oidcConfig := OIDCValidationConfig{}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
protocol string
|
|
||||||
host string
|
|
||||||
port int
|
|
||||||
wantTarget string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "http with default port 80 omits port",
|
|
||||||
protocol: "http",
|
|
||||||
host: "10.0.0.1",
|
|
||||||
port: 80,
|
|
||||||
wantTarget: "http://10.0.0.1/",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "https with default port 443 omits port",
|
|
||||||
protocol: "https",
|
|
||||||
host: "10.0.0.1",
|
|
||||||
port: 443,
|
|
||||||
wantTarget: "https://10.0.0.1/",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "port 0 omits port",
|
|
||||||
protocol: "http",
|
|
||||||
host: "10.0.0.1",
|
|
||||||
port: 0,
|
|
||||||
wantTarget: "http://10.0.0.1/",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-default port is included",
|
|
||||||
protocol: "http",
|
|
||||||
host: "10.0.0.1",
|
|
||||||
port: 8080,
|
|
||||||
wantTarget: "http://10.0.0.1:8080/",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "https with non-default port is included",
|
|
||||||
protocol: "https",
|
|
||||||
host: "10.0.0.1",
|
|
||||||
port: 8443,
|
|
||||||
wantTarget: "https://10.0.0.1:8443/",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "http port 443 is included",
|
|
||||||
protocol: "http",
|
|
||||||
host: "10.0.0.1",
|
|
||||||
port: 443,
|
|
||||||
wantTarget: "http://10.0.0.1:443/",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "https port 80 is included",
|
|
||||||
protocol: "https",
|
|
||||||
host: "10.0.0.1",
|
|
||||||
port: 80,
|
|
||||||
wantTarget: "https://10.0.0.1:80/",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
rp := &Service{
|
|
||||||
ID: "test-id",
|
|
||||||
AccountID: "acc-1",
|
|
||||||
Domain: "example.com",
|
|
||||||
Targets: []*Target{
|
|
||||||
{
|
|
||||||
TargetId: "peer-1",
|
|
||||||
TargetType: TargetTypePeer,
|
|
||||||
Host: tt.host,
|
|
||||||
Port: tt.port,
|
|
||||||
Protocol: tt.protocol,
|
|
||||||
Enabled: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
pm := rp.ToProtoMapping(Create, "token", oidcConfig)
|
|
||||||
require.Len(t, pm.Path, 1, "should have one path mapping")
|
|
||||||
assert.Equal(t, tt.wantTarget, pm.Path[0].Target)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
|
|
||||||
rp := &Service{
|
|
||||||
ID: "test-id",
|
|
||||||
AccountID: "acc-1",
|
|
||||||
Domain: "example.com",
|
|
||||||
Targets: []*Target{
|
|
||||||
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false},
|
|
||||||
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{})
|
|
||||||
require.Len(t, pm.Path, 1)
|
|
||||||
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestToProtoMapping_OperationTypes(t *testing.T) {
|
|
||||||
rp := validProxy()
|
|
||||||
tests := []struct {
|
|
||||||
op Operation
|
|
||||||
want proto.ProxyMappingUpdateType
|
|
||||||
}{
|
|
||||||
{Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED},
|
|
||||||
{Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED},
|
|
||||||
{Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(string(tt.op), func(t *testing.T) {
|
|
||||||
pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{})
|
|
||||||
assert.Equal(t, tt.want, pm.Type)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthConfig_HashSecrets(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
config *AuthConfig
|
|
||||||
wantErr bool
|
|
||||||
validate func(*testing.T, *AuthConfig)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "hash password successfully",
|
|
||||||
config: &AuthConfig{
|
|
||||||
PasswordAuth: &PasswordAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Password: "testPassword123",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
validate: func(t *testing.T, config *AuthConfig) {
|
|
||||||
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
|
|
||||||
t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password)
|
|
||||||
}
|
|
||||||
// Verify the hash can be verified
|
|
||||||
if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil {
|
|
||||||
t.Errorf("Hash verification failed: %v", err)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "hash PIN successfully",
|
|
||||||
config: &AuthConfig{
|
|
||||||
PinAuth: &PINAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Pin: "123456",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
validate: func(t *testing.T, config *AuthConfig) {
|
|
||||||
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
|
||||||
t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin)
|
|
||||||
}
|
|
||||||
// Verify the hash can be verified
|
|
||||||
if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil {
|
|
||||||
t.Errorf("Hash verification failed: %v", err)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "hash both password and PIN",
|
|
||||||
config: &AuthConfig{
|
|
||||||
PasswordAuth: &PasswordAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Password: "password",
|
|
||||||
},
|
|
||||||
PinAuth: &PINAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Pin: "9999",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
validate: func(t *testing.T, config *AuthConfig) {
|
|
||||||
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
|
|
||||||
t.Errorf("Password not hashed with argon2id")
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
|
||||||
t.Errorf("PIN not hashed with argon2id")
|
|
||||||
}
|
|
||||||
if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil {
|
|
||||||
t.Errorf("Password hash verification failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil {
|
|
||||||
t.Errorf("PIN hash verification failed: %v", err)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "skip disabled password auth",
|
|
||||||
config: &AuthConfig{
|
|
||||||
PasswordAuth: &PasswordAuthConfig{
|
|
||||||
Enabled: false,
|
|
||||||
Password: "password",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
validate: func(t *testing.T, config *AuthConfig) {
|
|
||||||
if config.PasswordAuth.Password != "password" {
|
|
||||||
t.Errorf("Disabled password auth should not be hashed")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "skip empty password",
|
|
||||||
config: &AuthConfig{
|
|
||||||
PasswordAuth: &PasswordAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Password: "",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
validate: func(t *testing.T, config *AuthConfig) {
|
|
||||||
if config.PasswordAuth.Password != "" {
|
|
||||||
t.Errorf("Empty password should remain empty")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "skip nil password auth",
|
|
||||||
config: &AuthConfig{
|
|
||||||
PasswordAuth: nil,
|
|
||||||
PinAuth: &PINAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Pin: "1234",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
validate: func(t *testing.T, config *AuthConfig) {
|
|
||||||
if config.PasswordAuth != nil {
|
|
||||||
t.Errorf("PasswordAuth should remain nil")
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
|
||||||
t.Errorf("PIN should still be hashed")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
err := tt.config.HashSecrets()
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if tt.validate != nil {
|
|
||||||
tt.validate(t, tt.config)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) {
|
|
||||||
config := &AuthConfig{
|
|
||||||
PasswordAuth: &PasswordAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Password: "correctPassword",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := config.HashSecrets(); err != nil {
|
|
||||||
t.Fatalf("HashSecrets() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify with wrong password should fail
|
|
||||||
err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password)
|
|
||||||
if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
|
|
||||||
t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthConfig_ClearSecrets(t *testing.T) {
|
|
||||||
config := &AuthConfig{
|
|
||||||
PasswordAuth: &PasswordAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Password: "hashedPassword",
|
|
||||||
},
|
|
||||||
PinAuth: &PINAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Pin: "hashedPin",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
config.ClearSecrets()
|
|
||||||
|
|
||||||
if config.PasswordAuth.Password != "" {
|
|
||||||
t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password)
|
|
||||||
}
|
|
||||||
if config.PinAuth.Pin != "" {
|
|
||||||
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
package sessionkey
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/proxy/auth"
|
|
||||||
)
|
|
||||||
|
|
||||||
type KeyPair struct {
|
|
||||||
PrivateKey string
|
|
||||||
PublicKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
type Claims struct {
|
|
||||||
jwt.RegisteredClaims
|
|
||||||
Method auth.Method `json:"method"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func GenerateKeyPair() (*KeyPair, error) {
|
|
||||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("generate ed25519 key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &KeyPair{
|
|
||||||
PrivateKey: base64.StdEncoding.EncodeToString(priv),
|
|
||||||
PublicKey: base64.StdEncoding.EncodeToString(pub),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) {
|
|
||||||
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("decode private key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(privKeyBytes) != ed25519.PrivateKeySize {
|
|
||||||
return "", fmt.Errorf("invalid private key size: got %d, want %d", len(privKeyBytes), ed25519.PrivateKeySize)
|
|
||||||
}
|
|
||||||
|
|
||||||
privKey := ed25519.PrivateKey(privKeyBytes)
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
claims := Claims{
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Issuer: auth.SessionJWTIssuer,
|
|
||||||
Subject: userID,
|
|
||||||
Audience: jwt.ClaimStrings{domain},
|
|
||||||
ExpiresAt: jwt.NewNumericDate(now.Add(expiration)),
|
|
||||||
IssuedAt: jwt.NewNumericDate(now),
|
|
||||||
NotBefore: jwt.NewNumericDate(now),
|
|
||||||
},
|
|
||||||
Method: method,
|
|
||||||
}
|
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
|
||||||
signedToken, err := token.SignedString(privKey)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("sign token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return signedToken, nil
|
|
||||||
}
|
|
||||||
@@ -21,8 +21,6 @@ import (
|
|||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
"github.com/netbirdio/netbird/formatter/hook"
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
|
||||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||||
@@ -94,7 +92,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
|||||||
|
|
||||||
func (s *BaseServer) APIHandler() http.Handler {
|
func (s *BaseServer) APIHandler() http.Handler {
|
||||||
return Create(s, func() http.Handler {
|
return Create(s, func() http.Handler {
|
||||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create API handler: %v", err)
|
log.Fatalf("failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
@@ -122,13 +120,11 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
realip.WithTrustedProxiesCount(trustedProxiesCount),
|
realip.WithTrustedProxiesCount(trustedProxiesCount),
|
||||||
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
|
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
|
||||||
}
|
}
|
||||||
proxyUnary, proxyStream, proxyAuthClose := nbgrpc.NewProxyAuthInterceptors(s.Store())
|
|
||||||
s.proxyAuthClose = proxyAuthClose
|
|
||||||
gRPCOpts := []grpc.ServerOption{
|
gRPCOpts := []grpc.ServerOption{
|
||||||
grpc.KeepaliveEnforcementPolicy(kaep),
|
grpc.KeepaliveEnforcementPolicy(kaep),
|
||||||
grpc.KeepaliveParams(kasp),
|
grpc.KeepaliveParams(kasp),
|
||||||
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor, proxyUnary),
|
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
|
||||||
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor, proxyStream),
|
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
||||||
@@ -154,53 +150,10 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
}
|
}
|
||||||
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
||||||
|
|
||||||
mgmtProto.RegisterProxyServiceServer(gRPCAPIHandler, s.ReverseProxyGRPCServer())
|
|
||||||
log.Info("ProxyService registered on gRPC server")
|
|
||||||
|
|
||||||
return gRPCAPIHandler
|
return gRPCAPIHandler
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
|
||||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
|
||||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager())
|
|
||||||
s.AfterInit(func(s *BaseServer) {
|
|
||||||
proxyService.SetProxyManager(s.ReverseProxyManager())
|
|
||||||
})
|
|
||||||
return proxyService
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
|
|
||||||
return Create(s, func() nbgrpc.ProxyOIDCConfig {
|
|
||||||
return nbgrpc.ProxyOIDCConfig{
|
|
||||||
Issuer: s.Config.HttpConfig.AuthIssuer,
|
|
||||||
// todo: double check auth clientID value
|
|
||||||
ClientID: s.Config.HttpConfig.AuthClientID, // Reuse dashboard client
|
|
||||||
Scopes: []string{"openid", "profile", "email"},
|
|
||||||
CallbackURL: s.Config.HttpConfig.AuthCallbackURL,
|
|
||||||
HMACKey: []byte(s.Config.DataStoreEncryptionKey), // Use the datastore encryption key for OIDC state HMACs, this should ensure all management instances are using the same key.
|
|
||||||
Audience: s.Config.HttpConfig.AuthAudience,
|
|
||||||
KeysLocation: s.Config.HttpConfig.AuthKeysLocation,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
|
|
||||||
return Create(s, func() *nbgrpc.OneTimeTokenStore {
|
|
||||||
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
|
|
||||||
log.Info("One-time token store initialized for proxy authentication")
|
|
||||||
return tokenStore
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
|
|
||||||
return Create(s, func() accesslogs.Manager {
|
|
||||||
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
|
|
||||||
return accessLogManager
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
|
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
|
||||||
// Load server's certificate and private key
|
// Load server's certificate and private key
|
||||||
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
|
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
|
||||||
|
|||||||
@@ -100,8 +100,6 @@ type HttpServerConfig struct {
|
|||||||
CertFile string
|
CertFile string
|
||||||
// CertKey is the location of the certificate private key
|
// CertKey is the location of the certificate private key
|
||||||
CertKey string
|
CertKey string
|
||||||
// AuthClientID is the client id used for proxy SSO auth
|
|
||||||
AuthClientID string
|
|
||||||
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
|
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
|
||||||
AuthAudience string
|
AuthAudience string
|
||||||
// CLIAuthAudience identifies the client app recipients that the JWT is intended for (aud in JWT)
|
// CLIAuthAudience identifies the client app recipients that the JWT is intended for (aud in JWT)
|
||||||
@@ -119,8 +117,6 @@ type HttpServerConfig struct {
|
|||||||
IdpSignKeyRefreshEnabled bool
|
IdpSignKeyRefreshEnabled bool
|
||||||
// Extra audience
|
// Extra audience
|
||||||
ExtraAuthAudience string
|
ExtraAuthAudience string
|
||||||
// AuthCallbackDomain contains the callback domain
|
|
||||||
AuthCallbackURL string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Host represents a Netbird host (e.g. STUN, TURN, Signal)
|
// Host represents a Netbird host (e.g. STUN, TURN, Signal)
|
||||||
|
|||||||
@@ -8,9 +8,6 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
|
||||||
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
@@ -72,14 +69,7 @@ func (s *BaseServer) UsersManager() users.Manager {
|
|||||||
func (s *BaseServer) SettingsManager() settings.Manager {
|
func (s *BaseServer) SettingsManager() settings.Manager {
|
||||||
return Create(s, func() settings.Manager {
|
return Create(s, func() settings.Manager {
|
||||||
extraSettingsManager := integrations.NewManager(s.EventStore())
|
extraSettingsManager := integrations.NewManager(s.EventStore())
|
||||||
|
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager())
|
||||||
idpConfig := settings.IdpConfig{}
|
|
||||||
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
|
|
||||||
idpConfig.EmbeddedIdpEnabled = true
|
|
||||||
idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled
|
|
||||||
}
|
|
||||||
|
|
||||||
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,11 +91,6 @@ func (s *BaseServer) AccountManager() account.Manager {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create account manager: %v", err)
|
log.Fatalf("failed to create account manager: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.AfterInit(func(s *BaseServer) {
|
|
||||||
accountManager.SetServiceManager(s.ReverseProxyManager())
|
|
||||||
})
|
|
||||||
|
|
||||||
return accountManager
|
return accountManager
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -162,7 +147,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
|
|||||||
|
|
||||||
func (s *BaseServer) ResourcesManager() resources.Manager {
|
func (s *BaseServer) ResourcesManager() resources.Manager {
|
||||||
return Create(s, func() resources.Manager {
|
return Create(s, func() resources.Manager {
|
||||||
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager())
|
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,16 +174,3 @@ func (s *BaseServer) RecordsManager() records.Manager {
|
|||||||
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
|
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager {
|
|
||||||
return Create(s, func() reverseproxy.Manager {
|
|
||||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
|
||||||
return Create(s, func() *manager.Manager {
|
|
||||||
m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager())
|
|
||||||
return &m
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
"golang.org/x/crypto/acme/autocert"
|
"golang.org/x/crypto/acme/autocert"
|
||||||
@@ -20,7 +21,6 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
|
||||||
"github.com/netbirdio/netbird/management/server/metrics"
|
"github.com/netbirdio/netbird/management/server/metrics"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/util/wsproxy"
|
"github.com/netbirdio/netbird/util/wsproxy"
|
||||||
@@ -58,8 +58,6 @@ type BaseServer struct {
|
|||||||
mgmtMetricsPort int
|
mgmtMetricsPort int
|
||||||
mgmtPort int
|
mgmtPort int
|
||||||
|
|
||||||
proxyAuthClose func()
|
|
||||||
|
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
certManager *autocert.Manager
|
certManager *autocert.Manager
|
||||||
update *version.Update
|
update *version.Update
|
||||||
@@ -217,11 +215,6 @@ func (s *BaseServer) Stop() error {
|
|||||||
_ = s.certManager.Listener().Close()
|
_ = s.certManager.Listener().Close()
|
||||||
}
|
}
|
||||||
s.GRPCServer().Stop()
|
s.GRPCServer().Stop()
|
||||||
s.ReverseProxyGRPCServer().Close()
|
|
||||||
if s.proxyAuthClose != nil {
|
|
||||||
s.proxyAuthClose()
|
|
||||||
s.proxyAuthClose = nil
|
|
||||||
}
|
|
||||||
_ = s.Store().Close(ctx)
|
_ = s.Store().Close(ctx)
|
||||||
_ = s.EventStore().Close(ctx)
|
_ = s.EventStore().Close(ctx)
|
||||||
if s.update != nil {
|
if s.update != nil {
|
||||||
|
|||||||
@@ -1,167 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/subtle"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// OneTimeTokenStore manages short-lived, single-use authentication tokens
|
|
||||||
// for proxy-to-management RPC authentication. Tokens are generated when
|
|
||||||
// a service is created and must be used exactly once by the proxy
|
|
||||||
// to authenticate a subsequent RPC call.
|
|
||||||
type OneTimeTokenStore struct {
|
|
||||||
tokens map[string]*tokenMetadata
|
|
||||||
mu sync.RWMutex
|
|
||||||
cleanup *time.Ticker
|
|
||||||
cleanupDone chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// tokenMetadata stores information about a one-time token
|
|
||||||
type tokenMetadata struct {
|
|
||||||
ServiceID string
|
|
||||||
AccountID string
|
|
||||||
ExpiresAt time.Time
|
|
||||||
CreatedAt time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewOneTimeTokenStore creates a new token store with automatic cleanup
|
|
||||||
// of expired tokens. The cleanupInterval determines how often expired
|
|
||||||
// tokens are removed from memory.
|
|
||||||
func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
|
|
||||||
store := &OneTimeTokenStore{
|
|
||||||
tokens: make(map[string]*tokenMetadata),
|
|
||||||
cleanup: time.NewTicker(cleanupInterval),
|
|
||||||
cleanupDone: make(chan struct{}),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start background cleanup goroutine
|
|
||||||
go store.cleanupExpired()
|
|
||||||
|
|
||||||
return store
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateToken creates a new cryptographically secure one-time token
|
|
||||||
// with the specified TTL. The token is associated with a specific
|
|
||||||
// accountID and serviceID for validation purposes.
|
|
||||||
//
|
|
||||||
// Returns the generated token string or an error if random generation fails.
|
|
||||||
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
|
|
||||||
// Generate 32 bytes (256 bits) of cryptographically secure random data
|
|
||||||
randomBytes := make([]byte, 32)
|
|
||||||
if _, err := rand.Read(randomBytes); err != nil {
|
|
||||||
return "", fmt.Errorf("failed to generate random token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode as URL-safe base64 for easy transmission in gRPC
|
|
||||||
token := base64.URLEncoding.EncodeToString(randomBytes)
|
|
||||||
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
s.tokens[token] = &tokenMetadata{
|
|
||||||
ServiceID: serviceID,
|
|
||||||
AccountID: accountID,
|
|
||||||
ExpiresAt: time.Now().Add(ttl),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
|
|
||||||
serviceID, accountID, ttl)
|
|
||||||
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateAndConsume verifies the token against the provided accountID and
|
|
||||||
// serviceID, checks expiration, and then deletes it to enforce single-use.
|
|
||||||
//
|
|
||||||
// This method uses constant-time comparison to prevent timing attacks.
|
|
||||||
//
|
|
||||||
// Returns nil on success, or an error if:
|
|
||||||
// - Token doesn't exist
|
|
||||||
// - Token has expired
|
|
||||||
// - Account ID doesn't match
|
|
||||||
// - Reverse proxy ID doesn't match
|
|
||||||
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
metadata, exists := s.tokens[token]
|
|
||||||
if !exists {
|
|
||||||
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)",
|
|
||||||
serviceID, accountID)
|
|
||||||
return fmt.Errorf("invalid token")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check expiration
|
|
||||||
if time.Now().After(metadata.ExpiresAt) {
|
|
||||||
delete(s.tokens, token)
|
|
||||||
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
|
|
||||||
serviceID, accountID)
|
|
||||||
return fmt.Errorf("token expired")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate account ID using constant-time comparison (prevents timing attacks)
|
|
||||||
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
|
|
||||||
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)",
|
|
||||||
metadata.AccountID, accountID)
|
|
||||||
return fmt.Errorf("account ID mismatch")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate service ID using constant-time comparison
|
|
||||||
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
|
|
||||||
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)",
|
|
||||||
metadata.ServiceID, serviceID)
|
|
||||||
return fmt.Errorf("service ID mismatch")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete token immediately to enforce single-use
|
|
||||||
delete(s.tokens, token)
|
|
||||||
|
|
||||||
log.Infof("Token validated and consumed for proxy %s in account %s",
|
|
||||||
serviceID, accountID)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanupExpired removes expired tokens in the background to prevent memory leaks
|
|
||||||
func (s *OneTimeTokenStore) cleanupExpired() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-s.cleanup.C:
|
|
||||||
s.mu.Lock()
|
|
||||||
now := time.Now()
|
|
||||||
removed := 0
|
|
||||||
for token, metadata := range s.tokens {
|
|
||||||
if now.After(metadata.ExpiresAt) {
|
|
||||||
delete(s.tokens, token)
|
|
||||||
removed++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if removed > 0 {
|
|
||||||
log.Debugf("Cleaned up %d expired one-time tokens", removed)
|
|
||||||
}
|
|
||||||
s.mu.Unlock()
|
|
||||||
case <-s.cleanupDone:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close stops the cleanup goroutine and releases resources
|
|
||||||
func (s *OneTimeTokenStore) Close() {
|
|
||||||
s.cleanup.Stop()
|
|
||||||
close(s.cleanupDone)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetTokenCount returns the current number of tokens in the store (for debugging/metrics)
|
|
||||||
func (s *OneTimeTokenStore) GetTokenCount() int {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
return len(s.tokens)
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,234 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/metadata"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// lastUsedUpdateInterval is the minimum interval between last_used updates for the same token.
|
|
||||||
lastUsedUpdateInterval = time.Minute
|
|
||||||
// lastUsedCleanupInterval is how often stale lastUsed entries are removed.
|
|
||||||
lastUsedCleanupInterval = 2 * time.Minute
|
|
||||||
)
|
|
||||||
|
|
||||||
type proxyTokenContextKey struct{}
|
|
||||||
|
|
||||||
// ProxyTokenContextKey is the typed key used to store validated token info in context.
|
|
||||||
var ProxyTokenContextKey = proxyTokenContextKey{}
|
|
||||||
|
|
||||||
// proxyTokenID identifies a proxy access token by its database ID.
|
|
||||||
type proxyTokenID = string
|
|
||||||
|
|
||||||
// proxyTokenStore defines the store interface needed for token validation
|
|
||||||
type proxyTokenStore interface {
|
|
||||||
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength store.LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
|
|
||||||
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// proxyAuthInterceptor holds state for proxy authentication interceptors.
|
|
||||||
type proxyAuthInterceptor struct {
|
|
||||||
store proxyTokenStore
|
|
||||||
failureLimiter *authFailureLimiter
|
|
||||||
|
|
||||||
// lastUsedMu protects lastUsedTimes
|
|
||||||
lastUsedMu sync.Mutex
|
|
||||||
lastUsedTimes map[proxyTokenID]time.Time
|
|
||||||
cancel context.CancelFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
func newProxyAuthInterceptor(tokenStore proxyTokenStore) *proxyAuthInterceptor {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
i := &proxyAuthInterceptor{
|
|
||||||
store: tokenStore,
|
|
||||||
failureLimiter: newAuthFailureLimiter(),
|
|
||||||
lastUsedTimes: make(map[proxyTokenID]time.Time),
|
|
||||||
cancel: cancel,
|
|
||||||
}
|
|
||||||
go i.lastUsedCleanupLoop(ctx)
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewProxyAuthInterceptors creates gRPC unary and stream interceptors that validate proxy access tokens.
|
|
||||||
// They only intercept ProxyService methods. Both interceptors share state for last-used and failure rate limiting.
|
|
||||||
// The returned close function must be called on shutdown to stop background goroutines.
|
|
||||||
func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor, func()) {
|
|
||||||
interceptor := newProxyAuthInterceptor(tokenStore)
|
|
||||||
|
|
||||||
unary := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
|
||||||
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
|
|
||||||
return handler(ctx, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := interceptor.validateProxyToken(ctx)
|
|
||||||
if err != nil {
|
|
||||||
// Log auth failures explicitly; gRPC doesn't log these by default.
|
|
||||||
log.WithContext(ctx).Warnf("proxy auth failed: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx = context.WithValue(ctx, ProxyTokenContextKey, token)
|
|
||||||
return handler(ctx, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
stream := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
|
||||||
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
|
|
||||||
return handler(srv, ss)
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := interceptor.validateProxyToken(ss.Context())
|
|
||||||
if err != nil {
|
|
||||||
// Log auth failures explicitly; gRPC doesn't log these by default.
|
|
||||||
log.WithContext(ss.Context()).Warnf("proxy auth failed: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.WithValue(ss.Context(), ProxyTokenContextKey, token)
|
|
||||||
wrapped := &wrappedServerStream{
|
|
||||||
ServerStream: ss,
|
|
||||||
ctx: ctx,
|
|
||||||
}
|
|
||||||
|
|
||||||
return handler(srv, wrapped)
|
|
||||||
}
|
|
||||||
|
|
||||||
return unary, stream, interceptor.close
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
|
|
||||||
clientIP := peerIPFromContext(ctx)
|
|
||||||
|
|
||||||
if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
|
|
||||||
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := i.doValidateProxyToken(ctx)
|
|
||||||
if err != nil {
|
|
||||||
if clientIP != "" {
|
|
||||||
i.failureLimiter.recordFailure(clientIP)
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
i.maybeUpdateLastUsed(ctx, token.ID)
|
|
||||||
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
|
|
||||||
md, ok := metadata.FromIncomingContext(ctx)
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Errorf(codes.Unauthenticated, "missing metadata")
|
|
||||||
}
|
|
||||||
|
|
||||||
authValues := md.Get("authorization")
|
|
||||||
if len(authValues) == 0 {
|
|
||||||
return nil, status.Errorf(codes.Unauthenticated, "missing authorization header")
|
|
||||||
}
|
|
||||||
|
|
||||||
authValue := authValues[0]
|
|
||||||
if !strings.HasPrefix(authValue, "Bearer ") {
|
|
||||||
return nil, status.Errorf(codes.Unauthenticated, "invalid authorization format")
|
|
||||||
}
|
|
||||||
|
|
||||||
plainToken := types.PlainProxyToken(strings.TrimPrefix(authValue, "Bearer "))
|
|
||||||
|
|
||||||
if err := plainToken.Validate(); err != nil {
|
|
||||||
return nil, status.Errorf(codes.Unauthenticated, "invalid token format")
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := i.store.GetProxyAccessTokenByHashedToken(ctx, store.LockingStrengthNone, plainToken.Hash())
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
|
|
||||||
// Currently tokens are management-wide; AccountID field is reserved for future use.
|
|
||||||
|
|
||||||
if !token.IsValid() {
|
|
||||||
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
|
|
||||||
}
|
|
||||||
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// maybeUpdateLastUsed updates the last_used timestamp if enough time has passed since the last update.
|
|
||||||
func (i *proxyAuthInterceptor) maybeUpdateLastUsed(ctx context.Context, tokenID string) {
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
i.lastUsedMu.Lock()
|
|
||||||
lastUpdate, exists := i.lastUsedTimes[tokenID]
|
|
||||||
if exists && now.Sub(lastUpdate) < lastUsedUpdateInterval {
|
|
||||||
i.lastUsedMu.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
i.lastUsedTimes[tokenID] = now
|
|
||||||
i.lastUsedMu.Unlock()
|
|
||||||
|
|
||||||
if err := i.store.MarkProxyAccessTokenUsed(ctx, tokenID); err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("failed to mark proxy token as used: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *proxyAuthInterceptor) lastUsedCleanupLoop(ctx context.Context) {
|
|
||||||
ticker := time.NewTicker(lastUsedCleanupInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
i.cleanupStaleLastUsed()
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanupStaleLastUsed removes entries older than 2x the update interval.
|
|
||||||
func (i *proxyAuthInterceptor) cleanupStaleLastUsed() {
|
|
||||||
i.lastUsedMu.Lock()
|
|
||||||
defer i.lastUsedMu.Unlock()
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
staleThreshold := 2 * lastUsedUpdateInterval
|
|
||||||
for id, lastUpdate := range i.lastUsedTimes {
|
|
||||||
if now.Sub(lastUpdate) > staleThreshold {
|
|
||||||
delete(i.lastUsedTimes, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *proxyAuthInterceptor) close() {
|
|
||||||
i.cancel()
|
|
||||||
i.failureLimiter.stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetProxyTokenFromContext retrieves the validated proxy token from the context
|
|
||||||
func GetProxyTokenFromContext(ctx context.Context) *types.ProxyAccessToken {
|
|
||||||
token, ok := ctx.Value(ProxyTokenContextKey).(*types.ProxyAccessToken)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return token
|
|
||||||
}
|
|
||||||
|
|
||||||
// wrappedServerStream wraps a grpc.ServerStream to provide a custom context
|
|
||||||
type wrappedServerStream struct {
|
|
||||||
grpc.ServerStream
|
|
||||||
ctx context.Context
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *wrappedServerStream) Context() context.Context {
|
|
||||||
return w.ctx
|
|
||||||
}
|
|
||||||
@@ -1,134 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
|
||||||
"golang.org/x/time/rate"
|
|
||||||
"google.golang.org/grpc/peer"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// proxyAuthFailureBurst is the maximum number of failed attempts before rate limiting kicks in.
|
|
||||||
proxyAuthFailureBurst = 5
|
|
||||||
// proxyAuthLimiterCleanup is how often stale limiters are removed.
|
|
||||||
proxyAuthLimiterCleanup = 5 * time.Minute
|
|
||||||
// proxyAuthLimiterTTL is how long a limiter is kept after the last failure.
|
|
||||||
proxyAuthLimiterTTL = 15 * time.Minute
|
|
||||||
)
|
|
||||||
|
|
||||||
// defaultProxyAuthFailureRate is the token replenishment rate for failed auth attempts.
|
|
||||||
// One token every 12 seconds = 5 per minute.
|
|
||||||
var defaultProxyAuthFailureRate = rate.Every(12 * time.Second)
|
|
||||||
|
|
||||||
// clientIP identifies a client by its IP address for rate limiting purposes.
|
|
||||||
type clientIP = string
|
|
||||||
|
|
||||||
type limiterEntry struct {
|
|
||||||
limiter *rate.Limiter
|
|
||||||
lastAccess time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// authFailureLimiter tracks per-IP rate limits for failed proxy authentication attempts.
|
|
||||||
type authFailureLimiter struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
limiters map[clientIP]*limiterEntry
|
|
||||||
failureRate rate.Limit
|
|
||||||
cancel context.CancelFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
func newAuthFailureLimiter() *authFailureLimiter {
|
|
||||||
return newAuthFailureLimiterWithRate(defaultProxyAuthFailureRate)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newAuthFailureLimiterWithRate(failureRate rate.Limit) *authFailureLimiter {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
l := &authFailureLimiter{
|
|
||||||
limiters: make(map[clientIP]*limiterEntry),
|
|
||||||
failureRate: failureRate,
|
|
||||||
cancel: cancel,
|
|
||||||
}
|
|
||||||
go l.cleanupLoop(ctx)
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
// isLimited returns true if the given IP has exhausted its failure budget.
|
|
||||||
func (l *authFailureLimiter) isLimited(ip clientIP) bool {
|
|
||||||
l.mu.Lock()
|
|
||||||
defer l.mu.Unlock()
|
|
||||||
|
|
||||||
entry, exists := l.limiters[ip]
|
|
||||||
if !exists {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return entry.limiter.Tokens() < 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// recordFailure consumes a token from the rate limiter for the given IP.
|
|
||||||
func (l *authFailureLimiter) recordFailure(ip clientIP) {
|
|
||||||
l.mu.Lock()
|
|
||||||
defer l.mu.Unlock()
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
entry, exists := l.limiters[ip]
|
|
||||||
if !exists {
|
|
||||||
entry = &limiterEntry{
|
|
||||||
limiter: rate.NewLimiter(l.failureRate, proxyAuthFailureBurst),
|
|
||||||
}
|
|
||||||
l.limiters[ip] = entry
|
|
||||||
}
|
|
||||||
entry.lastAccess = now
|
|
||||||
entry.limiter.Allow()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *authFailureLimiter) cleanupLoop(ctx context.Context) {
|
|
||||||
ticker := time.NewTicker(proxyAuthLimiterCleanup)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
l.cleanup()
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *authFailureLimiter) cleanup() {
|
|
||||||
l.mu.Lock()
|
|
||||||
defer l.mu.Unlock()
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
for ip, entry := range l.limiters {
|
|
||||||
if now.Sub(entry.lastAccess) > proxyAuthLimiterTTL {
|
|
||||||
delete(l.limiters, ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *authFailureLimiter) stop() {
|
|
||||||
l.cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
// peerIPFromContext extracts the client IP from the gRPC context.
|
|
||||||
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
|
|
||||||
func peerIPFromContext(ctx context.Context) clientIP {
|
|
||||||
if addr, ok := realip.FromContext(ctx); ok {
|
|
||||||
return addr.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
if p, ok := peer.FromContext(ctx); ok {
|
|
||||||
host, _, err := net.SplitHostPort(p.Addr.String())
|
|
||||||
if err != nil {
|
|
||||||
return p.Addr.String()
|
|
||||||
}
|
|
||||||
return host
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/time/rate"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAuthFailureLimiter_NotLimitedInitially(t *testing.T) {
|
|
||||||
l := newAuthFailureLimiter()
|
|
||||||
defer l.stop()
|
|
||||||
|
|
||||||
assert.False(t, l.isLimited("192.168.1.1"), "new IP should not be rate limited")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthFailureLimiter_LimitedAfterBurst(t *testing.T) {
|
|
||||||
l := newAuthFailureLimiter()
|
|
||||||
defer l.stop()
|
|
||||||
|
|
||||||
ip := "192.168.1.1"
|
|
||||||
for i := 0; i < proxyAuthFailureBurst; i++ {
|
|
||||||
l.recordFailure(ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.True(t, l.isLimited(ip), "IP should be limited after exhausting burst")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthFailureLimiter_DifferentIPsIndependent(t *testing.T) {
|
|
||||||
l := newAuthFailureLimiter()
|
|
||||||
defer l.stop()
|
|
||||||
|
|
||||||
for i := 0; i < proxyAuthFailureBurst; i++ {
|
|
||||||
l.recordFailure("192.168.1.1")
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.True(t, l.isLimited("192.168.1.1"))
|
|
||||||
assert.False(t, l.isLimited("192.168.1.2"), "different IP should not be affected")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthFailureLimiter_RecoveryOverTime(t *testing.T) {
|
|
||||||
l := newAuthFailureLimiterWithRate(rate.Limit(100)) // 100 tokens/sec for fast recovery
|
|
||||||
defer l.stop()
|
|
||||||
|
|
||||||
ip := "10.0.0.1"
|
|
||||||
|
|
||||||
// Exhaust burst
|
|
||||||
for i := 0; i < proxyAuthFailureBurst; i++ {
|
|
||||||
l.recordFailure(ip)
|
|
||||||
}
|
|
||||||
require.True(t, l.isLimited(ip))
|
|
||||||
|
|
||||||
// Wait for token replenishment
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
|
|
||||||
assert.False(t, l.isLimited(ip), "should recover after tokens replenish")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthFailureLimiter_Cleanup(t *testing.T) {
|
|
||||||
l := newAuthFailureLimiter()
|
|
||||||
defer l.stop()
|
|
||||||
|
|
||||||
l.recordFailure("10.0.0.1")
|
|
||||||
|
|
||||||
l.mu.Lock()
|
|
||||||
require.Len(t, l.limiters, 1)
|
|
||||||
// Backdate the entry so it looks stale
|
|
||||||
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
|
|
||||||
l.mu.Unlock()
|
|
||||||
|
|
||||||
l.cleanup()
|
|
||||||
|
|
||||||
l.mu.Lock()
|
|
||||||
assert.Empty(t, l.limiters, "stale entries should be cleaned up")
|
|
||||||
l.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthFailureLimiter_CleanupKeepsFresh(t *testing.T) {
|
|
||||||
l := newAuthFailureLimiter()
|
|
||||||
defer l.stop()
|
|
||||||
|
|
||||||
l.recordFailure("10.0.0.1")
|
|
||||||
l.recordFailure("10.0.0.2")
|
|
||||||
|
|
||||||
l.mu.Lock()
|
|
||||||
// Only backdate one entry
|
|
||||||
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
|
|
||||||
l.mu.Unlock()
|
|
||||||
|
|
||||||
l.cleanup()
|
|
||||||
|
|
||||||
l.mu.Lock()
|
|
||||||
assert.Len(t, l.limiters, 1, "only stale entries should be removed")
|
|
||||||
assert.Contains(t, l.limiters, "10.0.0.2")
|
|
||||||
l.mu.Unlock()
|
|
||||||
}
|
|
||||||
@@ -1,381 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockReverseProxyManager struct {
|
|
||||||
proxiesByAccount map[string][]*reverseproxy.Service
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
|
||||||
if m.err != nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
return m.proxiesByAccount[accountID], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
|
||||||
return []*reverseproxy.Service{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
|
|
||||||
return &reverseproxy.Service{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
|
||||||
return &reverseproxy.Service{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
|
||||||
return &reverseproxy.Service{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID, reverseProxyID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
|
|
||||||
return &reverseproxy.Service{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockUsersManager struct {
|
|
||||||
users map[string]*types.User
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
|
|
||||||
if m.err != nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
user, ok := m.users[userID]
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("user not found")
|
|
||||||
}
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateUserGroupAccess(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
domain string
|
|
||||||
userID string
|
|
||||||
proxiesByAccount map[string][]*reverseproxy.Service
|
|
||||||
users map[string]*types.User
|
|
||||||
proxyErr error
|
|
||||||
userErr error
|
|
||||||
expectErr bool
|
|
||||||
expectErrMsg string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "user not found",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "unknown-user",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{},
|
|
||||||
expectErr: true,
|
|
||||||
expectErrMsg: "user not found",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "proxy not found in user's account",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
expectErr: true,
|
|
||||||
expectErrMsg: "service not found",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "proxy exists in different account - not accessible",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
expectErr: true,
|
|
||||||
expectErrMsg: "service not found",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no bearer auth configured - same account allows access",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "bearer auth disabled - same account allows access",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {{
|
|
||||||
Domain: "app.example.com",
|
|
||||||
AccountID: "account1",
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "bearer auth enabled but no groups configured - same account allows access",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {{
|
|
||||||
Domain: "app.example.com",
|
|
||||||
AccountID: "account1",
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
DistributionGroups: []string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user not in allowed groups",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {{
|
|
||||||
Domain: "app.example.com",
|
|
||||||
AccountID: "account1",
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
DistributionGroups: []string{"group1", "group2"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group3", "group4"}},
|
|
||||||
},
|
|
||||||
expectErr: true,
|
|
||||||
expectErrMsg: "not in allowed groups",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user in one of the allowed groups - allow access",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {{
|
|
||||||
Domain: "app.example.com",
|
|
||||||
AccountID: "account1",
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
DistributionGroups: []string{"group1", "group2"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group2", "group3"}},
|
|
||||||
},
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user in all allowed groups - allow access",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {{
|
|
||||||
Domain: "app.example.com",
|
|
||||||
AccountID: "account1",
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
DistributionGroups: []string{"group1", "group2"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group1", "group2", "group3"}},
|
|
||||||
},
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "proxy manager error",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: nil,
|
|
||||||
proxyErr: errors.New("database error"),
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
expectErr: true,
|
|
||||||
expectErrMsg: "get account services",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple proxies in account - finds correct one",
|
|
||||||
domain: "app2.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {
|
|
||||||
{Domain: "app1.example.com", AccountID: "account1"},
|
|
||||||
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
|
|
||||||
{Domain: "app3.example.com", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
server := &ProxyServiceServer{
|
|
||||||
reverseProxyManager: &mockReverseProxyManager{
|
|
||||||
proxiesByAccount: tt.proxiesByAccount,
|
|
||||||
err: tt.proxyErr,
|
|
||||||
},
|
|
||||||
usersManager: &mockUsersManager{
|
|
||||||
users: tt.users,
|
|
||||||
err: tt.userErr,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err := server.ValidateUserGroupAccess(context.Background(), tt.domain, tt.userID)
|
|
||||||
|
|
||||||
if tt.expectErr {
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Contains(t, err.Error(), tt.expectErrMsg)
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetAccountProxyByDomain(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
accountID string
|
|
||||||
domain string
|
|
||||||
proxiesByAccount map[string][]*reverseproxy.Service
|
|
||||||
err error
|
|
||||||
expectProxy bool
|
|
||||||
expectErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "proxy found",
|
|
||||||
accountID: "account1",
|
|
||||||
domain: "app.example.com",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {
|
|
||||||
{Domain: "other.example.com", AccountID: "account1"},
|
|
||||||
{Domain: "app.example.com", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectProxy: true,
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "proxy not found in account",
|
|
||||||
accountID: "account1",
|
|
||||||
domain: "unknown.example.com",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
|
||||||
},
|
|
||||||
expectProxy: false,
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty proxy list for account",
|
|
||||||
accountID: "account1",
|
|
||||||
domain: "app.example.com",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
|
||||||
expectProxy: false,
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "manager error",
|
|
||||||
accountID: "account1",
|
|
||||||
domain: "app.example.com",
|
|
||||||
proxiesByAccount: nil,
|
|
||||||
err: errors.New("database error"),
|
|
||||||
expectProxy: false,
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
server := &ProxyServiceServer{
|
|
||||||
reverseProxyManager: &mockReverseProxyManager{
|
|
||||||
proxiesByAccount: tt.proxiesByAccount,
|
|
||||||
err: tt.err,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy, err := server.getAccountServiceByDomain(context.Background(), tt.accountID, tt.domain)
|
|
||||||
|
|
||||||
if tt.expectErr {
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Nil(t, proxy)
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, proxy)
|
|
||||||
assert.Equal(t, tt.domain, proxy.Domain)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,232 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// registerFakeProxy adds a fake proxy connection to the server's internal maps
|
|
||||||
// and returns the channel where messages will be received.
|
|
||||||
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping {
|
|
||||||
ch := make(chan *proto.ProxyMapping, 10)
|
|
||||||
conn := &proxyConnection{
|
|
||||||
proxyID: proxyID,
|
|
||||||
address: clusterAddr,
|
|
||||||
sendChan: ch,
|
|
||||||
}
|
|
||||||
s.connectedProxies.Store(proxyID, conn)
|
|
||||||
|
|
||||||
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
|
||||||
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
|
||||||
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
|
|
||||||
select {
|
|
||||||
case msg := <-ch:
|
|
||||||
return msg
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
|
||||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
|
||||||
defer tokenStore.Close()
|
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
|
||||||
tokenStore: tokenStore,
|
|
||||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
|
||||||
}
|
|
||||||
|
|
||||||
const cluster = "proxy.example.com"
|
|
||||||
const numProxies = 3
|
|
||||||
|
|
||||||
channels := make([]chan *proto.ProxyMapping, numProxies)
|
|
||||||
for i := range numProxies {
|
|
||||||
id := "proxy-" + string(rune('a'+i))
|
|
||||||
channels[i] = registerFakeProxy(s, id, cluster)
|
|
||||||
}
|
|
||||||
|
|
||||||
update := &proto.ProxyMapping{
|
|
||||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
|
||||||
Id: "service-1",
|
|
||||||
AccountId: "account-1",
|
|
||||||
Domain: "test.example.com",
|
|
||||||
Path: []*proto.PathMapping{
|
|
||||||
{Path: "/", Target: "http://10.0.0.1:8080/"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
s.SendServiceUpdateToCluster(update, cluster)
|
|
||||||
|
|
||||||
tokens := make([]string, numProxies)
|
|
||||||
for i, ch := range channels {
|
|
||||||
msg := drainChannel(ch)
|
|
||||||
require.NotNil(t, msg, "proxy %d should receive a message", i)
|
|
||||||
assert.Equal(t, update.Domain, msg.Domain)
|
|
||||||
assert.Equal(t, update.Id, msg.Id)
|
|
||||||
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
|
|
||||||
tokens[i] = msg.AuthToken
|
|
||||||
}
|
|
||||||
|
|
||||||
// All tokens must be unique
|
|
||||||
tokenSet := make(map[string]struct{})
|
|
||||||
for i, tok := range tokens {
|
|
||||||
_, exists := tokenSet[tok]
|
|
||||||
assert.False(t, exists, "proxy %d got duplicate token", i)
|
|
||||||
tokenSet[tok] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Each token must be independently consumable
|
|
||||||
for i, tok := range tokens {
|
|
||||||
err := tokenStore.ValidateAndConsume(tok, "account-1", "service-1")
|
|
||||||
assert.NoError(t, err, "proxy %d token should validate successfully", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
|
||||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
|
||||||
defer tokenStore.Close()
|
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
|
||||||
tokenStore: tokenStore,
|
|
||||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
|
||||||
}
|
|
||||||
|
|
||||||
const cluster = "proxy.example.com"
|
|
||||||
ch1 := registerFakeProxy(s, "proxy-a", cluster)
|
|
||||||
ch2 := registerFakeProxy(s, "proxy-b", cluster)
|
|
||||||
|
|
||||||
update := &proto.ProxyMapping{
|
|
||||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
|
||||||
Id: "service-1",
|
|
||||||
AccountId: "account-1",
|
|
||||||
Domain: "test.example.com",
|
|
||||||
}
|
|
||||||
|
|
||||||
s.SendServiceUpdateToCluster(update, cluster)
|
|
||||||
|
|
||||||
msg1 := drainChannel(ch1)
|
|
||||||
msg2 := drainChannel(ch2)
|
|
||||||
require.NotNil(t, msg1)
|
|
||||||
require.NotNil(t, msg2)
|
|
||||||
|
|
||||||
// Delete operations should not generate tokens
|
|
||||||
assert.Empty(t, msg1.AuthToken)
|
|
||||||
assert.Empty(t, msg2.AuthToken)
|
|
||||||
|
|
||||||
// No tokens should have been created
|
|
||||||
assert.Equal(t, 0, tokenStore.GetTokenCount())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
|
||||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
|
||||||
defer tokenStore.Close()
|
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
|
||||||
tokenStore: tokenStore,
|
|
||||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
|
|
||||||
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
|
|
||||||
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
|
|
||||||
|
|
||||||
update := &proto.ProxyMapping{
|
|
||||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
|
||||||
Id: "service-1",
|
|
||||||
AccountId: "account-1",
|
|
||||||
Domain: "test.example.com",
|
|
||||||
}
|
|
||||||
|
|
||||||
s.SendServiceUpdate(update)
|
|
||||||
|
|
||||||
msg1 := drainChannel(ch1)
|
|
||||||
msg2 := drainChannel(ch2)
|
|
||||||
require.NotNil(t, msg1)
|
|
||||||
require.NotNil(t, msg2)
|
|
||||||
|
|
||||||
assert.NotEmpty(t, msg1.AuthToken)
|
|
||||||
assert.NotEmpty(t, msg2.AuthToken)
|
|
||||||
assert.NotEqual(t, msg1.AuthToken, msg2.AuthToken, "tokens must be unique per proxy")
|
|
||||||
|
|
||||||
// Both tokens should validate
|
|
||||||
assert.NoError(t, tokenStore.ValidateAndConsume(msg1.AuthToken, "account-1", "service-1"))
|
|
||||||
assert.NoError(t, tokenStore.ValidateAndConsume(msg2.AuthToken, "account-1", "service-1"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateState creates a state using the same format as GetOIDCURL.
|
|
||||||
func generateState(s *ProxyServiceServer, redirectURL string) string {
|
|
||||||
nonce := make([]byte, 16)
|
|
||||||
_, _ = rand.Read(nonce)
|
|
||||||
nonceB64 := base64.URLEncoding.EncodeToString(nonce)
|
|
||||||
|
|
||||||
payload := redirectURL + "|" + nonceB64
|
|
||||||
hmacSum := s.generateHMAC(payload)
|
|
||||||
return base64.URLEncoding.EncodeToString([]byte(redirectURL)) + "|" + nonceB64 + "|" + hmacSum
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOAuthState_NeverTheSame(t *testing.T) {
|
|
||||||
s := &ProxyServiceServer{
|
|
||||||
oidcConfig: ProxyOIDCConfig{
|
|
||||||
HMACKey: []byte("test-hmac-key"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
redirectURL := "https://app.example.com/callback"
|
|
||||||
|
|
||||||
// Generate 100 states for the same redirect URL
|
|
||||||
states := make(map[string]bool)
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
state := generateState(s, redirectURL)
|
|
||||||
|
|
||||||
// State must have 3 parts: base64(url)|nonce|hmac
|
|
||||||
parts := strings.Split(state, "|")
|
|
||||||
require.Equal(t, 3, len(parts), "state must have 3 parts")
|
|
||||||
|
|
||||||
// State must be unique
|
|
||||||
require.False(t, states[state], "state %d is a duplicate", i)
|
|
||||||
states[state] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
|
||||||
s := &ProxyServiceServer{
|
|
||||||
oidcConfig: ProxyOIDCConfig{
|
|
||||||
HMACKey: []byte("test-hmac-key"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Old format had only 2 parts: base64(url)|hmac
|
|
||||||
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
|
||||||
|
|
||||||
_, _, err := s.ValidateState("base64url|hmac")
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Contains(t, err.Error(), "invalid state format")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
|
||||||
s := &ProxyServiceServer{
|
|
||||||
oidcConfig: ProxyOIDCConfig{
|
|
||||||
HMACKey: []byte("test-hmac-key"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store with tampered HMAC
|
|
||||||
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
|
||||||
|
|
||||||
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Contains(t, err.Error(), "invalid state signature")
|
|
||||||
}
|
|
||||||
@@ -77,9 +77,8 @@ type Server struct {
|
|||||||
|
|
||||||
oAuthConfigProvider idp.OAuthConfigProvider
|
oAuthConfigProvider idp.OAuthConfigProvider
|
||||||
|
|
||||||
syncSem atomic.Int32
|
syncSem atomic.Int32
|
||||||
syncLimEnabled bool
|
syncLim int32
|
||||||
syncLim int32
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new Management server
|
// NewServer creates a new Management server
|
||||||
@@ -109,7 +108,6 @@ func NewServer(
|
|||||||
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
|
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
|
||||||
|
|
||||||
syncLim := int32(defaultSyncLim)
|
syncLim := int32(defaultSyncLim)
|
||||||
syncLimEnabled := true
|
|
||||||
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
|
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
|
||||||
syncLimParsed, err := strconv.Atoi(syncLimStr)
|
syncLimParsed, err := strconv.Atoi(syncLimStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -117,9 +115,6 @@ func NewServer(
|
|||||||
} else {
|
} else {
|
||||||
//nolint:gosec
|
//nolint:gosec
|
||||||
syncLim = int32(syncLimParsed)
|
syncLim = int32(syncLimParsed)
|
||||||
if syncLim < 0 {
|
|
||||||
syncLimEnabled = false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,8 +134,7 @@ func NewServer(
|
|||||||
|
|
||||||
loginFilter: newLoginFilter(),
|
loginFilter: newLoginFilter(),
|
||||||
|
|
||||||
syncLim: syncLim,
|
syncLim: syncLim,
|
||||||
syncLimEnabled: syncLimEnabled,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -218,7 +212,7 @@ func (s *Server) Job(srv proto.ManagementService_JobServer) error {
|
|||||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||||
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||||
if s.syncLimEnabled && s.syncSem.Load() >= s.syncLim {
|
if s.syncSem.Load() >= s.syncLim {
|
||||||
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
|
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
|
||||||
}
|
}
|
||||||
s.syncSem.Add(1)
|
s.syncSem.Add(1)
|
||||||
@@ -300,7 +294,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
metahash := metaHash(peerMeta, realIP.String())
|
metahash := metaHash(peerMeta, realIP.String())
|
||||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||||
|
|
||||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart)
|
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
@@ -311,7 +305,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -319,7 +313,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
|
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -336,7 +330,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
|
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
|
|
||||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart)
|
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
|
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
|
||||||
@@ -404,20 +398,11 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||||
// It implements a backpressure mechanism that sends the first update immediately,
|
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||||
// then debounces subsequent rapid updates, ensuring only the latest update is sent
|
|
||||||
// after a quiet period.
|
|
||||||
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
|
|
||||||
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
||||||
|
|
||||||
// Create a debouncer for this peer connection
|
|
||||||
debouncer := NewUpdateDebouncer(1000 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
// condition when there are some updates
|
// condition when there are some updates
|
||||||
// todo set the updates channel size to 1
|
|
||||||
case update, open := <-updates:
|
case update, open := <-updates:
|
||||||
if s.appMetrics != nil {
|
if s.appMetrics != nil {
|
||||||
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
|
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
|
||||||
@@ -425,38 +410,20 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
|||||||
|
|
||||||
if !open {
|
if !open {
|
||||||
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
|
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||||
if debouncer.ProcessUpdate(update) {
|
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
|
||||||
// Send immediately (first update or after quiet period)
|
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
||||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
|
return err
|
||||||
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Timer expired - quiet period reached, send pending updates if any
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
if len(pendingUpdates) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
log.WithContext(ctx).Debugf("sending %d debounced update(s) for peer %s", len(pendingUpdates), peerKey.String())
|
|
||||||
for _, pendingUpdate := range pendingUpdates {
|
|
||||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, pendingUpdate, srv, streamStartTime); err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// condition when client <-> server connection has been terminated
|
// condition when client <-> server connection has been terminated
|
||||||
case <-srv.Context().Done():
|
case <-srv.Context().Done():
|
||||||
// happens when connection drops, e.g. client disconnects
|
// happens when connection drops, e.g. client disconnects
|
||||||
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||||
return srv.Context().Err()
|
return srv.Context().Err()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -464,16 +431,16 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
|||||||
|
|
||||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||||
// then sends the encrypted message to the connected peer via the sync server.
|
// then sends the encrypted message to the connected peer via the sync server.
|
||||||
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
|
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||||
key, err := s.secretsManager.GetWGKey()
|
key, err := s.secretsManager.GetWGKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||||
return status.Errorf(codes.Internal, "failed processing update message")
|
return status.Errorf(codes.Internal, "failed processing update message")
|
||||||
}
|
}
|
||||||
|
|
||||||
encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
|
encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||||
return status.Errorf(codes.Internal, "failed processing update message")
|
return status.Errorf(codes.Internal, "failed processing update message")
|
||||||
}
|
}
|
||||||
err = srv.Send(&proto.EncryptedMessage{
|
err = srv.Send(&proto.EncryptedMessage{
|
||||||
@@ -481,7 +448,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
|
|||||||
Body: encryptedResp,
|
Body: encryptedResp,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||||
return status.Errorf(codes.Internal, "failed sending update message")
|
return status.Errorf(codes.Internal, "failed sending update message")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||||
@@ -513,15 +480,11 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
|
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
|
||||||
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key, streamStartTime)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -242,10 +242,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
|
|||||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
|
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
|
||||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
|
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
|
||||||
Update: update,
|
|
||||||
MessageType: network_map.MessageTypeControlConfig,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
|
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
|
||||||
@@ -269,10 +266,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
|
|||||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
|
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
|
||||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
|
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
|
||||||
Update: update,
|
|
||||||
MessageType: network_map.MessageTypeControlConfig,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
|
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
|
||||||
|
|||||||
@@ -1,103 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
|
||||||
)
|
|
||||||
|
|
||||||
// UpdateDebouncer implements a backpressure mechanism that:
|
|
||||||
// - Sends the first update immediately
|
|
||||||
// - Coalesces rapid subsequent network map updates (only latest matters)
|
|
||||||
// - Queues control/config updates (all must be delivered)
|
|
||||||
// - Preserves the order of messages (important for control configs between network maps)
|
|
||||||
// - Ensures pending updates are sent after a quiet period
|
|
||||||
type UpdateDebouncer struct {
|
|
||||||
debounceInterval time.Duration
|
|
||||||
timer *time.Timer
|
|
||||||
pendingUpdates []*network_map.UpdateMessage // Queue that preserves order
|
|
||||||
timerC <-chan time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewUpdateDebouncer creates a new debouncer with the specified interval
|
|
||||||
func NewUpdateDebouncer(interval time.Duration) *UpdateDebouncer {
|
|
||||||
return &UpdateDebouncer{
|
|
||||||
debounceInterval: interval,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessUpdate handles an incoming update and returns whether it should be sent immediately
|
|
||||||
func (d *UpdateDebouncer) ProcessUpdate(update *network_map.UpdateMessage) bool {
|
|
||||||
if d.timer == nil {
|
|
||||||
// No active debounce timer, signal to send immediately
|
|
||||||
// and start the debounce period
|
|
||||||
d.startTimer()
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Already in debounce period, accumulate this update preserving order
|
|
||||||
// Check if we should coalesce with the last pending update
|
|
||||||
if len(d.pendingUpdates) > 0 &&
|
|
||||||
update.MessageType == network_map.MessageTypeNetworkMap &&
|
|
||||||
d.pendingUpdates[len(d.pendingUpdates)-1].MessageType == network_map.MessageTypeNetworkMap {
|
|
||||||
// Replace the last network map with this one (coalesce consecutive network maps)
|
|
||||||
d.pendingUpdates[len(d.pendingUpdates)-1] = update
|
|
||||||
} else {
|
|
||||||
// Append to the queue (preserves order for control configs and non-consecutive network maps)
|
|
||||||
d.pendingUpdates = append(d.pendingUpdates, update)
|
|
||||||
}
|
|
||||||
d.resetTimer()
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// TimerChannel returns the timer channel for select statements
|
|
||||||
func (d *UpdateDebouncer) TimerChannel() <-chan time.Time {
|
|
||||||
if d.timer == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return d.timerC
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPendingUpdates returns and clears all pending updates after timer expiration.
|
|
||||||
// Updates are returned in the order they were received, with consecutive network maps
|
|
||||||
// already coalesced to only the latest one.
|
|
||||||
// If there were pending updates, it restarts the timer to continue debouncing.
|
|
||||||
// If there were no pending updates, it clears the timer (true quiet period).
|
|
||||||
func (d *UpdateDebouncer) GetPendingUpdates() []*network_map.UpdateMessage {
|
|
||||||
updates := d.pendingUpdates
|
|
||||||
d.pendingUpdates = nil
|
|
||||||
|
|
||||||
if len(updates) > 0 {
|
|
||||||
// There were pending updates, so updates are still coming rapidly
|
|
||||||
// Restart the timer to continue debouncing mode
|
|
||||||
if d.timer != nil {
|
|
||||||
d.timer.Reset(d.debounceInterval)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// No pending updates means true quiet period - return to immediate mode
|
|
||||||
d.timer = nil
|
|
||||||
d.timerC = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return updates
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop stops the debouncer and cleans up resources
|
|
||||||
func (d *UpdateDebouncer) Stop() {
|
|
||||||
if d.timer != nil {
|
|
||||||
d.timer.Stop()
|
|
||||||
d.timer = nil
|
|
||||||
d.timerC = nil
|
|
||||||
}
|
|
||||||
d.pendingUpdates = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *UpdateDebouncer) startTimer() {
|
|
||||||
d.timer = time.NewTimer(d.debounceInterval)
|
|
||||||
d.timerC = d.timer.C
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *UpdateDebouncer) resetTimer() {
|
|
||||||
d.timer.Stop()
|
|
||||||
d.timer.Reset(d.debounceInterval)
|
|
||||||
}
|
|
||||||
@@ -1,587 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_FirstUpdateSentImmediately(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
update := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
shouldSend := debouncer.ProcessUpdate(update)
|
|
||||||
|
|
||||||
if !shouldSend {
|
|
||||||
t.Error("First update should be sent immediately")
|
|
||||||
}
|
|
||||||
|
|
||||||
if debouncer.TimerChannel() == nil {
|
|
||||||
t.Error("Timer should be started after first update")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_RapidUpdatesCoalesced(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
update1 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
update2 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
update3 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
// First update should be sent immediately
|
|
||||||
if !debouncer.ProcessUpdate(update1) {
|
|
||||||
t.Error("First update should be sent immediately")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rapid subsequent updates should be coalesced
|
|
||||||
if debouncer.ProcessUpdate(update2) {
|
|
||||||
t.Error("Second rapid update should not be sent immediately")
|
|
||||||
}
|
|
||||||
|
|
||||||
if debouncer.ProcessUpdate(update3) {
|
|
||||||
t.Error("Third rapid update should not be sent immediately")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for debounce period
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
if len(pendingUpdates) != 1 {
|
|
||||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
|
||||||
}
|
|
||||||
if pendingUpdates[0] != update3 {
|
|
||||||
t.Error("Should get the last update (update3)")
|
|
||||||
}
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_LastUpdateAlwaysSent(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
update1 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
update2 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send first update
|
|
||||||
debouncer.ProcessUpdate(update1)
|
|
||||||
|
|
||||||
// Send second update within debounce period
|
|
||||||
debouncer.ProcessUpdate(update2)
|
|
||||||
|
|
||||||
// Wait for timer
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
if len(pendingUpdates) != 1 {
|
|
||||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
|
||||||
}
|
|
||||||
if pendingUpdates[0] != update2 {
|
|
||||||
t.Error("Should get the last update")
|
|
||||||
}
|
|
||||||
if pendingUpdates[0] == update1 {
|
|
||||||
t.Error("Should not get the first update")
|
|
||||||
}
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_TimerResetOnNewUpdate(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
update1 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
update2 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
update3 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send first update
|
|
||||||
debouncer.ProcessUpdate(update1)
|
|
||||||
|
|
||||||
// Wait a bit, but not the full debounce period
|
|
||||||
time.Sleep(30 * time.Millisecond)
|
|
||||||
|
|
||||||
// Send second update - should reset timer
|
|
||||||
debouncer.ProcessUpdate(update2)
|
|
||||||
|
|
||||||
// Wait a bit more
|
|
||||||
time.Sleep(30 * time.Millisecond)
|
|
||||||
|
|
||||||
// Send third update - should reset timer again
|
|
||||||
debouncer.ProcessUpdate(update3)
|
|
||||||
|
|
||||||
// Now wait for the timer (should fire after last update's reset)
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
if len(pendingUpdates) != 1 {
|
|
||||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
|
||||||
}
|
|
||||||
if pendingUpdates[0] != update3 {
|
|
||||||
t.Error("Should get the last update (update3)")
|
|
||||||
}
|
|
||||||
// Timer should be restarted since there was a pending update
|
|
||||||
if debouncer.TimerChannel() == nil {
|
|
||||||
t.Error("Timer should be restarted after sending pending update")
|
|
||||||
}
|
|
||||||
case <-time.After(150 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_TimerRestartsAfterPendingUpdateSent(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
update1 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
update2 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
update3 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
// First update sent immediately
|
|
||||||
debouncer.ProcessUpdate(update1)
|
|
||||||
|
|
||||||
// Second update coalesced
|
|
||||||
debouncer.ProcessUpdate(update2)
|
|
||||||
|
|
||||||
// Wait for timer to expire
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
|
|
||||||
if len(pendingUpdates) == 0 {
|
|
||||||
t.Fatal("Should have pending update")
|
|
||||||
}
|
|
||||||
|
|
||||||
// After sending pending update, timer is restarted, so next update is NOT immediate
|
|
||||||
if debouncer.ProcessUpdate(update3) {
|
|
||||||
t.Error("Update after debounced send should not be sent immediately (timer restarted)")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for the restarted timer and verify update3 is pending
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
finalUpdates := debouncer.GetPendingUpdates()
|
|
||||||
if len(finalUpdates) != 1 || finalUpdates[0] != update3 {
|
|
||||||
t.Error("Should get update3 as pending")
|
|
||||||
}
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired for restarted timer")
|
|
||||||
}
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_StopCleansUp(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
|
|
||||||
update := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send update to start timer
|
|
||||||
debouncer.ProcessUpdate(update)
|
|
||||||
|
|
||||||
// Stop should clean up
|
|
||||||
debouncer.Stop()
|
|
||||||
|
|
||||||
// Multiple stops should be safe
|
|
||||||
debouncer.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_HighFrequencyUpdates(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
// Simulate high-frequency updates
|
|
||||||
var lastUpdate *network_map.UpdateMessage
|
|
||||||
sentImmediately := 0
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
update := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{
|
|
||||||
NetworkMap: &proto.NetworkMap{
|
|
||||||
Serial: uint64(i),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
lastUpdate = update
|
|
||||||
if debouncer.ProcessUpdate(update) {
|
|
||||||
sentImmediately++
|
|
||||||
}
|
|
||||||
time.Sleep(1 * time.Millisecond) // Very rapid updates
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only first update should be sent immediately
|
|
||||||
if sentImmediately != 1 {
|
|
||||||
t.Errorf("Expected only 1 update sent immediately, got %d", sentImmediately)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for debounce period
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
if len(pendingUpdates) != 1 {
|
|
||||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
|
||||||
}
|
|
||||||
if pendingUpdates[0] != lastUpdate {
|
|
||||||
t.Error("Should get the very last update")
|
|
||||||
}
|
|
||||||
if pendingUpdates[0].Update.NetworkMap.Serial != 99 {
|
|
||||||
t.Errorf("Expected serial 99, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
|
||||||
}
|
|
||||||
case <-time.After(200 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_NoUpdatesAfterFirst(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
update := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send first update
|
|
||||||
if !debouncer.ProcessUpdate(update) {
|
|
||||||
t.Error("First update should be sent immediately")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for timer to expire with no additional updates (true quiet period)
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
if len(pendingUpdates) != 0 {
|
|
||||||
t.Error("Should have no pending updates")
|
|
||||||
}
|
|
||||||
// After true quiet period, timer should be cleared
|
|
||||||
if debouncer.TimerChannel() != nil {
|
|
||||||
t.Error("Timer should be cleared after quiet period")
|
|
||||||
}
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_IntermediateUpdatesDropped(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
updates := make([]*network_map.UpdateMessage, 5)
|
|
||||||
for i := range updates {
|
|
||||||
updates[i] = &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{
|
|
||||||
NetworkMap: &proto.NetworkMap{
|
|
||||||
Serial: uint64(i),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// First update sent immediately
|
|
||||||
debouncer.ProcessUpdate(updates[0])
|
|
||||||
|
|
||||||
// Send updates 1, 2, 3, 4 rapidly - only last one should remain pending
|
|
||||||
debouncer.ProcessUpdate(updates[1])
|
|
||||||
debouncer.ProcessUpdate(updates[2])
|
|
||||||
debouncer.ProcessUpdate(updates[3])
|
|
||||||
debouncer.ProcessUpdate(updates[4])
|
|
||||||
|
|
||||||
// Wait for debounce
|
|
||||||
<-debouncer.TimerChannel()
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
|
|
||||||
if len(pendingUpdates) != 1 {
|
|
||||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
|
||||||
}
|
|
||||||
if pendingUpdates[0].Update.NetworkMap.Serial != 4 {
|
|
||||||
t.Errorf("Expected only the last update (serial 4), got serial %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_TrueQuietPeriodResetsToImmediateMode(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
update1 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
update2 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
// First update sent immediately
|
|
||||||
if !debouncer.ProcessUpdate(update1) {
|
|
||||||
t.Error("First update should be sent immediately")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for timer without sending any more updates (true quiet period)
|
|
||||||
<-debouncer.TimerChannel()
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
|
|
||||||
if len(pendingUpdates) != 0 {
|
|
||||||
t.Error("Should have no pending updates during quiet period")
|
|
||||||
}
|
|
||||||
|
|
||||||
// After true quiet period, next update should be sent immediately
|
|
||||||
if !debouncer.ProcessUpdate(update2) {
|
|
||||||
t.Error("Update after true quiet period should be sent immediately")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_ContinuousHighFrequencyStaysInDebounceMode(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
// Simulate continuous high-frequency updates
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
update := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{
|
|
||||||
NetworkMap: &proto.NetworkMap{
|
|
||||||
Serial: uint64(i),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
if i == 0 {
|
|
||||||
// First one sent immediately
|
|
||||||
if !debouncer.ProcessUpdate(update) {
|
|
||||||
t.Error("First update should be sent immediately")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// All others should be coalesced (not sent immediately)
|
|
||||||
if debouncer.ProcessUpdate(update) {
|
|
||||||
t.Errorf("Update %d should not be sent immediately", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait a bit but send next update before debounce expires
|
|
||||||
time.Sleep(20 * time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now wait for final debounce
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
if len(pendingUpdates) == 0 {
|
|
||||||
t.Fatal("Should have the last update pending")
|
|
||||||
}
|
|
||||||
if pendingUpdates[0].Update.NetworkMap.Serial != 9 {
|
|
||||||
t.Errorf("Expected serial 9, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
|
||||||
}
|
|
||||||
case <-time.After(200 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_ControlConfigMessagesQueued(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
netmapUpdate := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
tokenUpdate1 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
|
||||||
MessageType: network_map.MessageTypeControlConfig,
|
|
||||||
}
|
|
||||||
tokenUpdate2 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
|
||||||
MessageType: network_map.MessageTypeControlConfig,
|
|
||||||
}
|
|
||||||
|
|
||||||
// First update sent immediately
|
|
||||||
debouncer.ProcessUpdate(netmapUpdate)
|
|
||||||
|
|
||||||
// Send multiple control config updates - they should all be queued
|
|
||||||
debouncer.ProcessUpdate(tokenUpdate1)
|
|
||||||
debouncer.ProcessUpdate(tokenUpdate2)
|
|
||||||
|
|
||||||
// Wait for debounce period
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
// Should get both control config updates
|
|
||||||
if len(pendingUpdates) != 2 {
|
|
||||||
t.Errorf("Expected 2 control config updates, got %d", len(pendingUpdates))
|
|
||||||
}
|
|
||||||
// Control configs should come first
|
|
||||||
if pendingUpdates[0] != tokenUpdate1 {
|
|
||||||
t.Error("First pending update should be tokenUpdate1")
|
|
||||||
}
|
|
||||||
if pendingUpdates[1] != tokenUpdate2 {
|
|
||||||
t.Error("Second pending update should be tokenUpdate2")
|
|
||||||
}
|
|
||||||
case <-time.After(200 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_MixedMessageTypes(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
netmapUpdate1 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
netmapUpdate2 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 2}},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
tokenUpdate := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
|
||||||
MessageType: network_map.MessageTypeControlConfig,
|
|
||||||
}
|
|
||||||
|
|
||||||
// First update sent immediately
|
|
||||||
debouncer.ProcessUpdate(netmapUpdate1)
|
|
||||||
|
|
||||||
// Send token update and network map update
|
|
||||||
debouncer.ProcessUpdate(tokenUpdate)
|
|
||||||
debouncer.ProcessUpdate(netmapUpdate2)
|
|
||||||
|
|
||||||
// Wait for debounce period
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
// Should get 2 updates in order: token, then network map
|
|
||||||
if len(pendingUpdates) != 2 {
|
|
||||||
t.Errorf("Expected 2 pending updates, got %d", len(pendingUpdates))
|
|
||||||
}
|
|
||||||
// Token update should come first (preserves order)
|
|
||||||
if pendingUpdates[0] != tokenUpdate {
|
|
||||||
t.Error("First pending update should be tokenUpdate")
|
|
||||||
}
|
|
||||||
// Network map update should come second
|
|
||||||
if pendingUpdates[1] != netmapUpdate2 {
|
|
||||||
t.Error("Second pending update should be netmapUpdate2")
|
|
||||||
}
|
|
||||||
case <-time.After(200 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_OrderPreservation(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
// Simulate: 50 network maps -> 1 control config -> 50 network maps
|
|
||||||
// Expected result: 3 messages (netmap, controlConfig, netmap)
|
|
||||||
|
|
||||||
// Send first network map immediately
|
|
||||||
firstNetmap := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 0}},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
if !debouncer.ProcessUpdate(firstNetmap) {
|
|
||||||
t.Error("First update should be sent immediately")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send 49 more network maps (will be coalesced to last one)
|
|
||||||
var lastNetmapBatch1 *network_map.UpdateMessage
|
|
||||||
for i := 1; i < 50; i++ {
|
|
||||||
lastNetmapBatch1 = &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
debouncer.ProcessUpdate(lastNetmapBatch1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send 1 control config
|
|
||||||
controlConfig := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
|
||||||
MessageType: network_map.MessageTypeControlConfig,
|
|
||||||
}
|
|
||||||
debouncer.ProcessUpdate(controlConfig)
|
|
||||||
|
|
||||||
// Send 50 more network maps (will be coalesced to last one)
|
|
||||||
var lastNetmapBatch2 *network_map.UpdateMessage
|
|
||||||
for i := 50; i < 100; i++ {
|
|
||||||
lastNetmapBatch2 = &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
debouncer.ProcessUpdate(lastNetmapBatch2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for debounce period
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
// Should get exactly 3 updates: netmap, controlConfig, netmap
|
|
||||||
if len(pendingUpdates) != 3 {
|
|
||||||
t.Errorf("Expected 3 pending updates, got %d", len(pendingUpdates))
|
|
||||||
}
|
|
||||||
// First should be the last netmap from batch 1
|
|
||||||
if pendingUpdates[0] != lastNetmapBatch1 {
|
|
||||||
t.Error("First pending update should be last netmap from batch 1")
|
|
||||||
}
|
|
||||||
if pendingUpdates[0].Update.NetworkMap.Serial != 49 {
|
|
||||||
t.Errorf("Expected serial 49, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
|
||||||
}
|
|
||||||
// Second should be the control config
|
|
||||||
if pendingUpdates[1] != controlConfig {
|
|
||||||
t.Error("Second pending update should be control config")
|
|
||||||
}
|
|
||||||
// Third should be the last netmap from batch 2
|
|
||||||
if pendingUpdates[2] != lastNetmapBatch2 {
|
|
||||||
t.Error("Third pending update should be last netmap from batch 2")
|
|
||||||
}
|
|
||||||
if pendingUpdates[2].Update.NetworkMap.Serial != 99 {
|
|
||||||
t.Errorf("Expected serial 99, got %d", pendingUpdates[2].Update.NetworkMap.Serial)
|
|
||||||
}
|
|
||||||
case <-time.After(200 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,304 +0,0 @@
|
|||||||
//go:build integration
|
|
||||||
|
|
||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
"github.com/netbirdio/netbird/proxy/auth"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
type validateSessionTestSetup struct {
|
|
||||||
proxyService *ProxyServiceServer
|
|
||||||
store store.Store
|
|
||||||
cleanup func()
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
proxyManager := &testValidateSessionProxyManager{store: testStore}
|
|
||||||
usersManager := &testValidateSessionUsersManager{store: testStore}
|
|
||||||
|
|
||||||
proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager)
|
|
||||||
proxyService.SetProxyManager(proxyManager)
|
|
||||||
|
|
||||||
createTestProxies(t, ctx, testStore)
|
|
||||||
|
|
||||||
return &validateSessionTestSetup{
|
|
||||||
proxyService: proxyService,
|
|
||||||
store: testStore,
|
|
||||||
cleanup: storeCleanup,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
pubKey, privKey := generateSessionKeyPair(t)
|
|
||||||
|
|
||||||
testProxy := &reverseproxy.Service{
|
|
||||||
ID: "testProxyId",
|
|
||||||
AccountID: "testAccountId",
|
|
||||||
Name: "Test Proxy",
|
|
||||||
Domain: "test-proxy.example.com",
|
|
||||||
Enabled: true,
|
|
||||||
SessionPrivateKey: privKey,
|
|
||||||
SessionPublicKey: pubKey,
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
require.NoError(t, testStore.CreateService(ctx, testProxy))
|
|
||||||
|
|
||||||
restrictedProxy := &reverseproxy.Service{
|
|
||||||
ID: "restrictedProxyId",
|
|
||||||
AccountID: "testAccountId",
|
|
||||||
Name: "Restricted Proxy",
|
|
||||||
Domain: "restricted-proxy.example.com",
|
|
||||||
Enabled: true,
|
|
||||||
SessionPrivateKey: privKey,
|
|
||||||
SessionPublicKey: pubKey,
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
DistributionGroups: []string{"allowedGroupId"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateSessionKeyPair(t *testing.T) (string, string) {
|
|
||||||
t.Helper()
|
|
||||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
|
||||||
require.NoError(t, err)
|
|
||||||
return base64.StdEncoding.EncodeToString(pub), base64.StdEncoding.EncodeToString(priv)
|
|
||||||
}
|
|
||||||
|
|
||||||
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
|
|
||||||
t.Helper()
|
|
||||||
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
|
|
||||||
require.NoError(t, err)
|
|
||||||
return token
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSession_UserAllowed(t *testing.T) {
|
|
||||||
setup := setupValidateSessionTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "test-proxy.example.com")
|
|
||||||
|
|
||||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
|
||||||
Domain: "test-proxy.example.com",
|
|
||||||
SessionToken: token,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, resp.Valid, "User should be allowed access")
|
|
||||||
assert.Equal(t, "allowedUserId", resp.UserId)
|
|
||||||
assert.Empty(t, resp.DeniedReason)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
|
|
||||||
setup := setupValidateSessionTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "restrictedProxyId")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
token := createSessionToken(t, proxy.SessionPrivateKey, "nonGroupUserId", "restricted-proxy.example.com")
|
|
||||||
|
|
||||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
|
||||||
Domain: "restricted-proxy.example.com",
|
|
||||||
SessionToken: token,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, resp.Valid, "User not in group should be denied")
|
|
||||||
assert.Equal(t, "not_in_group", resp.DeniedReason)
|
|
||||||
assert.Equal(t, "nonGroupUserId", resp.UserId)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
|
|
||||||
setup := setupValidateSessionTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
token := createSessionToken(t, proxy.SessionPrivateKey, "otherAccountUserId", "test-proxy.example.com")
|
|
||||||
|
|
||||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
|
||||||
Domain: "test-proxy.example.com",
|
|
||||||
SessionToken: token,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, resp.Valid, "User in different account should be denied")
|
|
||||||
assert.Equal(t, "account_mismatch", resp.DeniedReason)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSession_UserNotFound(t *testing.T) {
|
|
||||||
setup := setupValidateSessionTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
token := createSessionToken(t, proxy.SessionPrivateKey, "nonExistentUserId", "test-proxy.example.com")
|
|
||||||
|
|
||||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
|
||||||
Domain: "test-proxy.example.com",
|
|
||||||
SessionToken: token,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, resp.Valid, "Non-existent user should be denied")
|
|
||||||
assert.Equal(t, "user_not_found", resp.DeniedReason)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSession_ProxyNotFound(t *testing.T) {
|
|
||||||
setup := setupValidateSessionTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "unknown-proxy.example.com")
|
|
||||||
|
|
||||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
|
||||||
Domain: "unknown-proxy.example.com",
|
|
||||||
SessionToken: token,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, resp.Valid, "Unknown proxy should be denied")
|
|
||||||
assert.Equal(t, "proxy_not_found", resp.DeniedReason)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSession_InvalidToken(t *testing.T) {
|
|
||||||
setup := setupValidateSessionTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
|
||||||
Domain: "test-proxy.example.com",
|
|
||||||
SessionToken: "invalid-token",
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, resp.Valid, "Invalid token should be denied")
|
|
||||||
assert.Equal(t, "invalid_token", resp.DeniedReason)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSession_MissingDomain(t *testing.T) {
|
|
||||||
setup := setupValidateSessionTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
|
||||||
SessionToken: "some-token",
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, resp.Valid)
|
|
||||||
assert.Contains(t, resp.DeniedReason, "missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSession_MissingToken(t *testing.T) {
|
|
||||||
setup := setupValidateSessionTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
|
||||||
Domain: "test-proxy.example.com",
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, resp.Valid)
|
|
||||||
assert.Contains(t, resp.DeniedReason, "missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
type testValidateSessionProxyManager struct {
|
|
||||||
store store.Store
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
|
||||||
return m.store.GetServices(ctx, store.LockingStrengthNone)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
|
|
||||||
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
|
||||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type testValidateSessionUsersManager struct {
|
|
||||||
store store.Store
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
|
|
||||||
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
|
||||||
}
|
|
||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
"github.com/netbirdio/netbird/shared/auth"
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
|
||||||
@@ -27,6 +26,7 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
|
||||||
"github.com/netbirdio/netbird/formatter/hook"
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
@@ -49,7 +49,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -83,9 +82,8 @@ type DefaultAccountManager struct {
|
|||||||
|
|
||||||
requestBuffer *AccountRequestBuffer
|
requestBuffer *AccountRequestBuffer
|
||||||
|
|
||||||
proxyController port_forwarding.Controller
|
proxyController port_forwarding.Controller
|
||||||
settingsManager settings.Manager
|
settingsManager settings.Manager
|
||||||
reverseProxyManager reverseproxy.Manager
|
|
||||||
|
|
||||||
// config contains the management server configuration
|
// config contains the management server configuration
|
||||||
config *nbconfig.Config
|
config *nbconfig.Config
|
||||||
@@ -115,10 +113,6 @@ type DefaultAccountManager struct {
|
|||||||
|
|
||||||
var _ account.Manager = (*DefaultAccountManager)(nil)
|
var _ account.Manager = (*DefaultAccountManager)(nil)
|
||||||
|
|
||||||
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
|
|
||||||
am.reverseProxyManager = serviceManager
|
|
||||||
}
|
|
||||||
|
|
||||||
func isUniqueConstraintError(err error) bool {
|
func isUniqueConstraintError(err error) bool {
|
||||||
switch {
|
switch {
|
||||||
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
|
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
|
||||||
@@ -327,9 +321,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
|
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
|
|
||||||
}
|
|
||||||
updateAccountPeers = true
|
updateAccountPeers = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -804,19 +795,6 @@ func IsEmbeddedIdp(i idp.Manager) bool {
|
|||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsLocalAuthDisabled checks if local (email/password) authentication is disabled.
|
|
||||||
// Returns true only when using embedded IDP with local auth disabled in config.
|
|
||||||
func IsLocalAuthDisabled(ctx context.Context, i idp.Manager) bool {
|
|
||||||
if isNil(i) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
embeddedIdp, ok := i.(*idp.EmbeddedIdPManager)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return embeddedIdp.IsLocalAuthDisabled()
|
|
||||||
}
|
|
||||||
|
|
||||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||||
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
|
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
|
||||||
@@ -1679,13 +1657,13 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
|
|||||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime)
|
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||||
}
|
}
|
||||||
@@ -1693,20 +1671,8 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
|
|||||||
return peer, netMap, postureChecks, dnsfwdPort, nil
|
return peer, netMap, postureChecks, dnsfwdPort, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
|
||||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
|
err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID)
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if peer.Status.LastSeen.After(streamStartTime) {
|
|
||||||
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect",
|
|
||||||
peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
"github.com/netbirdio/netbird/shared/auth"
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
@@ -59,7 +58,7 @@ type Manager interface {
|
|||||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
|
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
|
||||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||||
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||||
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||||
@@ -115,8 +114,8 @@ type Manager interface {
|
|||||||
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
||||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error
|
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
||||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||||
@@ -140,5 +139,4 @@ type Manager interface {
|
|||||||
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
||||||
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
||||||
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
||||||
SetServiceManager(serviceManager reverseproxy.Manager)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,8 +27,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||||
@@ -1802,14 +1800,6 @@ func TestAccount_Copy(t *testing.T) {
|
|||||||
Address: "172.12.6.1/24",
|
Address: "172.12.6.1/24",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Services: []*reverseproxy.Service{
|
|
||||||
{
|
|
||||||
ID: "service1",
|
|
||||||
Name: "test-service",
|
|
||||||
AccountID: "account1",
|
|
||||||
Targets: []*reverseproxy.Target{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||||
}
|
}
|
||||||
account.InitOnce()
|
account.InitOnce()
|
||||||
@@ -1891,7 +1881,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
|||||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||||
@@ -1962,7 +1952,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
failed := waitTimeout(wg, time.Second)
|
failed := waitTimeout(wg, time.Second)
|
||||||
@@ -1971,82 +1961,6 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
|
||||||
manager, _, err := createManager(t)
|
|
||||||
require.NoError(t, err, "unable to create account manager")
|
|
||||||
|
|
||||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
|
||||||
require.NoError(t, err, "unable to create an account")
|
|
||||||
|
|
||||||
key, err := wgtypes.GenerateKey()
|
|
||||||
require.NoError(t, err, "unable to generate WireGuard key")
|
|
||||||
peerPubKey := key.PublicKey().String()
|
|
||||||
|
|
||||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
|
||||||
Key: peerPubKey,
|
|
||||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
|
||||||
}, false)
|
|
||||||
require.NoError(t, err, "unable to add peer")
|
|
||||||
|
|
||||||
t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) {
|
|
||||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
|
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
|
||||||
|
|
||||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
|
||||||
require.NoError(t, err, "unable to get peer")
|
|
||||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
|
||||||
|
|
||||||
streamStartTime := time.Now().UTC()
|
|
||||||
|
|
||||||
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.False(t, peer.Status.Connected, "peer should be disconnected")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) {
|
|
||||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
|
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
|
||||||
|
|
||||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
|
||||||
|
|
||||||
streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour)
|
|
||||||
|
|
||||||
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.True(t, peer.Status.Connected,
|
|
||||||
"peer should remain connected because LastSeen > streamStartTime (zombie stream protection)")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) {
|
|
||||||
node2SyncTime := time.Now().UTC()
|
|
||||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime)
|
|
||||||
require.NoError(t, err, "node 2 should connect peer")
|
|
||||||
|
|
||||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
|
||||||
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime")
|
|
||||||
|
|
||||||
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
|
|
||||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime)
|
|
||||||
require.NoError(t, err, "stale connect should not return error")
|
|
||||||
|
|
||||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.True(t, peer.Status.Connected, "peer should still be connected")
|
|
||||||
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(),
|
|
||||||
"LastSeen should NOT be overwritten by stale syncTime from blocked goroutine")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
||||||
manager, _, err := createManager(t)
|
manager, _, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
@@ -2069,7 +1983,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
|||||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
@@ -3122,8 +3036,6 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil))
|
|
||||||
|
|
||||||
return manager, updateManager, nil
|
return manager, updateManager, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3264,7 +3176,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}, time.Now().UTC())
|
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
|
||||||
assert.NoError(b, err)
|
assert.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -204,10 +204,6 @@ const (
|
|||||||
UserInviteLinkRegenerated Activity = 106
|
UserInviteLinkRegenerated Activity = 106
|
||||||
UserInviteLinkDeleted Activity = 107
|
UserInviteLinkDeleted Activity = 107
|
||||||
|
|
||||||
ServiceCreated Activity = 108
|
|
||||||
ServiceUpdated Activity = 109
|
|
||||||
ServiceDeleted Activity = 110
|
|
||||||
|
|
||||||
AccountDeleted Activity = 99999
|
AccountDeleted Activity = 99999
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -341,10 +337,6 @@ var activityMap = map[Activity]Code{
|
|||||||
UserInviteLinkAccepted: {"User invite link accepted", "user.invite.link.accept"},
|
UserInviteLinkAccepted: {"User invite link accepted", "user.invite.link.accept"},
|
||||||
UserInviteLinkRegenerated: {"User invite link regenerated", "user.invite.link.regenerate"},
|
UserInviteLinkRegenerated: {"User invite link regenerated", "user.invite.link.regenerate"},
|
||||||
UserInviteLinkDeleted: {"User invite link deleted", "user.invite.link.delete"},
|
UserInviteLinkDeleted: {"User invite link deleted", "user.invite.link.delete"},
|
||||||
|
|
||||||
ServiceCreated: {"Service created", "service.create"},
|
|
||||||
ServiceUpdated: {"Service updated", "service.update"},
|
|
||||||
ServiceDeleted: {"Service deleted", "service.delete"},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringCode returns a string code of the activity
|
// StringCode returns a string code of the activity
|
||||||
|
|||||||
@@ -703,7 +703,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
|||||||
t.Run("saving group linked to network router", func(t *testing.T) {
|
t.Run("saving group linked to network router", func(t *testing.T) {
|
||||||
permissionsManager := permissions.NewManager(manager.Store)
|
permissionsManager := permissions.NewManager(manager.Store)
|
||||||
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
||||||
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager)
|
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager)
|
||||||
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
||||||
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user