Compare commits

...

61 Commits

Author SHA1 Message Date
Diego Noguês
b016a1f0d0 feat: poc for token command on combined 2026-02-13 01:22:59 +01:00
Diego Noguês
c009055693 feat: adds netbird's proxy component to getting-started 2026-02-13 00:42:59 +01:00
Diego Noguês
14181c909c fix: remove duplicate import 2026-02-13 00:02:50 +01:00
mlsmaycon
a05dc3823d Merge branch 'main' into prototype/reverse-proxy
# Conflicts:
#	infrastructure_files/getting-started.sh
2026-02-12 19:27:12 +01:00
Misha Bragin
64b849c801 [self-hosted] add netbird server (#5232)
* Unified NetBird combined server (Management, Signal, Relay, STUN) as a single executable with richer YAML configuration, validation, and defaults.
  * Official Dockerfile/image for single-container deployment.
  * Optional in-process profiling endpoint for diagnostics.
  * Multiplexing to route HTTP/gRPC/WebSocket traffic via one port; runtime hooks to inject custom handlers.
* **Chores**
  * Updated deployment scripts, compose files, and reverse-proxy templates to target the combined server; added example configs and getting-started updates.
2026-02-12 19:24:43 +01:00
Diego Noguês
7d19bdf085 feat: adding traefik + nb's reverse proxy (#5303)
* feat: adding traefik and proxy component to getting-started

* feat: adding traefik and proxy component to getting-started

* feat: adding IPAM settings to docker compose and setting static ip to traefik

* fix: remove change to peers group all

* feat: switch to labels for traefik instead of static conf files

* feat: adding traefik and proxy component to getting-started

* feat: adding IPAM settings to docker compose and setting static ip to traefik

* fix: remove change to peers group all

* feat: switch to labels for traefik instead of static conf files

* chore: remove unnecessary comment

* chore: build

* chore: switching env var for NB_PROXY_DOMAIN
2026-02-12 19:12:20 +01:00
Diego Noguês
a1b048f2ad feat: adding traefik + nb reverse proxy 2026-02-12 18:43:35 +01:00
mlsmaycon
0bd227196e fix integration tests 2026-02-12 18:22:41 +01:00
Viktor Liu
eea7687ddf Fix lint and failing tests 2026-02-12 18:19:13 +01:00
mlsmaycon
57d3ee5aac optimize the DeriveClusterFromDomain function
1. validate domain only for proxy urls
2. use registered target cluster for custom domain extraction
2026-02-12 17:10:32 +01:00
pascal
cfdfdecc14 return error if unable to derive cluster on service creation 2026-02-12 16:57:16 +01:00
mlsmaycon
ac995bae6d rename url flag to domain and update validation 2026-02-12 16:28:29 +01:00
Alisdair MacLeod
41a5509ce0 fix nil pointer error in roundtripper 2026-02-12 15:19:19 +00:00
pascal
db5e26db94 rename domain type 2026-02-12 16:15:02 +01:00
Viktor Liu
fe975fb834 Fix missing lang attribute 2026-02-12 23:03:50 +08:00
Viktor Liu
e368d2995b Fix test 2026-02-12 22:57:28 +08:00
Viktor Liu
a3241d8376 Fix swallowed response codes 2026-02-12 22:54:17 +08:00
Alisdair MacLeod
6dfc5772ba fix nil pointer error in roundtripper 2026-02-12 14:44:07 +00:00
Viktor Liu
f70925178c Handle TCP port reuse for TIME-WAIT connections 2026-02-12 22:06:29 +08:00
Viktor Liu
9554934b92 Validate trusted proxies in OAuth callback getClientIP 2026-02-12 22:06:29 +08:00
Viktor Liu
7fdb824a37 Remove write permissions from /var/lib/netbird in proxy Dockerfile 2026-02-12 22:06:29 +08:00
Viktor Liu
412407adc0 Add .dockerignore to exclude sensitive files from build context 2026-02-12 22:06:29 +08:00
Viktor Liu
e0874d7de7 Add noopener to window.open in ErrorPage 2026-02-12 22:06:29 +08:00
pascal
8df1536cbb Merge branch 'main' into prototype/reverse-proxy 2026-02-12 15:05:14 +01:00
pascal
fcbacc62ec clear userID from access logs if not oidc 2026-02-12 14:50:35 +01:00
pascal
ee2ae45653 add permissions validation to domain manager 2026-02-12 14:31:23 +01:00
pascal
6f2f0f9ae4 exclude proxy peers on peers api 2026-02-12 13:49:05 +01:00
Alisdair MacLeod
c37ebc6fb3 add more metrics, improve metrics, reduce metrics impact on other packages 2026-02-12 12:36:54 +00:00
Viktor Liu
23abb5743c Treated tombstoned conns as new 2026-02-12 20:11:12 +08:00
Viktor Liu
b87aa0bc15 Address linter issues 2026-02-12 18:41:20 +08:00
Maycon Santos
69d4b5d821 [misc] Update sign pipeline version (#5296) 2026-02-12 11:31:49 +01:00
Viktor Liu
f1a65d732d Add proxy to license boundary check 2026-02-12 18:31:18 +08:00
Viktor Liu
a3c0ea3e71 Add proxy unit test workflow 2026-02-12 18:31:18 +08:00
Viktor Liu
abaf061c2a Skip nil client for health 2026-02-12 18:31:18 +08:00
pascal
e531fb54b1 ignore error 2026-02-12 11:20:22 +01:00
mlsmaycon
5fcfed5b16 add proxy tests 2026-02-12 11:19:10 +01:00
pascal
5f43449f67 move linter exceptions 2026-02-12 10:45:21 +01:00
mlsmaycon
6796601aa6 Generate a random nonce to ensure each OIDC request gets a unique state 2026-02-12 10:45:13 +01:00
pascal
1fc25c301b move linter exceptions 2026-02-12 10:11:49 +01:00
Viktor Liu
08ae281b2d Fix network monitor restarting the client in netstack mode 2026-02-12 16:48:31 +08:00
Viktor Liu
3dfa97dcbd [client] Fix stale entries in nftables with no handle (#5272) 2026-02-12 09:15:57 +01:00
Viktor Liu
1ddc9ce2bf [client] Fix nil pointer panic in device and engine code (#5287) 2026-02-12 09:15:42 +01:00
Viktor Liu
bd47f44c63 Preload services targets 2026-02-12 16:04:55 +08:00
Viktor Liu
381260911b Create unique token per proxy 2026-02-12 15:48:35 +08:00
Viktor Liu
38db42e7d6 Fix initial sync complete on empty service list 2026-02-12 15:48:35 +08:00
Viktor Liu
5d606d909d Add TTL-based expiry and cleanup for PKCE verifiers to prevent unbounded memory growth 2026-02-12 15:12:41 +08:00
Viktor Liu
d689718b50 Improve logging and error handling 2026-02-12 15:12:41 +08:00
pascal
54a73c6649 move linter exceptions 2026-02-12 02:10:00 +01:00
pascal
418377842e fix tests 2026-02-12 02:00:22 +01:00
pascal
15ef56e03d fix typos 2026-02-12 01:54:14 +01:00
pascal
917035f8e8 fix tests 2026-02-12 01:52:30 +01:00
pascal
963e3f5457 fix linter issues 2026-02-12 01:15:36 +01:00
pascal
e20b969188 fix linter issues 2026-02-12 01:02:13 +01:00
pascal
1c7059ee67 fix some tests 2026-02-12 00:16:33 +01:00
pascal
22a3365658 fix rename errors and tests 2026-02-11 22:34:50 +01:00
Maycon Santos
2de1949018 [client] Check if login is required on foreground mode (#5295) 2026-02-11 21:42:36 +01:00
pascal
08ab1e3478 rename reverse proxy to services 2026-02-11 21:39:51 +01:00
pascal
ebb1f4007d add id to request log search 2026-02-11 19:25:23 +01:00
pascal
acb53ece93 Merge branch 'prototype/reverse-proxy-logs-pagination' into prototype/reverse-proxy 2026-02-11 18:51:28 +01:00
Alisdair MacLeod
f3493ee042 add basic metrics for stress testing 2026-02-11 14:56:39 +00:00
Vlad
fc88399c23 [management] fixed ischild check (#5279) 2026-02-10 20:31:15 +03:00
133 changed files with 7027 additions and 2811 deletions

6
.dockerignore Normal file
View File

@@ -0,0 +1,6 @@
.env
.env.*
*.pem
*.key
*.crt
*.p12

View File

@@ -23,7 +23,7 @@ jobs:
- name: Check for problematic license dependencies
run: |
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..."
echo ""
# Find all directories except the problematic ones and system dirs
@@ -31,7 +31,7 @@ jobs:
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
@@ -39,11 +39,11 @@ jobs:
else
echo "✓ No problematic dependencies found"
fi
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name ".git*" | sort)
echo ""
if [ $FOUND_ISSUES -eq 1 ]; then
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
@@ -88,7 +88,7 @@ jobs:
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"

View File

@@ -43,5 +43,5 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)

View File

@@ -46,6 +46,5 @@ jobs:
time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...

View File

@@ -144,7 +144,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
test_client_on_docker:
name: "Client (Docker) / Unit"
@@ -204,7 +204,7 @@ jobs:
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /client/ui -e /upload-server)
'
test_relay:
@@ -261,6 +261,53 @@ jobs:
-exec 'sudo' \
-timeout 10m -p 1 ./relay/... ./shared/relay/...
test_proxy:
name: "Proxy / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Install dependencies
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test -timeout 10m -p 1 ./proxy/...
test_signal:
name: "Signal / Unit"
needs: [build-cache]

View File

@@ -63,7 +63,7 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' })" >> $env:GITHUB_ENV
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"

View File

@@ -20,7 +20,7 @@ jobs:
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
skip: go.mod,go.sum
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:
fail-fast: false

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.0"
SIGN_PIPE_VER: "v0.1.1"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

View File

@@ -106,6 +106,26 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-server
dir: combined
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-server
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-upload
dir: upload-server
env: [CGO_ENABLED=0]
@@ -520,6 +540,55 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
ids:
- netbird-server
goarch: amd64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
ids:
- netbird-server
goarch: arm64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
ids:
- netbird-server
goarch: arm
goarm: 6
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }}
image_templates:
@@ -598,6 +667,18 @@ docker_manifests:
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:{{ .Version }}
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:latest
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
@@ -675,6 +756,19 @@ docker_manifests:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:latest
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
brews:
- ids:
- default

View File

@@ -282,13 +282,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
}
defer authClient.Close()
needsLogin := false
err, isAuthError := authClient.Login(ctx, "", "")
if isAuthError {
needsLogin = true
} else if err != nil {
return fmt.Errorf("login check failed: %v", err)
needsLogin, err := authClient.IsLoginRequired(ctx)
if err != nil {
return fmt.Errorf("check login required: %v", err)
}
jwtToken := ""

View File

@@ -31,6 +31,14 @@ var (
ErrConfigNotInitialized = errors.New("config not initialized")
)
// PeerConnStatus is a peer's connection status.
type PeerConnStatus = peer.ConnStatus
const (
// PeerStatusConnected indicates the peer is in connected state.
PeerStatusConnected = peer.StatusConnected
)
// Client manages a netbird embedded client instance.
type Client struct {
deviceName string

View File

@@ -483,7 +483,12 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
}
if nftRule.Handle == 0 {
return fmt.Errorf("route rule %s has no handle", ruleKey)
log.Warnf("route rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(nftRule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
@@ -660,13 +665,32 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
}
if err := r.conn.Flush(); err != nil {
// TODO: rollback ipset counter
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
r.rollbackRules(pair)
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
}
return nil
}
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
func (r *router) rollbackRules(pair firewall.RouterPair) {
keys := []string{
firewall.GenKey(firewall.ForwardingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
}
for _, key := range keys {
rule, ok := r.rules[key]
if !ok {
continue
}
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("rollback set counter for %s: %v", key, err)
}
delete(r.rules, key)
}
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
@@ -928,18 +952,30 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
rule, exists := r.rules[ruleKey]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
@@ -1329,65 +1365,89 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err)
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
}
// Set counters are decremented in the sub-methods above before flush. If flush fails,
// counters will be off until the next successful removal or refresh cycle.
if err := r.conn.Flush(); err != nil {
// TODO: rollback set counter
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
}
return nil
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} else {
rule, exists := r.rules[ruleKey]
if !exists {
log.Debugf("prerouting rule %s not found", ruleKey)
return nil
}
if rule.Handle == 0 {
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
// refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries
// (e.g. from failed flushes) and updates handles for all existing rules.
func (r *router) refreshRulesMap() error {
var merr *multierror.Error
newRules := make(map[string]*nftables.Rule)
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("list rules: %w", err)
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
// preserve existing entries for this chain since we can't verify their state
for k, v := range r.rules {
if v.Chain != nil && v.Chain.Name == chain.Name {
newRules[k] = v
}
}
continue
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
newRules[string(rule.UserData)] = rule
}
}
}
return nil
r.rules = newRules
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
@@ -1629,20 +1689,34 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
}
var merr *multierror.Error
var needsFlush bool
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.conn.DelRule(dnatRule); err != nil {
if dnatRule.Handle == 0 {
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
delete(r.rules, ruleKey+dnatSuffix)
} else if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
} else {
needsFlush = true
}
}
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.conn.DelRule(masqRule); err != nil {
if masqRule.Handle == 0 {
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
delete(r.rules, ruleKey+snatSuffix)
} else if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
} else {
needsFlush = true
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
if needsFlush {
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
}
if merr == nil {
@@ -1757,16 +1831,25 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if rule, exists := r.rules[ruleID]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}

View File

@@ -18,6 +18,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/id"
)
const (
@@ -719,3 +720,137 @@ func deleteWorkTable() {
}
}
}
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Add a real rule to the kernel
ruleKey, err := r.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
firewall.ProtocolTCP,
nil,
&firewall.Port{Values: []uint16{80}},
firewall.ActionAccept,
)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey))
})
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
staleKey := "stale-rule-that-does-not-exist"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
err = r.refreshRulesMap()
require.NoError(t, err)
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
realRule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "real rule should still exist after refresh")
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
}
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Inject a stale entry with Handle=0
staleKey := "stale-route-rule"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
// DeleteRouteRule should not return an error for stale handles
err = r.DeleteRouteRule(id.RuleID(staleKey))
assert.NoError(t, err, "deleting a stale rule should not error")
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
}
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
pair := firewall.RouterPair{
ID: "staletest",
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: true,
}
rtr := manager.router
// First add succeeds
err = rtr.AddNatRule(pair)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, rtr.RemoveNatRule(pair))
})
// Corrupt the handle to simulate stale state
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := rtr.rules[natRuleKey]; exists {
rule.Handle = 0
}
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
if rule, exists := rtr.rules[inverseKey]; exists {
rule.Handle = 0
}
// Adding the same rule again should succeed despite stale handles
err = rtr.AddNatRule(pair)
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
// Verify rules exist in kernel
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err)
found := 0
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found++
}
}
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
}

View File

@@ -115,6 +115,17 @@ func (t *TCPConnTrack) IsTombstone() bool {
return t.tombstone.Load()
}
// IsSupersededBy returns true if this connection should be replaced by a new one
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
// connections are superseded by a pure SYN (a new connection attempt for the same
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
if t.tombstone.Load() {
return true
}
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
}
// SetTombstone safely marks the connection for deletion
func (t *TCPConnTrack) SetTombstone() {
t.tombstone.Store(true)
@@ -169,7 +180,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists {
if exists && !conn.IsSupersededBy(flags) {
t.updateState(key, conn, flags, direction, size)
return key, uint16(conn.DNATOrigPort.Load()), true
}
@@ -241,7 +252,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.IsTombstone() {
if !exists || conn.IsSupersededBy(flags) {
return false
}

View File

@@ -485,6 +485,261 @@ func TestTCPAbnormalSequences(t *testing.T) {
})
}
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
// updateIfExists treats tombstoned entries as live, causing track() to skip
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
// because the entry is tombstoned, and the response packet gets dropped by ACL.
func TestTCPPortReuseTombstone(t *testing.T) {
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and gracefully close a connection (server-initiated close)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Server sends FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
// Client sends FIN-ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
// Server sends final ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
// Connection should be tombstoned
conn := tracker.connections[key]
require.NotNil(t, conn, "old connection should still be in map")
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
// Now reuse the same port for a new connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
// The old tombstoned entry should be replaced with a new one
newConn := tracker.connections[key]
require.NotNil(t, newConn, "new connection should exist")
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
require.Equal(t, TCPStateSynSent, newConn.GetState())
// SYN-ACK for the new connection should be valid
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
require.Equal(t, TCPStateEstablished, newConn.GetState())
// Data transfer should work
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
require.True(t, valid, "data should be allowed on new connection")
})
t.Run("Outbound port reuse after RST", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and RST a connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
// Reuse the same port
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
newConn := tracker.connections[key]
require.NotNil(t, newConn)
require.False(t, newConn.IsTombstone())
require.Equal(t, TCPStateSynSent, newConn.GetState())
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
})
t.Run("Inbound port reuse after close", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
clientIP := srcIP
serverIP := dstIP
clientPort := srcPort
serverPort := dstPort
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
// Inbound connection: client SYN → server SYN-ACK → client ACK
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateEstablished, conn.GetState())
// Server-initiated close to reach Closed/tombstoned:
// Server FIN (opposite dir) → CloseWait
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
require.Equal(t, TCPStateCloseWait, conn.GetState())
// Client FIN-ACK (same dir as conn) → LastAck
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Server final ACK (opposite dir) → Closed → tombstoned
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
require.True(t, conn.IsTombstone())
// New inbound connection on same ports
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
newConn := tracker.connections[key]
require.NotNil(t, newConn)
require.False(t, newConn.IsTombstone())
require.Equal(t, TCPStateSynReceived, newConn.GetState())
// Complete handshake: server SYN-ACK, then client ACK
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
require.Equal(t, TCPStateEstablished, newConn.GetState())
})
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
conn := tracker.connections[key]
require.True(t, conn.IsTombstone())
// Late ACK should be rejected (tombstoned)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
// Late outbound ACK should not create a new connection (not a SYN)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
})
}
func TestTCPPortReuseTimeWait(t *testing.T) {
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Active close: client (outbound initiator) sends FIN first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Server ACKs the FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Server sends its own FIN
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
// New outbound SYN on the same port (port reuse during TIME-WAIT)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
newConn := tracker.connections[key]
require.NotNil(t, newConn, "new connection should exist")
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
// SYN-ACK for new connection should be valid
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK for new connection should be accepted")
require.Equal(t, TCPStateEstablished, newConn.GetState())
})
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish outbound connection and close via active close → TIME-WAIT
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
// so the filter falls through to ACL check + TrackInbound (which creates
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
// Simulate what the filter does next: TrackInbound via the normal path
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
newConn := tracker.connections[invertedKey]
require.NotNil(t, newConn, "new inbound connection should be tracked")
require.Equal(t, TCPStateSynReceived, newConn.GetState())
require.False(t, newConn.IsTombstone())
})
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and active close → TIME-WAIT
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Late ACK retransmits during TIME-WAIT should still be accepted
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
})
}
func TestTCPTimeoutHandling(t *testing.T) {
// Create tracker with a very short timeout for testing
shortTimeout := 100 * time.Millisecond

View File

@@ -5,6 +5,8 @@ import (
"context"
"fmt"
"io"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
@@ -16,9 +18,18 @@ const (
maxBatchSize = 1024 * 16
maxMessageSize = 1024 * 2
defaultFlushInterval = 2 * time.Second
logChannelSize = 1000
defaultLogChanSize = 1000
)
func getLogChannelSize() int {
if v := os.Getenv("NB_USPFILTER_LOG_BUFFER"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return defaultLogChanSize
}
type Level uint32
const (
@@ -69,7 +80,7 @@ type Logger struct {
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{
output: logrusLogger.Out,
msgChannel: make(chan logMessage, logChannelSize),
msgChannel: make(chan logMessage, getLogChannelSize()),
shutdown: make(chan struct{}),
bufPool: sync.Pool{
New: func() any {

View File

@@ -29,8 +29,9 @@ type PacketFilter interface {
type FilteredDevice struct {
tun.Device
filter PacketFilter
mutex sync.RWMutex
filter PacketFilter
mutex sync.RWMutex
closeOnce sync.Once
}
// newDeviceFilter constructor function
@@ -40,6 +41,20 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {
}
}
// Close closes the underlying tun device exactly once.
// wireguard-go's netTun.Close() panics on double-close due to a bare close(channel),
// and multiple code paths can trigger Close on the same device.
func (d *FilteredDevice) Close() error {
var err error
d.closeOnce.Do(func() {
err = d.Device.Close()
})
if err != nil {
return err
}
return nil
}
// Read wraps read method with filtering feature
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {

View File

@@ -82,7 +82,9 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()
if cErr := tunIface.Close(); cErr != nil {
log.Debugf("failed to close tun device: %v", cErr)
}
return nil, fmt.Errorf("error configuring interface: %s", err)
}

View File

@@ -66,7 +66,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
}
}()
return nsTunDev, tunNet, nil
return t.tundev, tunNet, nil
}
func (t *NetStackTun) Close() error {

View File

@@ -28,6 +28,7 @@ import (
"github.com/netbirdio/netbird/client/firewall"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl"
@@ -543,11 +544,12 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor()
e.shutdownWg.Add(1)
wgIfaceName := e.wgInterface.Name()
go func() {
defer e.shutdownWg.Done()
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, wgIfaceName); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.triggerClientRestart()
} else if err != nil {
@@ -1922,7 +1924,7 @@ func (e *Engine) triggerClientRestart() {
}
func (e *Engine) startNetworkMonitor() {
if !e.config.NetworkMonitor {
if !e.config.NetworkMonitor || nbnetstack.IsEnabled() {
log.Infof("Network monitor is disabled, not starting")
return
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
@@ -38,11 +37,6 @@ func New() *NetworkMonitor {
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
if netstack.IsEnabled() {
log.Debugf("Network monitor: skipping in netstack mode")
return nil
}
nw.mu.Lock()
if nw.cancel != nil {
nw.mu.Unlock()

5
combined/Dockerfile Normal file
View File

@@ -0,0 +1,5 @@
FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-server" ]
CMD ["--config", "/etc/netbird/config.yaml"]
COPY netbird-server /go/bin/netbird-server

715
combined/cmd/config.go Normal file
View File

@@ -0,0 +1,715 @@
package cmd
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"path"
"strings"
"time"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/crypt"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
)
// CombinedConfig is the root configuration for the combined server.
// The combined server is primarily a Management server with optional embedded
// Signal, Relay, and STUN services.
//
// Architecture:
// - Management: Always runs locally (this IS the management server)
// - Signal: Runs locally by default; disabled if server.signalUri is set
// - Relay: Runs locally by default; disabled if server.relays is set
// - STUN: Runs locally on port 3478 by default; disabled if server.stuns is set
//
// All user-facing settings are under "server". The relay/signal/management
// fields are internal and populated automatically from server settings.
type CombinedConfig struct {
Server ServerConfig `yaml:"server"`
// Internal configs - populated from Server settings, not user-configurable
Relay RelayConfig `yaml:"-"`
Signal SignalConfig `yaml:"-"`
Management ManagementConfig `yaml:"-"`
}
// ServerConfig contains server-wide settings
// In simplified mode, this contains all configuration
type ServerConfig struct {
ListenAddress string `yaml:"listenAddress"`
MetricsPort int `yaml:"metricsPort"`
HealthcheckAddress string `yaml:"healthcheckAddress"`
LogLevel string `yaml:"logLevel"`
LogFile string `yaml:"logFile"`
TLS TLSConfig `yaml:"tls"`
// Simplified config fields (used when relay/signal/management sections are omitted)
ExposedAddress string `yaml:"exposedAddress"` // Public address with protocol (e.g., "https://example.com:443")
StunPorts []int `yaml:"stunPorts"` // STUN ports (empty to disable local STUN)
AuthSecret string `yaml:"authSecret"` // Shared secret for relay authentication
DataDir string `yaml:"dataDir"` // Data directory for all services
// External service overrides (simplified mode)
// When these are set, the corresponding local service is NOT started
// and these values are used for client configuration instead
Stuns []HostConfig `yaml:"stuns"` // External STUN servers (disables local STUN)
Relays RelaysConfig `yaml:"relays"` // External relay servers (disables local relay)
SignalURI string `yaml:"signalUri"` // External signal server (disables local signal)
// Management settings (simplified mode)
DisableAnonymousMetrics bool `yaml:"disableAnonymousMetrics"`
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
Auth AuthConfig `yaml:"auth"`
Store StoreConfig `yaml:"store"`
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
}
// TLSConfig contains TLS/HTTPS settings
type TLSConfig struct {
CertFile string `yaml:"certFile"`
KeyFile string `yaml:"keyFile"`
LetsEncrypt LetsEncryptConfig `yaml:"letsencrypt"`
}
// LetsEncryptConfig contains Let's Encrypt settings
type LetsEncryptConfig struct {
Enabled bool `yaml:"enabled"`
DataDir string `yaml:"dataDir"`
Domains []string `yaml:"domains"`
Email string `yaml:"email"`
AWSRoute53 bool `yaml:"awsRoute53"`
}
// RelayConfig contains relay service settings
type RelayConfig struct {
Enabled bool `yaml:"enabled"`
ExposedAddress string `yaml:"exposedAddress"`
AuthSecret string `yaml:"authSecret"`
LogLevel string `yaml:"logLevel"`
Stun StunConfig `yaml:"stun"`
}
// StunConfig contains embedded STUN service settings
type StunConfig struct {
Enabled bool `yaml:"enabled"`
Ports []int `yaml:"ports"`
LogLevel string `yaml:"logLevel"`
}
// SignalConfig contains signal service settings
type SignalConfig struct {
Enabled bool `yaml:"enabled"`
LogLevel string `yaml:"logLevel"`
}
// ManagementConfig contains management service settings
type ManagementConfig struct {
Enabled bool `yaml:"enabled"`
LogLevel string `yaml:"logLevel"`
DataDir string `yaml:"dataDir"`
DnsDomain string `yaml:"dnsDomain"`
DisableAnonymousMetrics bool `yaml:"disableAnonymousMetrics"`
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
DisableDefaultPolicy bool `yaml:"disableDefaultPolicy"`
Auth AuthConfig `yaml:"auth"`
Stuns []HostConfig `yaml:"stuns"`
Relays RelaysConfig `yaml:"relays"`
SignalURI string `yaml:"signalUri"`
Store StoreConfig `yaml:"store"`
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
}
// AuthConfig contains authentication/identity provider settings
type AuthConfig struct {
Issuer string `yaml:"issuer"`
LocalAuthDisabled bool `yaml:"localAuthDisabled"`
SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"`
Storage AuthStorageConfig `yaml:"storage"`
DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"`
CLIRedirectURIs []string `yaml:"cliRedirectURIs"`
Owner *AuthOwnerConfig `yaml:"owner,omitempty"`
}
// AuthStorageConfig contains auth storage settings
type AuthStorageConfig struct {
Type string `yaml:"type"`
File string `yaml:"file"`
}
// AuthOwnerConfig contains initial admin user settings
type AuthOwnerConfig struct {
Email string `yaml:"email"`
Password string `yaml:"password"`
}
// HostConfig represents a STUN/TURN/Signal host
type HostConfig struct {
URI string `yaml:"uri"`
Proto string `yaml:"proto,omitempty"` // udp, dtls, tcp, http, https - defaults based on URI scheme
Username string `yaml:"username,omitempty"`
Password string `yaml:"password,omitempty"`
}
// RelaysConfig contains external relay server settings for clients
type RelaysConfig struct {
Addresses []string `yaml:"addresses"`
CredentialsTTL string `yaml:"credentialsTTL"`
Secret string `yaml:"secret"`
}
// StoreConfig contains database settings
type StoreConfig struct {
Engine string `yaml:"engine"`
EncryptionKey string `yaml:"encryptionKey"`
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
}
// ReverseProxyConfig contains reverse proxy settings
type ReverseProxyConfig struct {
TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"`
TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"`
TrustedPeers []string `yaml:"trustedPeers"`
}
// DefaultConfig returns a CombinedConfig with default values
func DefaultConfig() *CombinedConfig {
return &CombinedConfig{
Server: ServerConfig{
ListenAddress: ":443",
MetricsPort: 9090,
HealthcheckAddress: ":9000",
LogLevel: "info",
LogFile: "console",
StunPorts: []int{3478},
DataDir: "/var/lib/netbird/",
Auth: AuthConfig{
Storage: AuthStorageConfig{
Type: "sqlite3",
},
},
Store: StoreConfig{
Engine: "sqlite",
},
},
Relay: RelayConfig{
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
Stun: StunConfig{
Enabled: false,
Ports: []int{3478},
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
},
},
Signal: SignalConfig{
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
},
Management: ManagementConfig{
DataDir: "/var/lib/netbird/",
Auth: AuthConfig{
Storage: AuthStorageConfig{
Type: "sqlite3",
},
},
Relays: RelaysConfig{
CredentialsTTL: "12h",
},
Store: StoreConfig{
Engine: "sqlite",
},
},
}
}
// hasRequiredSettings returns true if the configuration has the required server settings
func (c *CombinedConfig) hasRequiredSettings() bool {
return c.Server.ExposedAddress != ""
}
// parseExposedAddress extracts protocol, host, and host:port from the exposed address
// Input format: "https://example.com:443" or "http://example.com:8080" or "example.com:443"
// Returns: protocol ("https" or "http"), hostname only, and host:port
func parseExposedAddress(exposedAddress string) (protocol, hostname, hostPort string) {
// Default to https if no protocol specified
protocol = "https"
hostPort = exposedAddress
// Check for protocol prefix
if strings.HasPrefix(exposedAddress, "https://") {
protocol = "https"
hostPort = strings.TrimPrefix(exposedAddress, "https://")
} else if strings.HasPrefix(exposedAddress, "http://") {
protocol = "http"
hostPort = strings.TrimPrefix(exposedAddress, "http://")
}
// Extract hostname (without port)
hostname = hostPort
if host, _, err := net.SplitHostPort(hostPort); err == nil {
hostname = host
}
return protocol, hostname, hostPort
}
// ApplySimplifiedDefaults populates internal relay/signal/management configs from server settings.
// Management is always enabled. Signal, Relay, and STUN are enabled unless external
// overrides are configured (server.signalUri, server.relays, server.stuns).
func (c *CombinedConfig) ApplySimplifiedDefaults() {
if !c.hasRequiredSettings() {
return
}
// Parse exposed address to extract protocol and hostname
exposedProto, exposedHost, exposedHostPort := parseExposedAddress(c.Server.ExposedAddress)
// Check for external service overrides
hasExternalRelay := len(c.Server.Relays.Addresses) > 0
hasExternalSignal := c.Server.SignalURI != ""
hasExternalStuns := len(c.Server.Stuns) > 0
// Default stunPorts to [3478] if not specified and no external STUN
if len(c.Server.StunPorts) == 0 && !hasExternalStuns {
c.Server.StunPorts = []int{3478}
}
c.applyRelayDefaults(exposedProto, exposedHostPort, hasExternalRelay, hasExternalStuns)
c.applySignalDefaults(hasExternalSignal)
c.applyManagementDefaults(exposedHost)
// Auto-configure client settings (stuns, relays, signalUri)
c.autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort, hasExternalStuns, hasExternalRelay, hasExternalSignal)
}
// applyRelayDefaults configures the relay service if no external relay is configured.
func (c *CombinedConfig) applyRelayDefaults(exposedProto, exposedHostPort string, hasExternalRelay, hasExternalStuns bool) {
if hasExternalRelay {
return
}
c.Relay.Enabled = true
relayProto := "rel"
if exposedProto == "https" {
relayProto = "rels"
}
c.Relay.ExposedAddress = fmt.Sprintf("%s://%s", relayProto, exposedHostPort)
c.Relay.AuthSecret = c.Server.AuthSecret
if c.Relay.LogLevel == "" {
c.Relay.LogLevel = c.Server.LogLevel
}
// Enable local STUN only if no external STUN servers and stunPorts are configured
if !hasExternalStuns && len(c.Server.StunPorts) > 0 {
c.Relay.Stun.Enabled = true
c.Relay.Stun.Ports = c.Server.StunPorts
if c.Relay.Stun.LogLevel == "" {
c.Relay.Stun.LogLevel = c.Server.LogLevel
}
}
}
// applySignalDefaults configures the signal service if no external signal is configured.
func (c *CombinedConfig) applySignalDefaults(hasExternalSignal bool) {
if hasExternalSignal {
return
}
c.Signal.Enabled = true
if c.Signal.LogLevel == "" {
c.Signal.LogLevel = c.Server.LogLevel
}
}
// applyManagementDefaults configures the management service (always enabled).
func (c *CombinedConfig) applyManagementDefaults(exposedHost string) {
c.Management.Enabled = true
if c.Management.LogLevel == "" {
c.Management.LogLevel = c.Server.LogLevel
}
if c.Management.DataDir == "" || c.Management.DataDir == "/var/lib/netbird/" {
c.Management.DataDir = c.Server.DataDir
}
c.Management.DnsDomain = exposedHost
c.Management.DisableAnonymousMetrics = c.Server.DisableAnonymousMetrics
c.Management.DisableGeoliteUpdate = c.Server.DisableGeoliteUpdate
// Copy auth config from server if management auth issuer is not set
if c.Management.Auth.Issuer == "" && c.Server.Auth.Issuer != "" {
c.Management.Auth = c.Server.Auth
}
// Copy store config from server if not set
if c.Management.Store.Engine == "" || c.Management.Store.Engine == "sqlite" {
if c.Server.Store.Engine != "" {
c.Management.Store = c.Server.Store
}
}
// Copy reverse proxy config from server
if len(c.Server.ReverseProxy.TrustedHTTPProxies) > 0 || c.Server.ReverseProxy.TrustedHTTPProxiesCount > 0 || len(c.Server.ReverseProxy.TrustedPeers) > 0 {
c.Management.ReverseProxy = c.Server.ReverseProxy
}
}
// autoConfigureClientSettings sets up STUN/relay/signal URIs for clients
// External overrides from server config take precedence over auto-generated values
func (c *CombinedConfig) autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort string, hasExternalStuns, hasExternalRelay, hasExternalSignal bool) {
// Determine relay protocol from exposed protocol
relayProto := "rel"
if exposedProto == "https" {
relayProto = "rels"
}
// Configure STUN servers for clients
if hasExternalStuns {
// Use external STUN servers from server config
c.Management.Stuns = c.Server.Stuns
} else if len(c.Server.StunPorts) > 0 && len(c.Management.Stuns) == 0 {
// Auto-configure local STUN servers for all ports
for _, port := range c.Server.StunPorts {
c.Management.Stuns = append(c.Management.Stuns, HostConfig{
URI: fmt.Sprintf("stun:%s:%d", exposedHost, port),
})
}
}
// Configure relay for clients
if hasExternalRelay {
// Use external relay config from server
c.Management.Relays = c.Server.Relays
} else if len(c.Management.Relays.Addresses) == 0 {
// Auto-configure local relay
c.Management.Relays.Addresses = []string{
fmt.Sprintf("%s://%s", relayProto, exposedHostPort),
}
}
if c.Management.Relays.Secret == "" {
c.Management.Relays.Secret = c.Server.AuthSecret
}
if c.Management.Relays.CredentialsTTL == "" {
c.Management.Relays.CredentialsTTL = "12h"
}
// Configure signal for clients
if hasExternalSignal {
// Use external signal URI from server config
c.Management.SignalURI = c.Server.SignalURI
} else if c.Management.SignalURI == "" {
// Auto-configure local signal
c.Management.SignalURI = fmt.Sprintf("%s://%s", exposedProto, exposedHostPort)
}
}
// LoadConfig loads configuration from a YAML file
func LoadConfig(configPath string) (*CombinedConfig, error) {
cfg := DefaultConfig()
if configPath == "" {
return cfg, nil
}
data, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
if err := yaml.Unmarshal(data, cfg); err != nil {
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
// Populate internal configs from server settings
cfg.ApplySimplifiedDefaults()
return cfg, nil
}
// Validate validates the configuration
func (c *CombinedConfig) Validate() error {
if c.Server.ExposedAddress == "" {
return fmt.Errorf("server.exposedAddress is required")
}
if c.Server.DataDir == "" {
return fmt.Errorf("server.dataDir is required")
}
// Validate STUN ports
seen := make(map[int]bool)
for _, port := range c.Server.StunPorts {
if port <= 0 || port > 65535 {
return fmt.Errorf("invalid server.stunPorts value %d: must be between 1 and 65535", port)
}
if seen[port] {
return fmt.Errorf("duplicate STUN port %d in server.stunPorts", port)
}
seen[port] = true
}
// authSecret is required only if running local relay (no external relay configured)
hasExternalRelay := len(c.Server.Relays.Addresses) > 0
if !hasExternalRelay && c.Server.AuthSecret == "" {
return fmt.Errorf("server.authSecret is required when running local relay")
}
return nil
}
// HasTLSCert returns true if TLS certificate files are configured
func (c *CombinedConfig) HasTLSCert() bool {
return c.Server.TLS.CertFile != "" && c.Server.TLS.KeyFile != ""
}
// HasLetsEncrypt returns true if Let's Encrypt is configured
func (c *CombinedConfig) HasLetsEncrypt() bool {
return c.Server.TLS.LetsEncrypt.Enabled &&
c.Server.TLS.LetsEncrypt.DataDir != "" &&
len(c.Server.TLS.LetsEncrypt.Domains) > 0
}
// parseExplicitProtocol parses an explicit protocol string to nbconfig.Protocol
func parseExplicitProtocol(proto string) (nbconfig.Protocol, bool) {
switch strings.ToLower(proto) {
case "udp":
return nbconfig.UDP, true
case "dtls":
return nbconfig.DTLS, true
case "tcp":
return nbconfig.TCP, true
case "http":
return nbconfig.HTTP, true
case "https":
return nbconfig.HTTPS, true
default:
return "", false
}
}
// parseStunProtocol determines protocol for STUN/TURN servers.
// stun: → UDP, stuns: → DTLS, turn: → UDP, turns: → DTLS
// Explicit proto overrides URI scheme. Defaults to UDP.
func parseStunProtocol(uri, proto string) nbconfig.Protocol {
if proto != "" {
if p, ok := parseExplicitProtocol(proto); ok {
return p
}
}
uri = strings.ToLower(uri)
switch {
case strings.HasPrefix(uri, "stuns:"):
return nbconfig.DTLS
case strings.HasPrefix(uri, "turns:"):
return nbconfig.DTLS
default:
// stun:, turn:, or no scheme - default to UDP
return nbconfig.UDP
}
}
// parseSignalProtocol determines protocol for Signal servers.
// https:// → HTTPS, http:// → HTTP. Defaults to HTTPS.
func parseSignalProtocol(uri string) nbconfig.Protocol {
uri = strings.ToLower(uri)
switch {
case strings.HasPrefix(uri, "http://"):
return nbconfig.HTTP
default:
// https:// or no scheme - default to HTTPS
return nbconfig.HTTPS
}
}
// stripSignalProtocol removes the protocol prefix from a signal URI.
// Returns just the host:port (e.g., "selfhosted2.demo.netbird.io:443").
func stripSignalProtocol(uri string) string {
uri = strings.TrimPrefix(uri, "https://")
uri = strings.TrimPrefix(uri, "http://")
return uri
}
// ToManagementConfig converts CombinedConfig to management server config
func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
mgmt := c.Management
// Build STUN hosts
var stuns []*nbconfig.Host
for _, s := range mgmt.Stuns {
stuns = append(stuns, &nbconfig.Host{
URI: s.URI,
Proto: parseStunProtocol(s.URI, s.Proto),
Username: s.Username,
Password: s.Password,
})
}
// Build relay config
var relayConfig *nbconfig.Relay
if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" {
var ttl time.Duration
if mgmt.Relays.CredentialsTTL != "" {
var err error
ttl, err = time.ParseDuration(mgmt.Relays.CredentialsTTL)
if err != nil {
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", mgmt.Relays.CredentialsTTL, err)
}
}
relayConfig = &nbconfig.Relay{
Addresses: mgmt.Relays.Addresses,
CredentialsTTL: util.Duration{Duration: ttl},
Secret: mgmt.Relays.Secret,
}
}
// Build signal config
var signalConfig *nbconfig.Host
if mgmt.SignalURI != "" {
signalConfig = &nbconfig.Host{
URI: stripSignalProtocol(mgmt.SignalURI),
Proto: parseSignalProtocol(mgmt.SignalURI),
}
}
// Build store config
storeConfig := nbconfig.StoreConfig{
Engine: types.Engine(mgmt.Store.Engine),
}
// Build reverse proxy config
reverseProxy := nbconfig.ReverseProxy{
TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount,
}
for _, p := range mgmt.ReverseProxy.TrustedHTTPProxies {
if prefix, err := netip.ParsePrefix(p); err == nil {
reverseProxy.TrustedHTTPProxies = append(reverseProxy.TrustedHTTPProxies, prefix)
}
}
for _, p := range mgmt.ReverseProxy.TrustedPeers {
if prefix, err := netip.ParsePrefix(p); err == nil {
reverseProxy.TrustedPeers = append(reverseProxy.TrustedPeers, prefix)
}
}
// Build HTTP config (required, even if empty)
httpConfig := &nbconfig.HttpServerConfig{}
// Build embedded IDP config (always enabled in combined server)
storageFile := mgmt.Auth.Storage.File
if storageFile == "" {
storageFile = path.Join(mgmt.DataDir, "idp.db")
}
embeddedIdP := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
Storage: idp.EmbeddedStorageConfig{
Type: mgmt.Auth.Storage.Type,
Config: idp.EmbeddedStorageTypeConfig{
File: storageFile,
},
},
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
}
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
embeddedIdP.Owner = &idp.OwnerConfig{
Email: mgmt.Auth.Owner.Email,
Hash: mgmt.Auth.Owner.Password, // Will be hashed if plain text
}
}
// Set HTTP config fields for embedded IDP
httpConfig.AuthIssuer = mgmt.Auth.Issuer
httpConfig.IdpSignKeyRefreshEnabled = mgmt.Auth.SignKeyRefreshEnabled
return &nbconfig.Config{
Stuns: stuns,
Relay: relayConfig,
Signal: signalConfig,
Datadir: mgmt.DataDir,
DataStoreEncryptionKey: mgmt.Store.EncryptionKey,
HttpConfig: httpConfig,
StoreConfig: storeConfig,
ReverseProxy: reverseProxy,
DisableDefaultPolicy: mgmt.DisableDefaultPolicy,
EmbeddedIdP: embeddedIdP,
}, nil
}
// ApplyEmbeddedIdPConfig applies embedded IdP configuration to the management config.
// This mirrors the logic in management/cmd/management.go ApplyEmbeddedIdPConfig.
func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config, mgmtPort int, disableSingleAccMode bool) error {
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
return nil
}
// Embedded IdP requires single account mode
if disableSingleAccMode {
return fmt.Errorf("embedded IdP requires single account mode; multiple account mode is not supported with embedded IdP")
}
// Set LocalAddress for embedded IdP, used for internal JWT validation
cfg.EmbeddedIdP.LocalAddress = fmt.Sprintf("localhost:%d", mgmtPort)
// Set storage defaults based on Datadir
if cfg.EmbeddedIdP.Storage.Type == "" {
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
}
if cfg.EmbeddedIdP.Storage.Config.File == "" && cfg.Datadir != "" {
cfg.EmbeddedIdP.Storage.Config.File = path.Join(cfg.Datadir, "idp.db")
}
issuer := cfg.EmbeddedIdP.Issuer
// Ensure HttpConfig exists
if cfg.HttpConfig == nil {
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
}
// Set HttpConfig values from EmbeddedIdP
cfg.HttpConfig.AuthIssuer = issuer
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
cfg.HttpConfig.AuthUserIDClaim = "sub"
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
return nil
}
// EnsureEncryptionKey generates an encryption key if not set.
// Unlike management server, we don't write back to the config file.
func EnsureEncryptionKey(ctx context.Context, cfg *nbconfig.Config) error {
if cfg.DataStoreEncryptionKey != "" {
return nil
}
log.WithContext(ctx).Infof("DataStoreEncryptionKey is not set, generating a new key")
key, err := crypt.GenerateKey()
if err != nil {
return fmt.Errorf("failed to generate datastore encryption key: %v", err)
}
cfg.DataStoreEncryptionKey = key
keyPreview := key[:8] + "..."
log.WithContext(ctx).Warnf("DataStoreEncryptionKey generated (%s); add it to your config file under 'server.store.encryptionKey' to persist across restarts", keyPreview)
return nil
}
// LogConfigInfo logs informational messages about the loaded configuration
func LogConfigInfo(cfg *nbconfig.Config) {
if cfg.EmbeddedIdP != nil && cfg.EmbeddedIdP.Enabled {
log.Infof("running with the embedded IdP: %v", cfg.EmbeddedIdP.Issuer)
}
if cfg.Relay != nil {
log.Infof("Relay addresses: %v", cfg.Relay.Addresses)
}
}

33
combined/cmd/pprof.go Normal file
View File

@@ -0,0 +1,33 @@
//go:build pprof
// +build pprof
package cmd
import (
"net/http"
_ "net/http/pprof"
"os"
log "github.com/sirupsen/logrus"
)
func init() {
addr := pprofAddr()
go pprof(addr)
}
func pprofAddr() string {
listenAddr := os.Getenv("NB_PPROF_ADDR")
if listenAddr == "" {
return "localhost:6969"
}
return listenAddr
}
func pprof(listenAddr string) {
log.Infof("listening pprof on: %s\n", listenAddr)
if err := http.ListenAndServe(listenAddr, nil); err != nil {
log.Fatalf("Failed to start pprof: %v", err)
}
}

711
combined/cmd/root.go Normal file
View File

@@ -0,0 +1,711 @@
package cmd
import (
"context"
"crypto/sha256"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/coder/websocket"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"go.opentelemetry.io/otel/metric"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/encryption"
mgmtServer "github.com/netbirdio/netbird/management/internals/server"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/relay/healthcheck"
relayServer "github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/relay/server/listener/ws"
sharedMetrics "github.com/netbirdio/netbird/shared/metrics"
"github.com/netbirdio/netbird/shared/relay/auth"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/stun"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/wsproxy"
wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server"
)
var (
configPath string
config *CombinedConfig
rootCmd = &cobra.Command{
Use: "combined",
Short: "Combined Netbird server (Management + Signal + Relay + STUN)",
Long: `Combined Netbird server for self-hosted deployments.
All services (Management, Signal, Relay) are multiplexed on a single port.
Optional STUN server runs on separate UDP ports.
Configuration is loaded from a YAML file specified with --config.`,
SilenceUsage: true,
SilenceErrors: true,
RunE: execute,
}
)
func init() {
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
_ = rootCmd.MarkPersistentFlagRequired("config")
}
func Execute() error {
return rootCmd.Execute()
}
func waitForExitSignal() {
osSigs := make(chan os.Signal, 1)
signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
<-osSigs
}
func execute(cmd *cobra.Command, _ []string) error {
if err := initializeConfig(); err != nil {
return err
}
// Management is required as the base server when signal or relay are enabled
if (config.Signal.Enabled || config.Relay.Enabled) && !config.Management.Enabled {
return fmt.Errorf("management must be enabled when signal or relay are enabled (provides the base HTTP server)")
}
servers, err := createAllServers(cmd.Context(), config)
if err != nil {
return err
}
// Register services with management's gRPC server using AfterInit hook
setupServerHooks(servers, config)
// Start management server (this also starts the HTTP listener)
if servers.mgmtSrv != nil {
if err := servers.mgmtSrv.Start(cmd.Context()); err != nil {
cleanupSTUNListeners(servers.stunListeners)
return fmt.Errorf("failed to start management server: %w", err)
}
}
// Start all other servers
wg := sync.WaitGroup{}
startServers(&wg, servers.relaySrv, servers.healthcheck, servers.stunServer, servers.metricsServer)
waitForExitSignal()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err = shutdownServers(ctx, servers.relaySrv, servers.healthcheck, servers.stunServer, servers.mgmtSrv, servers.metricsServer)
wg.Wait()
return err
}
// initializeConfig loads and validates the configuration, then initializes logging.
func initializeConfig() error {
var err error
config, err = LoadConfig(configPath)
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
if err := config.Validate(); err != nil {
return fmt.Errorf("invalid config: %w", err)
}
if err := util.InitLog(config.Server.LogLevel, config.Server.LogFile); err != nil {
return fmt.Errorf("failed to initialize log: %w", err)
}
if dsn := config.Server.Store.DSN; dsn != "" {
switch strings.ToLower(config.Server.Store.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
log.Infof("Starting combined NetBird server")
logConfig(config)
logEnvVars()
return nil
}
// serverInstances holds all server instances created during startup.
type serverInstances struct {
relaySrv *relayServer.Server
mgmtSrv *mgmtServer.BaseServer
signalSrv *signalServer.Server
healthcheck *healthcheck.Server
stunServer *stun.Server
stunListeners []*net.UDPConn
metricsServer *sharedMetrics.Metrics
}
// createAllServers creates all server instances based on configuration.
func createAllServers(ctx context.Context, cfg *CombinedConfig) (*serverInstances, error) {
metricsServer, err := sharedMetrics.NewServer(cfg.Server.MetricsPort, "")
if err != nil {
return nil, fmt.Errorf("failed to create metrics server: %w", err)
}
servers := &serverInstances{
metricsServer: metricsServer,
}
_, tlsSupport, err := handleTLSConfig(cfg)
if err != nil {
return nil, fmt.Errorf("failed to setup TLS config: %w", err)
}
if err := servers.createRelayServer(cfg, tlsSupport); err != nil {
return nil, err
}
if err := servers.createManagementServer(ctx, cfg); err != nil {
return nil, err
}
if err := servers.createSignalServer(ctx, cfg); err != nil {
return nil, err
}
if err := servers.createHealthcheckServer(cfg); err != nil {
return nil, err
}
return servers, nil
}
func (s *serverInstances) createRelayServer(cfg *CombinedConfig, tlsSupport bool) error {
if !cfg.Relay.Enabled {
return nil
}
var err error
s.stunListeners, err = createSTUNListeners(cfg)
if err != nil {
return err
}
hashedSecret := sha256.Sum256([]byte(cfg.Relay.AuthSecret))
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
relayCfg := relayServer.Config{
Meter: s.metricsServer.Meter,
ExposedAddress: cfg.Relay.ExposedAddress,
AuthValidator: authenticator,
TLSSupport: tlsSupport,
}
s.relaySrv, err = createRelayServer(relayCfg, s.stunListeners)
if err != nil {
return err
}
log.Infof("Relay server created")
if len(s.stunListeners) > 0 {
s.stunServer = stun.NewServer(s.stunListeners, cfg.Relay.Stun.LogLevel)
}
return nil
}
func (s *serverInstances) createManagementServer(ctx context.Context, cfg *CombinedConfig) error {
if !cfg.Management.Enabled {
return nil
}
mgmtConfig, err := cfg.ToManagementConfig()
if err != nil {
return fmt.Errorf("failed to create management config: %w", err)
}
_, portStr, portErr := net.SplitHostPort(cfg.Server.ListenAddress)
if portErr != nil {
portStr = "443"
}
mgmtPort, _ := strconv.Atoi(portStr)
if err := ApplyEmbeddedIdPConfig(ctx, mgmtConfig, mgmtPort, false); err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to apply embedded IdP config: %w", err)
}
if err := EnsureEncryptionKey(ctx, mgmtConfig); err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to ensure encryption key: %w", err)
}
LogConfigInfo(mgmtConfig)
s.mgmtSrv, err = createManagementServer(cfg, mgmtConfig)
if err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to create management server: %w", err)
}
// Inject externally-managed AppMetrics so management uses the shared metrics server
appMetrics, err := telemetry.NewAppMetricsWithMeter(ctx, s.metricsServer.Meter)
if err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to create management app metrics: %w", err)
}
mgmtServer.Inject[telemetry.AppMetrics](s.mgmtSrv, appMetrics)
log.Infof("Management server created")
return nil
}
func (s *serverInstances) createSignalServer(ctx context.Context, cfg *CombinedConfig) error {
if !cfg.Signal.Enabled {
return nil
}
var err error
s.signalSrv, err = signalServer.NewServer(ctx, s.metricsServer.Meter, "signal_")
if err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to create signal server: %w", err)
}
log.Infof("Signal server created")
return nil
}
func (s *serverInstances) createHealthcheckServer(cfg *CombinedConfig) error {
hCfg := healthcheck.Config{
ListenAddress: cfg.Server.HealthcheckAddress,
ServiceChecker: s.relaySrv,
}
var err error
s.healthcheck, err = createHealthCheck(hCfg, s.stunListeners)
return err
}
// setupServerHooks registers services with management's gRPC server.
func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
if servers.mgmtSrv == nil {
return
}
servers.mgmtSrv.AfterInit(func(s *mgmtServer.BaseServer) {
grpcSrv := s.GRPCServer()
if servers.signalSrv != nil {
proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv)
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
}
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
if servers.relaySrv != nil {
log.Infof("Relay WebSocket handler added (path: /relay)")
}
})
}
func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, metricsServer *sharedMetrics.Metrics) {
if srv != nil {
instanceURL := srv.InstanceURL()
log.Infof("Relay server instance URL: %s", instanceURL.String())
log.Infof("Relay WebSocket multiplexed on management port (no separate relay listener)")
}
wg.Add(1)
go func() {
defer wg.Done()
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start metrics server: %v", err)
}
}()
wg.Add(1)
go func() {
defer wg.Done()
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start healthcheck server: %v", err)
}
}()
if stunServer != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := stunServer.Listen(); err != nil {
if errors.Is(err, stun.ErrServerClosed) {
return
}
log.Errorf("STUN server error: %v", err)
}
}()
}
}
func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv *mgmtServer.BaseServer, metricsServer *sharedMetrics.Metrics) error {
var errs error
if err := httpHealthcheck.Shutdown(ctx); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close healthcheck server: %w", err))
}
if stunServer != nil {
if err := stunServer.Shutdown(); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close STUN server: %w", err))
}
}
if srv != nil {
if err := srv.Shutdown(ctx); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close relay server: %w", err))
}
}
if mgmtSrv != nil {
log.Infof("shutting down management and signal servers")
if err := mgmtSrv.Stop(); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close management server: %w", err))
}
}
if metricsServer != nil {
log.Infof("shutting down metrics server")
if err := metricsServer.Shutdown(ctx); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close metrics server: %w", err))
}
}
return errs
}
func createHealthCheck(hCfg healthcheck.Config, stunListeners []*net.UDPConn) (*healthcheck.Server, error) {
httpHealthcheck, err := healthcheck.NewServer(hCfg)
if err != nil {
cleanupSTUNListeners(stunListeners)
return nil, fmt.Errorf("failed to create healthcheck server: %w", err)
}
return httpHealthcheck, nil
}
func createRelayServer(cfg relayServer.Config, stunListeners []*net.UDPConn) (*relayServer.Server, error) {
srv, err := relayServer.NewServer(cfg)
if err != nil {
cleanupSTUNListeners(stunListeners)
return nil, fmt.Errorf("failed to create relay server: %w", err)
}
return srv, nil
}
func cleanupSTUNListeners(stunListeners []*net.UDPConn) {
for _, l := range stunListeners {
_ = l.Close()
}
}
func createSTUNListeners(cfg *CombinedConfig) ([]*net.UDPConn, error) {
var stunListeners []*net.UDPConn
if cfg.Relay.Stun.Enabled {
for _, port := range cfg.Relay.Stun.Ports {
listener, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
if err != nil {
cleanupSTUNListeners(stunListeners)
return nil, fmt.Errorf("failed to create STUN listener on port %d: %w", port, err)
}
stunListeners = append(stunListeners, listener)
log.Infof("STUN server listening on UDP port %d", port)
}
}
return stunListeners, nil
}
func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
tlsCfg := cfg.Server.TLS
if tlsCfg.LetsEncrypt.AWSRoute53 {
log.Debugf("using Let's Encrypt DNS resolver with Route 53 support")
r53 := encryption.Route53TLS{
DataDir: tlsCfg.LetsEncrypt.DataDir,
Email: tlsCfg.LetsEncrypt.Email,
Domains: tlsCfg.LetsEncrypt.Domains,
}
tc, err := r53.GetCertificate()
if err != nil {
return nil, false, err
}
return tc, true, nil
}
if cfg.HasLetsEncrypt() {
log.Infof("setting up TLS with Let's Encrypt")
certManager, err := encryption.CreateCertManager(tlsCfg.LetsEncrypt.DataDir, tlsCfg.LetsEncrypt.Domains...)
if err != nil {
return nil, false, fmt.Errorf("failed creating LetsEncrypt cert manager: %w", err)
}
return certManager.TLSConfig(), true, nil
}
if cfg.HasTLSCert() {
log.Debugf("using file based TLS config")
tc, err := encryption.LoadTLSConfig(tlsCfg.CertFile, tlsCfg.KeyFile)
if err != nil {
return nil, false, err
}
return tc, true, nil
}
return nil, false, nil
}
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
mgmt := cfg.Management
dnsDomain := mgmt.DnsDomain
singleAccModeDomain := dnsDomain
// Extract port from listen address
_, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress)
if err != nil {
// If no port specified, assume default
portStr = "443"
}
mgmtPort, _ := strconv.Atoi(portStr)
mgmtSrv := mgmtServer.NewServer(
mgmtConfig,
dnsDomain,
singleAccModeDomain,
mgmtPort,
cfg.Server.MetricsPort,
mgmt.DisableAnonymousMetrics,
mgmt.DisableGeoliteUpdate,
// Always enable user deletion from IDP in combined server (embedded IdP is always enabled)
true,
)
return mgmtSrv, nil
}
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
var relayAcceptFn func(conn net.Conn)
if relaySrv != nil {
relayAcceptFn = relaySrv.RelayAccept()
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
// Native gRPC traffic (HTTP/2 with gRPC content-type)
case r.ProtoMajor == 2 && (strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") ||
strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc+proto")):
grpcServer.ServeHTTP(w, r)
// WebSocket proxy for Management gRPC
case r.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
wsProxy.Handler().ServeHTTP(w, r)
// WebSocket proxy for Signal gRPC
case r.URL.Path == wsproxy.ProxyPath+wsproxy.SignalComponent:
if cfg.Signal.Enabled {
wsProxy.Handler().ServeHTTP(w, r)
} else {
http.Error(w, "Signal service not enabled", http.StatusNotFound)
}
// Relay WebSocket
case r.URL.Path == "/relay":
if relayAcceptFn != nil {
handleRelayWebSocket(w, r, relayAcceptFn, cfg)
} else {
http.Error(w, "Relay service not enabled", http.StatusNotFound)
}
// Management HTTP API (default)
default:
httpHandler.ServeHTTP(w, r)
}
})
}
// handleRelayWebSocket handles incoming WebSocket connections for the relay service
func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn net.Conn), cfg *CombinedConfig) {
acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
}
wsConn, err := websocket.Accept(w, r, acceptOptions)
if err != nil {
log.Errorf("failed to accept relay ws connection: %s", err)
return
}
connRemoteAddr := r.RemoteAddr
if r.Header.Get("X-Real-Ip") != "" && r.Header.Get("X-Real-Port") != "" {
connRemoteAddr = net.JoinHostPort(r.Header.Get("X-Real-Ip"), r.Header.Get("X-Real-Port"))
}
rAddr, err := net.ResolveTCPAddr("tcp", connRemoteAddr)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}
lAddr, err := net.ResolveTCPAddr("tcp", cfg.Server.ListenAddress)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}
log.Debugf("Relay WS client connected from: %s", rAddr)
conn := ws.NewConn(wsConn, lAddr, rAddr)
acceptFn(conn)
}
// logConfig prints all configuration parameters for debugging
func logConfig(cfg *CombinedConfig) {
log.Info("=== Configuration ===")
logServerConfig(cfg)
logComponentsConfig(cfg)
logRelayConfig(cfg)
logManagementConfig(cfg)
log.Info("=== End Configuration ===")
}
func logServerConfig(cfg *CombinedConfig) {
log.Info("--- Server ---")
log.Infof(" Listen address: %s", cfg.Server.ListenAddress)
log.Infof(" Exposed address: %s", cfg.Server.ExposedAddress)
log.Infof(" Healthcheck address: %s", cfg.Server.HealthcheckAddress)
log.Infof(" Metrics port: %d", cfg.Server.MetricsPort)
log.Infof(" Log level: %s", cfg.Server.LogLevel)
log.Infof(" Data dir: %s", cfg.Server.DataDir)
switch {
case cfg.HasTLSCert():
log.Infof(" TLS: cert=%s, key=%s", cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile)
case cfg.HasLetsEncrypt():
log.Infof(" TLS: Let's Encrypt (domains=%v)", cfg.Server.TLS.LetsEncrypt.Domains)
default:
log.Info(" TLS: disabled (using reverse proxy)")
}
}
func logComponentsConfig(cfg *CombinedConfig) {
log.Info("--- Components ---")
log.Infof(" Management: %v (log level: %s)", cfg.Management.Enabled, cfg.Management.LogLevel)
log.Infof(" Signal: %v (log level: %s)", cfg.Signal.Enabled, cfg.Signal.LogLevel)
log.Infof(" Relay: %v (log level: %s)", cfg.Relay.Enabled, cfg.Relay.LogLevel)
}
func logRelayConfig(cfg *CombinedConfig) {
if !cfg.Relay.Enabled {
return
}
log.Info("--- Relay ---")
log.Infof(" Exposed address: %s", cfg.Relay.ExposedAddress)
log.Infof(" Auth secret: %s...", maskSecret(cfg.Relay.AuthSecret))
if cfg.Relay.Stun.Enabled {
log.Infof(" STUN ports: %v (log level: %s)", cfg.Relay.Stun.Ports, cfg.Relay.Stun.LogLevel)
} else {
log.Info(" STUN: disabled")
}
}
func logManagementConfig(cfg *CombinedConfig) {
if !cfg.Management.Enabled {
return
}
log.Info("--- Management ---")
log.Infof(" Data dir: %s", cfg.Management.DataDir)
log.Infof(" DNS domain: %s", cfg.Management.DnsDomain)
log.Infof(" Store engine: %s", cfg.Management.Store.Engine)
if cfg.Server.Store.DSN != "" {
log.Infof(" Store DSN: %s", maskDSNPassword(cfg.Server.Store.DSN))
}
log.Info(" Auth (embedded IdP):")
log.Infof(" Issuer: %s", cfg.Management.Auth.Issuer)
log.Infof(" Dashboard redirect URIs: %v", cfg.Management.Auth.DashboardRedirectURIs)
log.Infof(" CLI redirect URIs: %v", cfg.Management.Auth.CLIRedirectURIs)
log.Info(" Client settings:")
log.Infof(" Signal URI: %s", cfg.Management.SignalURI)
for _, s := range cfg.Management.Stuns {
log.Infof(" STUN: %s", s.URI)
}
if len(cfg.Management.Relays.Addresses) > 0 {
log.Infof(" Relay addresses: %v", cfg.Management.Relays.Addresses)
log.Infof(" Relay credentials TTL: %s", cfg.Management.Relays.CredentialsTTL)
}
}
// logEnvVars logs all NB_ environment variables that are currently set
func logEnvVars() {
log.Info("=== Environment Variables ===")
found := false
for _, env := range os.Environ() {
if strings.HasPrefix(env, "NB_") {
key, _, _ := strings.Cut(env, "=")
value := os.Getenv(key)
if strings.Contains(strings.ToLower(key), "secret") || strings.Contains(strings.ToLower(key), "key") || strings.Contains(strings.ToLower(key), "password") {
value = maskSecret(value)
}
log.Infof(" %s=%s", key, value)
found = true
}
}
if !found {
log.Info(" (none set)")
}
log.Info("=== End Environment Variables ===")
}
// maskDSNPassword masks the password in a DSN string.
// Handles both key=value format ("password=secret") and URI format ("user:secret@host").
func maskDSNPassword(dsn string) string {
// Key=value format: "host=localhost user=nb password=secret dbname=nb"
if strings.Contains(dsn, "password=") {
parts := strings.Fields(dsn)
for i, p := range parts {
if strings.HasPrefix(p, "password=") {
parts[i] = "password=****"
}
}
return strings.Join(parts, " ")
}
// URI format: "user:password@host..."
if atIdx := strings.Index(dsn, "@"); atIdx != -1 {
prefix := dsn[:atIdx]
if colonIdx := strings.Index(prefix, ":"); colonIdx != -1 {
return prefix[:colonIdx+1] + "****" + dsn[atIdx:]
}
}
return dsn
}
// maskSecret returns first 4 chars of secret followed by "..."
func maskSecret(secret string) string {
if len(secret) <= 4 {
return "****"
}
return secret[:4] + "..."
}

219
combined/cmd/token.go Normal file
View File

@@ -0,0 +1,219 @@
package cmd
import (
"context"
"fmt"
"os"
"strconv"
"text/tabwriter"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
var (
tokenName string
tokenExpireIn string
tokenDatadir string
tokenCmd = &cobra.Command{
Use: "token",
Short: "Manage proxy access tokens",
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
}
tokenCreateCmd = &cobra.Command{
Use: "create",
Short: "Create a new proxy access token",
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
RunE: tokenCreateRun,
}
tokenListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List all proxy access tokens",
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
RunE: tokenListRun,
}
tokenRevokeCmd = &cobra.Command{
Use: "revoke [token-id]",
Short: "Revoke a proxy access token",
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
Args: cobra.ExactArgs(1),
RunE: tokenRevokeRun,
}
)
func init() {
tokenCmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
tokenCreateCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
tokenCreateCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
tokenCreateCmd.MarkFlagRequired("name") //nolint
tokenCmd.AddCommand(tokenCreateCmd, tokenListCmd, tokenRevokeCmd)
rootCmd.AddCommand(tokenCmd)
}
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
//nolint
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource)
// Load combined server YAML config
cfg, err := LoadConfig(configPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
// Get datadir from config or override
datadir := cfg.Server.DataDir
if tokenDatadir != "" {
datadir = tokenDatadir
}
// Get store engine from config
storeEngine := types.Engine(cfg.Server.Store.Engine)
if storeEngine == "" {
storeEngine = "sqlite"
}
s, err := store.NewStore(ctx, storeEngine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}
func tokenCreateRun(cmd *cobra.Command, _ []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
expiresIn, err := parseDuration(tokenExpireIn)
if err != nil {
return fmt.Errorf("parse expiration: %w", err)
}
generated, err := types.CreateNewProxyAccessToken(tokenName, expiresIn, nil, "CLI")
if err != nil {
return fmt.Errorf("generate token: %w", err)
}
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
return fmt.Errorf("save token: %w", err)
}
fmt.Println("Token created successfully!") //nolint:forbidigo
fmt.Printf("Token: %s\n", generated.PlainToken) //nolint:forbidigo
fmt.Println() //nolint:forbidigo
fmt.Println("IMPORTANT: Save this token now. It will not be shown again.") //nolint:forbidigo
fmt.Printf("Token ID: %s\n", generated.ID) //nolint:forbidigo
return nil
})
}
func tokenListRun(cmd *cobra.Command, _ []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
if err != nil {
return fmt.Errorf("list tokens: %w", err)
}
if len(tokens) == 0 {
fmt.Println("No proxy access tokens found.") //nolint:forbidigo
return nil
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
for _, t := range tokens {
expires := "never"
if t.ExpiresAt != nil {
expires = t.ExpiresAt.Format("2006-01-02")
}
lastUsed := "never"
if t.LastUsed != nil {
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
}
revoked := "no"
if t.Revoked {
revoked = "yes"
}
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
t.ID,
t.Name,
t.CreatedAt.Format("2006-01-02"),
expires,
lastUsed,
revoked,
)
}
w.Flush()
return nil
})
}
func tokenRevokeRun(cmd *cobra.Command, args []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
tokenID := args[0]
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
return fmt.Errorf("revoke token: %w", err)
}
fmt.Printf("Token %s revoked successfully.\n", tokenID) //nolint:forbidigo
return nil
})
}
// parseDuration parses a duration string with support for days (e.g., "30d", "365d").
// An empty string returns zero duration (no expiration).
func parseDuration(s string) (time.Duration, error) {
if len(s) == 0 {
return 0, nil
}
if s[len(s)-1] == 'd' {
d, err := strconv.Atoi(s[:len(s)-1])
if err != nil {
return 0, fmt.Errorf("invalid day format: %s", s)
}
if d <= 0 {
return 0, fmt.Errorf("duration must be positive: %s", s)
}
return time.Duration(d) * 24 * time.Hour, nil
}
d, err := time.ParseDuration(s)
if err != nil {
return 0, err
}
if d <= 0 {
return 0, fmt.Errorf("duration must be positive: %s", s)
}
return d, nil
}

View File

@@ -0,0 +1,111 @@
# NetBird Combined Server Configuration
# Copy this file to config.yaml and customize for your deployment
#
# This is a Management server with optional embedded Signal, Relay, and STUN services.
# By default, all services run locally. You can use external services instead by
# setting the corresponding override fields.
#
# Architecture:
# - Management: Always runs locally (this IS the management server)
# - Signal: Local by default; set 'signalUri' to use external (disables local)
# - Relay: Local by default; set 'relays' to use external (disables local)
# - STUN: Local on port 3478 by default; set 'stuns' to use external instead
server:
# Main HTTP/gRPC port for all services (Management, Signal, Relay)
listenAddress: ":443"
# Public address that peers will use to connect to this server
# Used for relay connections and management DNS domain
# Format: protocol://hostname:port (e.g., https://server.mycompany.com:443)
exposedAddress: "https://server.mycompany.com:443"
# STUN server ports (defaults to [3478] if not specified; set 'stuns' to use external)
# stunPorts:
# - 3478
# Metrics endpoint port
metricsPort: 9090
# Healthcheck endpoint address
healthcheckAddress: ":9000"
# Logging configuration
logLevel: "info" # Default log level for all components: panic, fatal, error, warn, info, debug, trace
logFile: "console" # "console" or path to log file
# TLS configuration (optional)
tls:
certFile: ""
keyFile: ""
letsencrypt:
enabled: false
dataDir: ""
domains: []
email: ""
awsRoute53: false
# Shared secret for relay authentication (required when running local relay)
authSecret: "your-secret-key-here"
# Data directory for all services
dataDir: "/var/lib/netbird/"
# ============================================================================
# External Service Overrides (optional)
# Use these to point to external Signal, Relay, or STUN servers instead of
# running them locally. When set, the corresponding local service is disabled.
# ============================================================================
# External STUN servers - disables local STUN server
# stuns:
# - uri: "stun:stun.example.com:3478"
# - uri: "stun:stun.example.com:3479"
# External relay servers - disables local relay server
# relays:
# addresses:
# - "rels://relay.example.com:443"
# credentialsTTL: "12h"
# secret: "relay-shared-secret"
# External signal server - disables local signal server
# signalUri: "https://signal.example.com:443"
# ============================================================================
# Management Settings
# ============================================================================
# Metrics and updates
disableAnonymousMetrics: false
disableGeoliteUpdate: false
# Embedded authentication/identity provider (Dex) configuration (always enabled)
auth:
# OIDC issuer URL - must be publicly accessible
issuer: "https://server.mycompany.com/oauth2"
localAuthDisabled: false
signKeyRefreshEnabled: false
# OAuth2 redirect URIs for dashboard
dashboardRedirectURIs:
- "https://app.netbird.io/nb-auth"
- "https://app.netbird.io/nb-silent-auth"
# OAuth2 redirect URIs for CLI
cliRedirectURIs:
- "http://localhost:53000/"
# Optional initial admin user
# owner:
# email: "admin@example.com"
# password: "initial-password"
# Store configuration
store:
engine: "sqlite" # sqlite, postgres, or mysql
dsn: "" # Connection string for postgres or mysql
encryptionKey: ""
# Reverse proxy settings (optional)
# reverseProxy:
# trustedHTTPProxies: []
# trustedHTTPProxiesCount: 0
# trustedPeers: []

View File

@@ -0,0 +1,115 @@
# Simplified Combined NetBird Server Configuration
# Copy this file to config.yaml and customize for your deployment
# Server-wide settings
server:
# Main HTTP/gRPC port for all services (Management, Signal, Relay)
listenAddress: ":443"
# Metrics endpoint port
metricsPort: 9090
# Healthcheck endpoint address
healthcheckAddress: ":9000"
# Logging configuration
logLevel: "info" # panic, fatal, error, warn, info, debug, trace
logFile: "console" # "console" or path to log file
# TLS configuration (optional)
tls:
certFile: ""
keyFile: ""
letsencrypt:
enabled: false
dataDir: ""
domains: []
email: ""
awsRoute53: false
# Relay service configuration
relay:
# Enable/disable the relay service
enabled: true
# Public address that peers will use to connect to this relay
# Format: hostname:port or ip:port
exposedAddress: "relay.example.com:443"
# Shared secret for relay authentication (required when enabled)
authSecret: "your-secret-key-here"
# Log level for relay (reserved for future use, currently uses global log level)
logLevel: "info"
# Embedded STUN server (optional)
stun:
enabled: false
ports: [3478]
logLevel: "info"
# Signal service configuration
signal:
# Enable/disable the signal service
enabled: true
# Log level for signal (reserved for future use, currently uses global log level)
logLevel: "info"
# Management service configuration
management:
# Enable/disable the management service
enabled: true
# Data directory for management service
dataDir: "/var/lib/netbird/"
# DNS domain for the management server
dnsDomain: ""
# Metrics and updates
disableAnonymousMetrics: false
disableGeoliteUpdate: false
auth:
# OIDC issuer URL - must be publicly accessible
issuer: "https://management.example.com/oauth2"
localAuthDisabled: false
signKeyRefreshEnabled: false
# OAuth2 redirect URIs for dashboard
dashboardRedirectURIs:
- "https://app.example.com/nb-auth"
- "https://app.example.com/nb-silent-auth"
# OAuth2 redirect URIs for CLI
cliRedirectURIs:
- "http://localhost:53000/"
# Optional initial admin user
# owner:
# email: "admin@example.com"
# password: "initial-password"
# External STUN servers (for client config)
stuns: []
# - uri: "stun:stun.example.com:3478"
# External relay servers (for client config)
relays:
addresses: []
# - "rels://relay.example.com:443"
credentialsTTL: "12h"
secret: ""
# External signal server URI (for client config)
signalUri: ""
# Store configuration
store:
engine: "sqlite" # sqlite, postgres, or mysql
dsn: "" # Connection string for postgres or mysql
encryptionKey: ""
# Reverse proxy settings
reverseProxy:
trustedHTTPProxies: []
trustedHTTPProxiesCount: 0
trustedPeers: []

13
combined/main.go Normal file
View File

@@ -0,0 +1,13 @@
package main
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/combined/cmd"
)
func main() {
if err := cmd.Execute(); err != nil {
log.Fatalf("failed to execute command: %v", err)
}
}

2
go.mod
View File

@@ -69,7 +69,7 @@ require (
github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/oapi-codegen/runtime v1.1.2
github.com/okta/okta-sdk-golang/v2 v2.18.0

4
go.sum
View File

@@ -406,8 +406,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f h1:CTBf0je/FpKr2lVSMZLak7m8aaWcS6ur4SOfhSSazFI=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25 h1:iwAq/Ncaq0etl4uAlVsbNBzC1yY52o0AmY7uCm2AMTs=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,17 @@
FROM golang:1.25-bookworm AS builder
WORKDIR /app
# Install build dependencies
RUN apt-get update && apt-get install -y gcc libc6-dev && rm -rf /var/lib/apt/lists/*
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN CGO_ENABLED=1 GOOS=linux go build -ldflags="-s -w" -o netbird-mgmt ./management
FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-mgmt","management"]
CMD ["--log-file", "console"]
COPY --from=builder /app/netbird-mgmt /go/bin/netbird-mgmt

View File

@@ -57,7 +57,7 @@ var (
// detect whether user specified a port
userPort := cmd.Flag("port").Changed
config, err = loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
config, err = LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", nbconfig.MgmtConfigPath, err)
}
@@ -135,35 +135,35 @@ var (
}
)
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
func LoadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
loadedConfig := &nbconfig.Config{}
if _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig); err != nil {
return nil, err
}
applyCommandLineOverrides(loadedConfig)
ApplyCommandLineOverrides(loadedConfig)
// Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled
err := applyEmbeddedIdPConfig(ctx, loadedConfig)
err := ApplyEmbeddedIdPConfig(ctx, loadedConfig)
if err != nil {
return nil, err
}
if err := applyOIDCConfig(ctx, loadedConfig); err != nil {
if err := ApplyOIDCConfig(ctx, loadedConfig); err != nil {
return nil, err
}
logConfigInfo(loadedConfig)
LogConfigInfo(loadedConfig)
if err := ensureEncryptionKey(ctx, mgmtConfigPath, loadedConfig); err != nil {
if err := EnsureEncryptionKey(ctx, mgmtConfigPath, loadedConfig); err != nil {
return nil, err
}
return loadedConfig, nil
}
// applyCommandLineOverrides applies command-line flag overrides to the config
func applyCommandLineOverrides(cfg *nbconfig.Config) {
// ApplyCommandLineOverrides applies command-line flag overrides to the config
func ApplyCommandLineOverrides(cfg *nbconfig.Config) {
if mgmtLetsencryptDomain != "" {
cfg.HttpConfig.LetsEncryptDomain = mgmtLetsencryptDomain
}
@@ -176,9 +176,9 @@ func applyCommandLineOverrides(cfg *nbconfig.Config) {
}
}
// applyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
// ApplyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
// This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig.
func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
return nil
}
@@ -227,8 +227,8 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
return nil
}
// applyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
// ApplyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
func ApplyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint
if oidcEndpoint == "" {
return nil
@@ -254,16 +254,16 @@ func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
oidcConfig.JwksURI, cfg.HttpConfig.AuthKeysLocation)
cfg.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if err := applyDeviceAuthFlowConfig(ctx, cfg, &oidcConfig, oidcEndpoint); err != nil {
if err := ApplyDeviceAuthFlowConfig(ctx, cfg, &oidcConfig, oidcEndpoint); err != nil {
return err
}
applyPKCEFlowConfig(ctx, cfg, &oidcConfig)
ApplyPKCEFlowConfig(ctx, cfg, &oidcConfig)
return nil
}
// applyDeviceAuthFlowConfig applies OIDC config to DeviceAuthorizationFlow if enabled
func applyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse, oidcEndpoint string) error {
// ApplyDeviceAuthFlowConfig applies OIDC config to DeviceAuthorizationFlow if enabled
func ApplyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse, oidcEndpoint string) error {
if cfg.DeviceAuthorizationFlow == nil || strings.ToLower(cfg.DeviceAuthorizationFlow.Provider) == string(nbconfig.NONE) {
return nil
}
@@ -290,8 +290,8 @@ func applyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcCo
return nil
}
// applyPKCEFlowConfig applies OIDC config to PKCEAuthorizationFlow if configured
func applyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse) {
// ApplyPKCEFlowConfig applies OIDC config to PKCEAuthorizationFlow if configured
func ApplyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse) {
if cfg.PKCEAuthorizationFlow == nil {
return
}
@@ -304,8 +304,8 @@ func applyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *
cfg.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
}
// logConfigInfo logs informational messages about the loaded configuration
func logConfigInfo(cfg *nbconfig.Config) {
// LogConfigInfo logs informational messages about the loaded configuration
func LogConfigInfo(cfg *nbconfig.Config) {
if cfg.EmbeddedIdP != nil {
log.Infof("running with the embedded IdP: %v", cfg.EmbeddedIdP.Issuer)
}
@@ -314,8 +314,8 @@ func logConfigInfo(cfg *nbconfig.Config) {
}
}
// ensureEncryptionKey generates and saves a DataStoreEncryptionKey if not set
func ensureEncryptionKey(ctx context.Context, configPath string, cfg *nbconfig.Config) error {
// EnsureEncryptionKey generates and saves a DataStoreEncryptionKey if not set
func EnsureEncryptionKey(ctx context.Context, configPath string, cfg *nbconfig.Config) error {
if cfg.DataStoreEncryptionKey != "" {
return nil
}

View File

@@ -30,7 +30,7 @@ func Test_loadMgmtConfig(t *testing.T) {
t.Fatalf("failed to create config: %s", err)
}
cfg, err := loadMgmtConfig(context.Background(), tmpFile)
cfg, err := LoadMgmtConfig(context.Background(), tmpFile)
if err != nil {
t.Fatalf("failed to load management config: %s", err)
}

View File

@@ -67,6 +67,7 @@ func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Sto
return fmt.Errorf("init log: %w", err)
}
//nolint
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource)
config, err := loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
@@ -108,11 +109,11 @@ func tokenCreateRun(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("save token: %w", err)
}
fmt.Println("Token created successfully!")
fmt.Printf("Token: %s\n", generated.PlainToken)
fmt.Println()
fmt.Println("IMPORTANT: Save this token now. It will not be shown again.")
fmt.Printf("Token ID: %s\n", generated.ID)
fmt.Println("Token created successfully!") //nolint:forbidigo
fmt.Printf("Token: %s\n", generated.PlainToken) //nolint:forbidigo
fmt.Println() //nolint:forbidigo
fmt.Println("IMPORTANT: Save this token now. It will not be shown again.") //nolint:forbidigo
fmt.Printf("Token ID: %s\n", generated.ID) //nolint:forbidigo
return nil
})
@@ -126,7 +127,7 @@ func tokenListRun(cmd *cobra.Command, _ []string) error {
}
if len(tokens) == 0 {
fmt.Println("No proxy access tokens found.")
fmt.Println("No proxy access tokens found.") //nolint:forbidigo
return nil
}
@@ -174,7 +175,7 @@ func tokenRevokeRun(cmd *cobra.Command, args []string) error {
return fmt.Errorf("revoke token: %w", err)
}
fmt.Printf("Token %s revoked successfully.\n", tokenID)
fmt.Printf("Token %s revoked successfully.\n", tokenID) //nolint:forbidigo
return nil
})
}

View File

@@ -162,3 +162,17 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
}
// CreateProxyPeer mocks base method.
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
ret0, _ := ret[0].(error)
return ret0
}
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
}

View File

@@ -13,7 +13,7 @@ import (
type AccessLogEntry struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
ProxyID string `gorm:"index"`
ServiceID string `gorm:"index"`
Timestamp time.Time `gorm:"index"`
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
Method string `gorm:"index"`
@@ -27,28 +27,28 @@ type AccessLogEntry struct {
}
// FromProto creates an AccessLogEntry from a proto.AccessLog
func (a *AccessLogEntry) FromProto(proxyLog *proto.AccessLog) {
a.ID = proxyLog.GetLogId()
a.ProxyID = proxyLog.GetServiceId()
a.Timestamp = proxyLog.GetTimestamp().AsTime()
a.Method = proxyLog.GetMethod()
a.Host = proxyLog.GetHost()
a.Path = proxyLog.GetPath()
a.Duration = time.Duration(proxyLog.GetDurationMs()) * time.Millisecond
a.StatusCode = int(proxyLog.GetResponseCode())
a.UserId = proxyLog.GetUserId()
a.AuthMethodUsed = proxyLog.GetAuthMechanism()
a.AccountID = proxyLog.GetAccountId()
func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
a.ID = serviceLog.GetLogId()
a.ServiceID = serviceLog.GetServiceId()
a.Timestamp = serviceLog.GetTimestamp().AsTime()
a.Method = serviceLog.GetMethod()
a.Host = serviceLog.GetHost()
a.Path = serviceLog.GetPath()
a.Duration = time.Duration(serviceLog.GetDurationMs()) * time.Millisecond
a.StatusCode = int(serviceLog.GetResponseCode())
a.UserId = serviceLog.GetUserId()
a.AuthMethodUsed = serviceLog.GetAuthMechanism()
a.AccountID = serviceLog.GetAccountId()
if sourceIP := proxyLog.GetSourceIp(); sourceIP != "" {
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
if ip, err := netip.ParseAddr(sourceIP); err == nil {
a.GeoLocation.ConnectionIP = net.IP(ip.AsSlice())
}
}
if !proxyLog.GetAuthSuccess() {
if !serviceLog.GetAuthSuccess() {
a.Reason = "Authentication failed"
} else if proxyLog.GetResponseCode() >= 400 {
} else if serviceLog.GetResponseCode() >= 400 {
a.Reason = "Request failed"
}
}
@@ -88,7 +88,7 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
return &api.ProxyAccessLog{
Id: a.ID,
ProxyId: a.ProxyID,
ServiceId: a.ServiceID,
Timestamp: a.Timestamp,
Method: a.Method,
Host: a.Host,

View File

@@ -21,7 +21,7 @@ type AccessLogFilter struct {
PageSize int
// Filtering parameters
Search *string // General search across host, path, source IP, and user fields
Search *string // General search across log ID, host, path, source IP, and user fields
SourceIP *string // Filter by source IP address
Host *string // Filter by host header
Path *string // Filter by request path (supports LIKE pattern)

View File

@@ -44,11 +44,11 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
if err := m.store.CreateAccessLog(ctx, logEntry); err != nil {
log.WithContext(ctx).WithFields(log.Fields{
"proxy_id": logEntry.ProxyID,
"method": logEntry.Method,
"host": logEntry.Host,
"path": logEntry.Path,
"status": logEntry.StatusCode,
"service_id": logEntry.ServiceID,
"method": logEntry.Method,
"host": logEntry.Host,
"path": logEntry.Path,
"status": logEntry.StatusCode,
}).Errorf("failed to save access log: %v", err)
return err
}

View File

@@ -0,0 +1,17 @@
package domain
type Type string
const (
TypeFree Type = "free"
TypeCustom Type = "custom"
)
type Domain struct {
ID string `gorm:"unique;primaryKey;autoIncrement"`
Domain string `gorm:"unique"` // Domain records must be unique, this avoids domain reuse across accounts.
AccountID string `gorm:"index"`
TargetCluster string // The proxy cluster this domain should be validated against
Type Type `gorm:"-"`
Validated bool
}

View File

@@ -0,0 +1,12 @@
package domain
import (
"context"
)
type Manager interface {
GetDomains(ctx context.Context, accountID, userID string) ([]*Domain, error)
CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*Domain, error)
DeleteDomain(ctx context.Context, accountID, userID, domainID string) error
ValidateDomain(ctx context.Context, accountID, userID, domainID string)
}

View File

@@ -1,10 +1,12 @@
package domain
package manager
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -26,11 +28,11 @@ func RegisterEndpoints(router *mux.Router, manager Manager) {
router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS")
}
func domainTypeToApi(t domainType) api.ReverseProxyDomainType {
func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType {
switch t {
case TypeCustom:
case domain.TypeCustom:
return api.ReverseProxyDomainTypeCustom
case TypeFree:
case domain.TypeFree:
return api.ReverseProxyDomainTypeFree
}
// By default return as a "free" domain as that is more restrictive.
@@ -38,7 +40,7 @@ func domainTypeToApi(t domainType) api.ReverseProxyDomainType {
return api.ReverseProxyDomainTypeFree
}
func domainToApi(d *Domain) api.ReverseProxyDomain {
func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
resp := api.ReverseProxyDomain{
Domain: d.Domain,
Id: d.ID,
@@ -58,7 +60,7 @@ func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
return
}
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId)
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -85,7 +87,7 @@ func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
return
}
domain, err := h.manager.CreateDomain(r.Context(), userAuth.AccountId, req.Domain, req.TargetCluster)
domain, err := h.manager.CreateDomain(r.Context(), userAuth.AccountId, userAuth.UserId, req.Domain, req.TargetCluster)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -107,7 +109,7 @@ func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.manager.DeleteDomain(r.Context(), userAuth.AccountId, domainID); err != nil {
if err := h.manager.DeleteDomain(r.Context(), userAuth.AccountId, userAuth.UserId, domainID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -128,7 +130,7 @@ func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.R
return
}
go h.manager.ValidateDomain(userAuth.AccountId, domainID)
go h.manager.ValidateDomain(r.Context(), userAuth.AccountId, userAuth.UserId, domainID)
w.WriteHeader(http.StatusAccepted)
}

View File

@@ -1,40 +1,29 @@
package domain
package manager
import (
"context"
"fmt"
"net"
"net/url"
"strings"
"github.com/netbirdio/netbird/management/server/types"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
type domainType string
const (
TypeFree domainType = "free"
TypeCustom domainType = "custom"
)
type Domain struct {
ID string `gorm:"unique;primaryKey;autoIncrement"`
Domain string `gorm:"unique"` // Domain records must be unique, this avoids domain reuse across accounts.
AccountID string `gorm:"index"`
TargetCluster string // The proxy cluster this domain should be validated against
Type domainType `gorm:"-"`
Validated bool
}
type store interface {
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*Domain, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
ListCustomDomains(ctx context.Context, accountID string) ([]*Domain, error)
CreateCustomDomain(ctx context.Context, accountID string, domainName string, targetCluster string, validated bool) (*Domain, error)
UpdateCustomDomain(ctx context.Context, accountID string, d *Domain) (*Domain, error)
ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error)
CreateCustomDomain(ctx context.Context, accountID string, domainName string, targetCluster string, validated bool) (*domain.Domain, error)
UpdateCustomDomain(ctx context.Context, accountID string, d *domain.Domain) (*domain.Domain, error)
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
}
@@ -43,28 +32,38 @@ type proxyURLProvider interface {
}
type Manager struct {
store store
validator Validator
proxyURLProvider proxyURLProvider
store store
validator domain.Validator
proxyURLProvider proxyURLProvider
permissionsManager permissions.Manager
}
func NewManager(store store, proxyURLProvider proxyURLProvider) Manager {
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager {
return Manager{
store: store,
proxyURLProvider: proxyURLProvider,
validator: Validator{
resolver: net.DefaultResolver,
validator: domain.Validator{
Resolver: net.DefaultResolver,
},
permissionsManager: permissionsManager,
}
}
func (m Manager) GetDomains(ctx context.Context, accountID string) ([]*Domain, error) {
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
domains, err := m.store.ListCustomDomains(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("list custom domains: %w", err)
}
var ret []*Domain
var ret []*domain.Domain
// Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
@@ -75,30 +74,37 @@ func (m Manager) GetDomains(ctx context.Context, accountID string) ([]*Domain, e
}).Debug("getting domains with proxy allow list")
for _, cluster := range allowList {
ret = append(ret, &Domain{
ret = append(ret, &domain.Domain{
Domain: cluster,
AccountID: accountID,
Type: TypeFree,
Type: domain.TypeFree,
Validated: true,
})
}
// Add custom domains.
for _, domain := range domains {
ret = append(ret, &Domain{
ID: domain.ID,
Domain: domain.Domain,
for _, d := range domains {
ret = append(ret, &domain.Domain{
ID: d.ID,
Domain: d.Domain,
AccountID: accountID,
TargetCluster: domain.TargetCluster,
Type: TypeCustom,
Validated: domain.Validated,
TargetCluster: d.TargetCluster,
Type: domain.TypeCustom,
Validated: d.Validated,
})
}
return ret, nil
}
func (m Manager) CreateDomain(ctx context.Context, accountID, domainName, targetCluster string) (*Domain, error) {
func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
// Verify the target cluster is in the available clusters
allowList := m.proxyURLAllowList()
@@ -126,7 +132,15 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, domainName, target
return d, nil
}
func (m Manager) DeleteDomain(ctx context.Context, accountID, domainID string) error {
func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
// TODO: check for "no records" type error. Because that is a success condition.
return fmt.Errorf("delete domain from store: %w", err)
@@ -134,7 +148,22 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, domainID string) e
return nil
}
func (m Manager) ValidateDomain(accountID, domainID string) {
func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).WithError(err).Error("validate domain")
return
}
if !ok {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).WithError(err).Error("validate domain")
}
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
@@ -193,57 +222,19 @@ func (m Manager) ValidateDomain(accountID, domainID string) {
}
// proxyURLAllowList retrieves a list of currently connected proxies and
// their URLs (as reported by the proxy servers). It performs some clean
// up on those URLs to attempt to retrieve domain names as we would
// expect to see them in a validation check.
// their URLs
func (m Manager) proxyURLAllowList() []string {
var reverseProxyAddresses []string
if m.proxyURLProvider != nil {
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
}
var allowedProxyURLs []string
for _, addr := range reverseProxyAddresses {
if addr == "" {
continue
}
host := extractHostFromAddress(addr)
if host != "" {
allowedProxyURLs = append(allowedProxyURLs, host)
}
}
return allowedProxyURLs
}
// extractHostFromAddress extracts the hostname from an address string.
// It handles both URL format (https://host:port) and plain hostname (host or host:port).
func extractHostFromAddress(addr string) string {
// If it looks like a URL with a scheme, parse it
if strings.Contains(addr, "://") {
proxyUrl, err := url.Parse(addr)
if err != nil {
log.WithError(err).Debugf("failed to parse proxy URL %s", addr)
return ""
}
host, _, err := net.SplitHostPort(proxyUrl.Host)
if err != nil {
return proxyUrl.Host
}
return host
}
// Otherwise treat as hostname or host:port
host, _, err := net.SplitHostPort(addr)
if err != nil {
// No port, use as-is
return addr
}
return host
return reverseProxyAddresses
}
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by looking up the CNAME target.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, domain string) (string, error) {
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
allowList := m.proxyURLAllowList()
if len(allowList) == 0 {
return "", fmt.Errorf("no proxy clusters available")
@@ -253,10 +244,36 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, domain string) (st
return cluster, nil
}
cluster, valid := m.validator.ValidateWithCluster(ctx, domain, allowList)
customDomains, err := m.store.ListCustomDomains(ctx, accountID)
if err != nil {
return "", fmt.Errorf("list custom domains: %w", err)
}
targetCluster, valid := extractClusterFromCustomDomains(domain, customDomains)
if valid {
return cluster, nil
return targetCluster, nil
}
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
}
func extractClusterFromCustomDomains(domain string, customDomains []*domain.Domain) (string, bool) {
for _, customDomain := range customDomains {
if strings.HasSuffix(domain, "."+customDomain.Domain) {
return customDomain.TargetCluster, true
}
}
return "", false
}
// ExtractClusterFromFreeDomain extracts the cluster address from a free domain.
// Free domains have the format: <name>.<nonce>.<cluster> (e.g., myapp.abc123.eu.proxy.netbird.io)
// It matches the domain suffix against available clusters and returns the matching cluster.
func ExtractClusterFromFreeDomain(domain string, availableClusters []string) (string, bool) {
for _, cluster := range availableClusters {
if strings.HasSuffix(domain, "."+cluster) {
return cluster, true
}
}
return "", false
}

View File

@@ -13,15 +13,15 @@ type resolver interface {
}
type Validator struct {
resolver resolver
Resolver resolver
}
// NewValidator initializes a validator with a specific DNS resolver.
// If a Validator is used without specifying a resolver, then it will
// NewValidator initializes a validator with a specific DNS Resolver.
// If a Validator is used without specifying a Resolver, then it will
// use the net.DefaultResolver.
func NewValidator(resolver resolver) *Validator {
return &Validator{
resolver: resolver,
Resolver: resolver,
}
}
@@ -39,8 +39,8 @@ func (v *Validator) IsValid(ctx context.Context, domain string, accept []string)
// ValidateWithCluster validates a custom domain and returns the matched cluster address.
// Returns the cluster address and true if valid, or empty string and false if invalid.
func (v *Validator) ValidateWithCluster(ctx context.Context, domain string, accept []string) (string, bool) {
if v.resolver == nil {
v.resolver = net.DefaultResolver
if v.Resolver == nil {
v.Resolver = net.DefaultResolver
}
lookupDomain := "validation." + domain
@@ -50,7 +50,7 @@ func (v *Validator) ValidateWithCluster(ctx context.Context, domain string, acce
"acceptList": accept,
}).Debug("looking up CNAME for domain validation")
cname, err := v.resolver.LookupCNAME(ctx, lookupDomain)
cname, err := v.Resolver.LookupCNAME(ctx, lookupDomain)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
@@ -86,15 +86,3 @@ func (v *Validator) ValidateWithCluster(ctx context.Context, domain string, acce
}).Warn("domain CNAME does not match any accepted cluster")
return "", false
}
// ExtractClusterFromFreeDomain extracts the cluster address from a free domain.
// Free domains have the format: <name>.<nonce>.<cluster> (e.g., myapp.abc123.eu.proxy.netbird.io)
// It matches the domain suffix against available clusters and returns the matching cluster.
func ExtractClusterFromFreeDomain(domain string, availableClusters []string) (string, bool) {
for _, cluster := range availableClusters {
if strings.HasSuffix(domain, "."+cluster) {
return cluster, true
}
}
return "", false
}

View File

@@ -1,21 +1,23 @@
package reverseproxy
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
import (
"context"
)
type Manager interface {
GetAllReverseProxies(ctx context.Context, accountID, userID string) ([]*ReverseProxy, error)
GetReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) (*ReverseProxy, error)
CreateReverseProxy(ctx context.Context, accountID, userID string, reverseProxy *ReverseProxy) (*ReverseProxy, error)
UpdateReverseProxy(ctx context.Context, accountID, userID string, reverseProxy *ReverseProxy) (*ReverseProxy, error)
DeleteReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error
SetStatus(ctx context.Context, accountID, reverseProxyID string, status ProxyStatus) error
ReloadAllReverseProxiesForAccount(ctx context.Context, accountID string) error
ReloadReverseProxy(ctx context.Context, accountID, reverseProxyID string) error
GetGlobalReverseProxies(ctx context.Context) ([]*ReverseProxy, error)
GetProxyByID(ctx context.Context, accountID, reverseProxyID string) (*ReverseProxy, error)
GetAccountReverseProxies(ctx context.Context, accountID string) ([]*ReverseProxy, error)
GetProxyIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
ReloadService(ctx context.Context, accountID, serviceID string) error
GetGlobalServices(ctx context.Context) ([]*Service, error)
GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error)
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
}

View File

@@ -0,0 +1,225 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./interface.go
// Package reverseproxy is a generated GoMock package.
package reverseproxy
import (
context "context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockManager is a mock of Manager interface.
type MockManager struct {
ctrl *gomock.Controller
recorder *MockManagerMockRecorder
}
// MockManagerMockRecorder is the mock recorder for MockManager.
type MockManagerMockRecorder struct {
mock *MockManager
}
// NewMockManager creates a new mock instance.
func NewMockManager(ctrl *gomock.Controller) *MockManager {
mock := &MockManager{ctrl: ctrl}
mock.recorder = &MockManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// CreateService mocks base method.
func (m *MockManager) CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateService", ctx, accountID, userID, service)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateService indicates an expected call of CreateService.
func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
}
// DeleteService mocks base method.
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteService", ctx, accountID, userID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteService indicates an expected call of DeleteService.
func (mr *MockManagerMockRecorder) DeleteService(ctx, accountID, userID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockManager)(nil).DeleteService), ctx, accountID, userID, serviceID)
}
// GetAccountServices mocks base method.
func (m *MockManager) GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAccountServices", ctx, accountID)
ret0, _ := ret[0].([]*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAccountServices indicates an expected call of GetAccountServices.
func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
}
// GetAllServices mocks base method.
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAllServices", ctx, accountID, userID)
ret0, _ := ret[0].([]*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAllServices indicates an expected call of GetAllServices.
func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
}
// GetGlobalServices mocks base method.
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGlobalServices", ctx)
ret0, _ := ret[0].([]*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGlobalServices indicates an expected call of GetGlobalServices.
func (mr *MockManagerMockRecorder) GetGlobalServices(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGlobalServices", reflect.TypeOf((*MockManager)(nil).GetGlobalServices), ctx)
}
// GetService mocks base method.
func (m *MockManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetService", ctx, accountID, userID, serviceID)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetService indicates an expected call of GetService.
func (mr *MockManagerMockRecorder) GetService(ctx, accountID, userID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockManager)(nil).GetService), ctx, accountID, userID, serviceID)
}
// GetServiceByID mocks base method.
func (m *MockManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByID", ctx, accountID, serviceID)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceByID indicates an expected call of GetServiceByID.
func (mr *MockManagerMockRecorder) GetServiceByID(ctx, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByID", reflect.TypeOf((*MockManager)(nil).GetServiceByID), ctx, accountID, serviceID)
}
// GetServiceIDByTargetID mocks base method.
func (m *MockManager) GetServiceIDByTargetID(ctx context.Context, accountID, resourceID string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceIDByTargetID", ctx, accountID, resourceID)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceIDByTargetID indicates an expected call of GetServiceIDByTargetID.
func (mr *MockManagerMockRecorder) GetServiceIDByTargetID(ctx, accountID, resourceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceIDByTargetID", reflect.TypeOf((*MockManager)(nil).GetServiceIDByTargetID), ctx, accountID, resourceID)
}
// ReloadAllServicesForAccount mocks base method.
func (m *MockManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReloadAllServicesForAccount", ctx, accountID)
ret0, _ := ret[0].(error)
return ret0
}
// ReloadAllServicesForAccount indicates an expected call of ReloadAllServicesForAccount.
func (mr *MockManagerMockRecorder) ReloadAllServicesForAccount(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadAllServicesForAccount", reflect.TypeOf((*MockManager)(nil).ReloadAllServicesForAccount), ctx, accountID)
}
// ReloadService mocks base method.
func (m *MockManager) ReloadService(ctx context.Context, accountID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReloadService", ctx, accountID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// ReloadService indicates an expected call of ReloadService.
func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID)
}
// SetCertificateIssuedAt mocks base method.
func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetCertificateIssuedAt", ctx, accountID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// SetCertificateIssuedAt indicates an expected call of SetCertificateIssuedAt.
func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCertificateIssuedAt", reflect.TypeOf((*MockManager)(nil).SetCertificateIssuedAt), ctx, accountID, serviceID)
}
// SetStatus mocks base method.
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
ret0, _ := ret[0].(error)
return ret0
}
// SetStatus indicates an expected call of SetStatus.
func (mr *MockManagerMockRecorder) SetStatus(ctx, accountID, serviceID, status interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status)
}
// UpdateService mocks base method.
func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateService", ctx, accountID, userID, service)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateService indicates an expected call of UpdateService.
func (mr *MockManagerMockRecorder) UpdateService(ctx, accountID, userID, service interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockManager)(nil).UpdateService), ctx, accountID, userID, service)
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -20,148 +20,148 @@ type handler struct {
manager reverseproxy.Manager
}
// RegisterEndpoints registers all reverse proxy HTTP endpoints.
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domain.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
// RegisterEndpoints registers all service HTTP endpoints.
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
h := &handler{
manager: manager,
}
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
domain.RegisterEndpoints(domainRouter, domainManager)
domainmanager.RegisterEndpoints(domainRouter, domainManager)
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
router.HandleFunc("/reverse-proxies", h.getAllReverseProxies).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies", h.createReverseProxy).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/{proxyId}", h.getReverseProxy).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/{proxyId}", h.updateReverseProxy).Methods("PUT", "OPTIONS")
router.HandleFunc("/reverse-proxies/{proxyId}", h.deleteReverseProxy).Methods("DELETE", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllReverseProxies(w http.ResponseWriter, r *http.Request) {
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allReverseProxies, err := h.manager.GetAllReverseProxies(r.Context(), userAuth.AccountId, userAuth.UserId)
allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
apiReverseProxies := make([]*api.ReverseProxy, 0, len(allReverseProxies))
for _, reverseProxy := range allReverseProxies {
apiReverseProxies = append(apiReverseProxies, reverseProxy.ToAPIResponse())
apiServices := make([]*api.Service, 0, len(allServices))
for _, service := range allServices {
apiServices = append(apiServices, service.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, apiReverseProxies)
util.WriteJSONObject(r.Context(), w, apiServices)
}
func (h *handler) createReverseProxy(w http.ResponseWriter, r *http.Request) {
func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.ReverseProxyRequest
var req api.ServiceRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
reverseProxy := new(reverseproxy.ReverseProxy)
reverseProxy.FromAPIRequest(&req, userAuth.AccountId)
service := new(reverseproxy.Service)
service.FromAPIRequest(&req, userAuth.AccountId)
if err = reverseProxy.Validate(); err != nil {
if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
createdReverseProxy, err := h.manager.CreateReverseProxy(r.Context(), userAuth.AccountId, userAuth.UserId, reverseProxy)
createdService, err := h.manager.CreateService(r.Context(), userAuth.AccountId, userAuth.UserId, service)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, createdReverseProxy.ToAPIResponse())
util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse())
}
func (h *handler) getReverseProxy(w http.ResponseWriter, r *http.Request) {
func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
reverseProxyID := mux.Vars(r)["proxyId"]
if reverseProxyID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "reverse proxy ID is required"), w)
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
return
}
reverseProxy, err := h.manager.GetReverseProxy(r.Context(), userAuth.AccountId, userAuth.UserId, reverseProxyID)
service, err := h.manager.GetService(r.Context(), userAuth.AccountId, userAuth.UserId, serviceID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, reverseProxy.ToAPIResponse())
util.WriteJSONObject(r.Context(), w, service.ToAPIResponse())
}
func (h *handler) updateReverseProxy(w http.ResponseWriter, r *http.Request) {
func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
reverseProxyID := mux.Vars(r)["proxyId"]
if reverseProxyID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "reverse proxy ID is required"), w)
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
return
}
var req api.ReverseProxyRequest
var req api.ServiceRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
reverseProxy := new(reverseproxy.ReverseProxy)
reverseProxy.ID = reverseProxyID
reverseProxy.FromAPIRequest(&req, userAuth.AccountId)
service := new(reverseproxy.Service)
service.ID = serviceID
service.FromAPIRequest(&req, userAuth.AccountId)
if err = reverseProxy.Validate(); err != nil {
if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
updatedReverseProxy, err := h.manager.UpdateReverseProxy(r.Context(), userAuth.AccountId, userAuth.UserId, reverseProxy)
updatedService, err := h.manager.UpdateService(r.Context(), userAuth.AccountId, userAuth.UserId, service)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, updatedReverseProxy.ToAPIResponse())
util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse())
}
func (h *handler) deleteReverseProxy(w http.ResponseWriter, r *http.Request) {
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
reverseProxyID := mux.Vars(r)["proxyId"]
if reverseProxyID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "reverse proxy ID is required"), w)
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
return
}
if err := h.manager.DeleteReverseProxy(r.Context(), userAuth.AccountId, userAuth.UserId, reverseProxyID); err != nil {
if err := h.manager.DeleteService(r.Context(), userAuth.AccountId, userAuth.UserId, serviceID); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -23,7 +23,7 @@ const unknownHostPlaceholder = "unknown"
// ClusterDeriver derives the proxy cluster from a domain.
type ClusterDeriver interface {
DeriveClusterFromDomain(ctx context.Context, domain string) (string, error)
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
}
type managerImpl struct {
@@ -31,23 +31,21 @@ type managerImpl struct {
accountManager account.Manager
permissionsManager permissions.Manager
proxyGRPCServer *nbgrpc.ProxyServiceServer
tokenStore *nbgrpc.OneTimeTokenStore
clusterDeriver ClusterDeriver
}
// NewManager creates a new reverse proxy manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, tokenStore *nbgrpc.OneTimeTokenStore, clusterDeriver ClusterDeriver) reverseproxy.Manager {
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
proxyGRPCServer: proxyGRPCServer,
tokenStore: tokenStore,
clusterDeriver: clusterDeriver,
}
}
func (m *managerImpl) GetAllReverseProxies(ctx context.Context, accountID, userID string) ([]*reverseproxy.ReverseProxy, error) {
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -56,28 +54,28 @@ func (m *managerImpl) GetAllReverseProxies(ctx context.Context, accountID, userI
return nil, status.NewPermissionDeniedError()
}
proxies, err := m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID)
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get reverse proxies: %w", err)
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, proxy := range proxies {
err = m.replaceHostByLookup(ctx, accountID, proxy)
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for proxy %s: %w", proxy.ID, err)
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return proxies, nil
return services, nil
}
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, proxy *reverseproxy.ReverseProxy) error {
for _, target := range proxy.Targets {
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
for _, target := range service.Targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer by id %s for reverse proxy %s: %v", target.TargetId, proxy.ID, err)
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
@@ -85,7 +83,7 @@ func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string,
case reverseproxy.TargetTypeHost:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for reverse proxy %s: %v", target.TargetId, proxy.ID, err)
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
@@ -93,7 +91,7 @@ func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string,
case reverseproxy.TargetTypeDomain:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for reverse proxy %s: %v", target.TargetId, proxy.ID, err)
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
@@ -107,7 +105,7 @@ func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string,
return nil
}
func (m *managerImpl) GetReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.ReverseProxy, error) {
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -116,19 +114,19 @@ func (m *managerImpl) GetReverseProxy(ctx context.Context, accountID, userID, re
return nil, status.NewPermissionDeniedError()
}
proxy, err := m.store.GetReverseProxyByID(ctx, store.LockingStrengthNone, accountID, reverseProxyID)
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get reverse proxy: %w", err)
return nil, fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, proxy)
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for proxy %s: %w", proxy.ID, err)
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
return proxy, nil
return service, nil
}
func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID string, reverseProxy *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) {
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -139,16 +137,17 @@ func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID
var proxyCluster string
if m.clusterDeriver != nil {
proxyCluster, err = m.clusterDeriver.DeriveClusterFromDomain(ctx, reverseProxy.Domain)
proxyCluster, err = m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxies", reverseProxy.Domain)
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
return nil, status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
}
}
reverseProxy.AccountID = accountID
reverseProxy.ProxyCluster = proxyCluster
reverseProxy.InitNewRecord()
err = reverseProxy.Auth.HashSecrets()
service.AccountID = accountID
service.ProxyCluster = proxyCluster
service.InitNewRecord()
err = service.Auth.HashSecrets()
if err != nil {
return nil, fmt.Errorf("hash secrets: %w", err)
}
@@ -158,27 +157,27 @@ func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID
if err != nil {
return nil, fmt.Errorf("generate session keys: %w", err)
}
reverseProxy.SessionPrivateKey = keyPair.PrivateKey
reverseProxy.SessionPublicKey = keyPair.PublicKey
service.SessionPrivateKey = keyPair.PrivateKey
service.SessionPublicKey = keyPair.PublicKey
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Check for duplicate domain
existingReverseProxy, err := transaction.GetReverseProxyByDomain(ctx, accountID, reverseProxy.Domain)
existingService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing reverse proxy: %w", err)
return fmt.Errorf("failed to check existing service: %w", err)
}
}
if existingReverseProxy != nil {
return status.Errorf(status.AlreadyExists, "reverse proxy with domain %s already exists", reverseProxy.Domain)
if existingService != nil {
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
}
if err = validateTargetReferences(ctx, transaction, accountID, reverseProxy.Targets); err != nil {
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err = transaction.CreateReverseProxy(ctx, reverseProxy); err != nil {
return fmt.Errorf("failed to create reverse proxy: %w", err)
if err = transaction.CreateService(ctx, service); err != nil {
return fmt.Errorf("failed to create service: %w", err)
}
return nil
@@ -187,26 +186,21 @@ func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID
return nil, err
}
token, err := m.tokenStore.GenerateToken(accountID, reverseProxy.ID, 5*time.Minute)
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to generate authentication token: %w", err)
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.accountManager.StoreEvent(ctx, userID, reverseProxy.ID, accountID, activity.ReverseProxyCreated, reverseProxy.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, reverseProxy)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for proxy %s: %w", reverseProxy.ID, err)
}
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Create, token, m.proxyGRPCServer.GetOIDCValidationConfig()), reverseProxy.ProxyCluster)
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return reverseProxy, nil
return service, nil
}
func (m *managerImpl) UpdateReverseProxy(ctx context.Context, accountID, userID string, reverseProxy *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) {
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -217,67 +211,67 @@ func (m *managerImpl) UpdateReverseProxy(ctx context.Context, accountID, userID
var oldCluster string
var domainChanged bool
var reverseProxyEnabledChanged bool
var serviceEnabledChanged bool
err = reverseProxy.Auth.HashSecrets()
err = service.Auth.HashSecrets()
if err != nil {
return nil, fmt.Errorf("hash secrets: %w", err)
}
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingReverseProxy, err := transaction.GetReverseProxyByID(ctx, store.LockingStrengthUpdate, accountID, reverseProxy.ID)
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
}
oldCluster = existingReverseProxy.ProxyCluster
oldCluster = existingService.ProxyCluster
if existingReverseProxy.Domain != reverseProxy.Domain {
if existingService.Domain != service.Domain {
domainChanged = true
conflictReverseProxy, err := transaction.GetReverseProxyByDomain(ctx, accountID, reverseProxy.Domain)
conflictService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("check existing reverse proxy: %w", err)
return fmt.Errorf("check existing service: %w", err)
}
}
if conflictReverseProxy != nil && conflictReverseProxy.ID != reverseProxy.ID {
return status.Errorf(status.AlreadyExists, "reverse proxy with domain %s already exists", reverseProxy.Domain)
if conflictService != nil && conflictService.ID != service.ID {
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, reverseProxy.Domain)
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", reverseProxy.Domain)
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
}
reverseProxy.ProxyCluster = newCluster
service.ProxyCluster = newCluster
}
} else {
reverseProxy.ProxyCluster = existingReverseProxy.ProxyCluster
service.ProxyCluster = existingService.ProxyCluster
}
if reverseProxy.Auth.PasswordAuth != nil && reverseProxy.Auth.PasswordAuth.Enabled &&
existingReverseProxy.Auth.PasswordAuth != nil && existingReverseProxy.Auth.PasswordAuth.Enabled &&
reverseProxy.Auth.PasswordAuth.Password == "" {
reverseProxy.Auth.PasswordAuth = existingReverseProxy.Auth.PasswordAuth
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
}
if reverseProxy.Auth.PinAuth != nil && reverseProxy.Auth.PinAuth.Enabled &&
existingReverseProxy.Auth.PinAuth != nil && existingReverseProxy.Auth.PinAuth.Enabled &&
reverseProxy.Auth.PinAuth.Pin == "" {
reverseProxy.Auth.PinAuth = existingReverseProxy.Auth.PinAuth
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
service.Auth.PinAuth.Pin == "" {
service.Auth.PinAuth = existingService.Auth.PinAuth
}
reverseProxy.Meta = existingReverseProxy.Meta
reverseProxy.SessionPrivateKey = existingReverseProxy.SessionPrivateKey
reverseProxy.SessionPublicKey = existingReverseProxy.SessionPublicKey
reverseProxyEnabledChanged = existingReverseProxy.Enabled != reverseProxy.Enabled
service.Meta = existingService.Meta
service.SessionPrivateKey = existingService.SessionPrivateKey
service.SessionPublicKey = existingService.SessionPublicKey
serviceEnabledChanged = existingService.Enabled != service.Enabled
if err = validateTargetReferences(ctx, transaction, accountID, reverseProxy.Targets); err != nil {
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err = transaction.UpdateReverseProxy(ctx, reverseProxy); err != nil {
return fmt.Errorf("update reverse proxy: %w", err)
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
return nil
@@ -286,33 +280,28 @@ func (m *managerImpl) UpdateReverseProxy(ctx context.Context, accountID, userID
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, reverseProxy.ID, accountID, activity.ReverseProxyUpdated, reverseProxy.EventMeta())
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, reverseProxy)
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for proxy %s: %w", reverseProxy.ID, err)
}
token, err := m.tokenStore.GenerateToken(accountID, reverseProxy.ID, 5*time.Minute)
if err != nil {
return nil, fmt.Errorf("failed to generate authentication token: %w", err)
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
switch {
case domainChanged && oldCluster != reverseProxy.ProxyCluster:
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), oldCluster)
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Create, token, m.proxyGRPCServer.GetOIDCValidationConfig()), reverseProxy.ProxyCluster)
case !reverseProxy.Enabled && reverseProxyEnabledChanged:
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), reverseProxy.ProxyCluster)
case reverseProxy.Enabled && reverseProxyEnabledChanged:
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Create, token, m.proxyGRPCServer.GetOIDCValidationConfig()), reverseProxy.ProxyCluster)
case domainChanged && oldCluster != service.ProxyCluster:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), oldCluster)
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
case !service.Enabled && serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
case service.Enabled && serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
default:
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), reverseProxy.ProxyCluster)
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
return reverseProxy, nil
return service, nil
}
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
@@ -338,7 +327,7 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
return nil
}
func (m *managerImpl) DeleteReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) error {
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
@@ -347,16 +336,16 @@ func (m *managerImpl) DeleteReverseProxy(ctx context.Context, accountID, userID,
return status.NewPermissionDeniedError()
}
var reverseProxy *reverseproxy.ReverseProxy
var service *reverseproxy.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
reverseProxy, err = transaction.GetReverseProxyByID(ctx, store.LockingStrengthUpdate, accountID, reverseProxyID)
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if err = transaction.DeleteReverseProxy(ctx, accountID, reverseProxyID); err != nil {
return fmt.Errorf("failed to delete reverse proxy: %w", err)
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}
return nil
@@ -365,9 +354,9 @@ func (m *managerImpl) DeleteReverseProxy(ctx context.Context, accountID, userID,
return err
}
m.accountManager.StoreEvent(ctx, userID, reverseProxyID, accountID, activity.ReverseProxyDeleted, reverseProxy.EventMeta())
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), reverseProxy.ProxyCluster)
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
@@ -376,71 +365,71 @@ func (m *managerImpl) DeleteReverseProxy(ctx context.Context, accountID, userID,
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
// Call this when receiving a gRPC notification that the certificate was issued.
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
proxy, err := transaction.GetReverseProxyByID(ctx, store.LockingStrengthUpdate, accountID, reverseProxyID)
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get reverse proxy: %w", err)
return fmt.Errorf("failed to get service: %w", err)
}
proxy.Meta.CertificateIssuedAt = time.Now()
service.Meta.CertificateIssuedAt = time.Now()
if err = transaction.UpdateReverseProxy(ctx, proxy); err != nil {
return fmt.Errorf("failed to update reverse proxy certificate timestamp: %w", err)
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
}
return nil
})
}
// SetStatus updates the status of the reverse proxy (e.g., "active", "tunnel_not_created", etc.)
func (m *managerImpl) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
proxy, err := transaction.GetReverseProxyByID(ctx, store.LockingStrengthUpdate, accountID, reverseProxyID)
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get reverse proxy: %w", err)
return fmt.Errorf("failed to get service: %w", err)
}
proxy.Meta.Status = string(status)
service.Meta.Status = string(status)
if err = transaction.UpdateReverseProxy(ctx, proxy); err != nil {
return fmt.Errorf("failed to update reverse proxy status: %w", err)
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("failed to update service status: %w", err)
}
return nil
})
}
func (m *managerImpl) ReloadReverseProxy(ctx context.Context, accountID, reverseProxyID string) error {
proxy, err := m.store.GetReverseProxyByID(ctx, store.LockingStrengthNone, accountID, reverseProxyID)
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get reverse proxy: %w", err)
return fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, proxy)
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for proxy %s: %w", proxy.ID, err)
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(proxy.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), proxy.ProxyCluster)
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) ReloadAllReverseProxiesForAccount(ctx context.Context, accountID string) error {
proxies, err := m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID)
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get reverse proxies: %w", err)
return fmt.Errorf("failed to get services: %w", err)
}
for _, proxy := range proxies {
err = m.replaceHostByLookup(ctx, accountID, proxy)
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for proxy %s: %w", proxy.ID, err)
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(proxy.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), proxy.ProxyCluster)
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
@@ -448,61 +437,64 @@ func (m *managerImpl) ReloadAllReverseProxiesForAccount(ctx context.Context, acc
return nil
}
func (m *managerImpl) GetGlobalReverseProxies(ctx context.Context) ([]*reverseproxy.ReverseProxy, error) {
proxies, err := m.store.GetReverseProxies(ctx, store.LockingStrengthNone)
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
if err != nil {
return nil, fmt.Errorf("failed to get reverse proxies: %w", err)
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, proxy := range proxies {
err = m.replaceHostByLookup(ctx, proxy.AccountID, proxy)
for _, service := range services {
err = m.replaceHostByLookup(ctx, service.AccountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for proxy %s: %w", proxy.ID, err)
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return proxies, nil
return services, nil
}
func (m *managerImpl) GetProxyByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.ReverseProxy, error) {
proxy, err := m.store.GetReverseProxyByID(ctx, store.LockingStrengthNone, accountID, reverseProxyID)
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get reverse proxy: %w", err)
return nil, fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, proxy)
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for proxy %s: %w", proxy.ID, err)
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
return proxy, nil
return service, nil
}
func (m *managerImpl) GetAccountReverseProxies(ctx context.Context, accountID string) ([]*reverseproxy.ReverseProxy, error) {
proxies, err := m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID)
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get reverse proxies: %w", err)
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, proxy := range proxies {
err = m.replaceHostByLookup(ctx, accountID, proxy)
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for proxy %s: %w", proxy.ID, err)
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return proxies, nil
return services, nil
}
func (m *managerImpl) GetProxyIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetReverseProxyTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
return "", fmt.Errorf("failed to get reverse proxy target by resource ID: %w", err)
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
return "", nil
}
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
}
if target == nil {
return "", nil
}
return target.ReverseProxyID, nil
return target.ServiceID, nil
}

View File

@@ -43,16 +43,16 @@ const (
)
type Target struct {
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ReverseProxyID string `gorm:"index:idx_reverse_proxy_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
}
type PasswordAuthConfig struct {
@@ -112,34 +112,34 @@ type OIDCValidationConfig struct {
MaxTokenAgeSeconds int64
}
type ReverseProxyMeta struct {
type ServiceMeta struct {
CreatedAt time.Time
CertificateIssuedAt time.Time
Status string
}
type ReverseProxy struct {
type Service struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string `gorm:"index"`
ProxyCluster string `gorm:"index"`
Targets []*Target `gorm:"foreignKey:ReverseProxyID;constraint:OnDelete:CASCADE"`
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
Enabled bool
PassHostHeader bool
RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"`
Meta ReverseProxyMeta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
Auth AuthConfig `gorm:"serializer:json"`
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
}
func NewReverseProxy(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *ReverseProxy {
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
for _, target := range targets {
target.AccountID = accountID
}
rp := &ReverseProxy{
s := &Service{
AccountID: accountID,
Name: name,
Domain: domain,
@@ -147,92 +147,92 @@ func NewReverseProxy(accountID, name, domain, proxyCluster string, targets []*Ta
Targets: targets,
Enabled: enabled,
}
rp.InitNewRecord()
return rp
s.InitNewRecord()
return s
}
// InitNewRecord generates a new unique ID and resets metadata for a newly created
// ReverseProxy record. This overwrites any existing ID and Meta fields and should
// Service record. This overwrites any existing ID and Meta fields and should
// only be called during initial creation, not for updates.
func (r *ReverseProxy) InitNewRecord() {
r.ID = xid.New().String()
r.Meta = ReverseProxyMeta{
func (s *Service) InitNewRecord() {
s.ID = xid.New().String()
s.Meta = ServiceMeta{
CreatedAt: time.Now(),
Status: string(StatusPending),
}
}
func (r *ReverseProxy) ToAPIResponse() *api.ReverseProxy {
r.Auth.ClearSecrets()
func (s *Service) ToAPIResponse() *api.Service {
s.Auth.ClearSecrets()
authConfig := api.ReverseProxyAuthConfig{}
authConfig := api.ServiceAuthConfig{}
if r.Auth.PasswordAuth != nil {
if s.Auth.PasswordAuth != nil {
authConfig.PasswordAuth = &api.PasswordAuthConfig{
Enabled: r.Auth.PasswordAuth.Enabled,
Password: r.Auth.PasswordAuth.Password,
Enabled: s.Auth.PasswordAuth.Enabled,
Password: s.Auth.PasswordAuth.Password,
}
}
if r.Auth.PinAuth != nil {
if s.Auth.PinAuth != nil {
authConfig.PinAuth = &api.PINAuthConfig{
Enabled: r.Auth.PinAuth.Enabled,
Pin: r.Auth.PinAuth.Pin,
Enabled: s.Auth.PinAuth.Enabled,
Pin: s.Auth.PinAuth.Pin,
}
}
if r.Auth.BearerAuth != nil {
if s.Auth.BearerAuth != nil {
authConfig.BearerAuth = &api.BearerAuthConfig{
Enabled: r.Auth.BearerAuth.Enabled,
DistributionGroups: &r.Auth.BearerAuth.DistributionGroups,
Enabled: s.Auth.BearerAuth.Enabled,
DistributionGroups: &s.Auth.BearerAuth.DistributionGroups,
}
}
// Convert internal targets to API targets
apiTargets := make([]api.ReverseProxyTarget, 0, len(r.Targets))
for _, target := range r.Targets {
apiTargets = append(apiTargets, api.ReverseProxyTarget{
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
for _, target := range s.Targets {
apiTargets = append(apiTargets, api.ServiceTarget{
Path: target.Path,
Host: &target.Host,
Port: target.Port,
Protocol: api.ReverseProxyTargetProtocol(target.Protocol),
Protocol: api.ServiceTargetProtocol(target.Protocol),
TargetId: target.TargetId,
TargetType: api.ReverseProxyTargetTargetType(target.TargetType),
TargetType: api.ServiceTargetTargetType(target.TargetType),
Enabled: target.Enabled,
})
}
meta := api.ReverseProxyMeta{
CreatedAt: r.Meta.CreatedAt,
Status: api.ReverseProxyMetaStatus(r.Meta.Status),
meta := api.ServiceMeta{
CreatedAt: s.Meta.CreatedAt,
Status: api.ServiceMetaStatus(s.Meta.Status),
}
if !r.Meta.CertificateIssuedAt.IsZero() {
meta.CertificateIssuedAt = &r.Meta.CertificateIssuedAt
if !s.Meta.CertificateIssuedAt.IsZero() {
meta.CertificateIssuedAt = &s.Meta.CertificateIssuedAt
}
resp := &api.ReverseProxy{
Id: r.ID,
Name: r.Name,
Domain: r.Domain,
resp := &api.Service{
Id: s.ID,
Name: s.Name,
Domain: s.Domain,
Targets: apiTargets,
Enabled: r.Enabled,
PassHostHeader: &r.PassHostHeader,
RewriteRedirects: &r.RewriteRedirects,
Enabled: s.Enabled,
PassHostHeader: &s.PassHostHeader,
RewriteRedirects: &s.RewriteRedirects,
Auth: authConfig,
Meta: meta,
}
if r.ProxyCluster != "" {
resp.ProxyCluster = &r.ProxyCluster
if s.ProxyCluster != "" {
resp.ProxyCluster = &s.ProxyCluster
}
return resp
}
func (r *ReverseProxy) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := make([]*proto.PathMapping, 0, len(r.Targets))
for _, target := range r.Targets {
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
for _, target := range s.Targets {
if !target.Enabled {
continue
}
@@ -260,32 +260,32 @@ func (r *ReverseProxy) ToProtoMapping(operation Operation, authToken string, oid
}
auth := &proto.Authentication{
SessionKey: r.SessionPublicKey,
SessionKey: s.SessionPublicKey,
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
}
if r.Auth.PasswordAuth != nil && r.Auth.PasswordAuth.Enabled {
if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled {
auth.Password = true
}
if r.Auth.PinAuth != nil && r.Auth.PinAuth.Enabled {
if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled {
auth.Pin = true
}
if r.Auth.BearerAuth != nil && r.Auth.BearerAuth.Enabled {
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
auth.Oidc = true
}
return &proto.ProxyMapping{
Type: operationToProtoType(operation),
Id: r.ID,
Domain: r.Domain,
Id: s.ID,
Domain: s.Domain,
Path: pathMappings,
AuthToken: authToken,
Auth: auth,
AccountId: r.AccountID,
PassHostHeader: r.PassHostHeader,
RewriteRedirects: r.RewriteRedirects,
AccountId: s.AccountID,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
}
}
@@ -309,10 +309,10 @@ func isDefaultPort(scheme string, port int) bool {
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
}
func (r *ReverseProxy) FromAPIRequest(req *api.ReverseProxyRequest, accountID string) {
r.Name = req.Name
r.Domain = req.Domain
r.AccountID = accountID
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
s.Name = req.Name
s.Domain = req.Domain
s.AccountID = accountID
targets := make([]*Target, 0, len(req.Targets))
for _, apiTarget := range req.Targets {
@@ -330,27 +330,27 @@ func (r *ReverseProxy) FromAPIRequest(req *api.ReverseProxyRequest, accountID st
}
targets = append(targets, target)
}
r.Targets = targets
s.Targets = targets
r.Enabled = req.Enabled
s.Enabled = req.Enabled
if req.PassHostHeader != nil {
r.PassHostHeader = *req.PassHostHeader
s.PassHostHeader = *req.PassHostHeader
}
if req.RewriteRedirects != nil {
r.RewriteRedirects = *req.RewriteRedirects
s.RewriteRedirects = *req.RewriteRedirects
}
if req.Auth.PasswordAuth != nil {
r.Auth.PasswordAuth = &PasswordAuthConfig{
s.Auth.PasswordAuth = &PasswordAuthConfig{
Enabled: req.Auth.PasswordAuth.Enabled,
Password: req.Auth.PasswordAuth.Password,
}
}
if req.Auth.PinAuth != nil {
r.Auth.PinAuth = &PINAuthConfig{
s.Auth.PinAuth = &PINAuthConfig{
Enabled: req.Auth.PinAuth.Enabled,
Pin: req.Auth.PinAuth.Pin,
}
@@ -363,27 +363,27 @@ func (r *ReverseProxy) FromAPIRequest(req *api.ReverseProxyRequest, accountID st
if req.Auth.BearerAuth.DistributionGroups != nil {
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
}
r.Auth.BearerAuth = bearerAuth
s.Auth.BearerAuth = bearerAuth
}
}
func (r *ReverseProxy) Validate() error {
if r.Name == "" {
return errors.New("reverse proxy name is required")
func (s *Service) Validate() error {
if s.Name == "" {
return errors.New("service name is required")
}
if len(r.Name) > 255 {
return errors.New("reverse proxy name exceeds maximum length of 255 characters")
if len(s.Name) > 255 {
return errors.New("service name exceeds maximum length of 255 characters")
}
if r.Domain == "" {
return errors.New("reverse proxy domain is required")
if s.Domain == "" {
return errors.New("service domain is required")
}
if len(r.Targets) == 0 {
if len(s.Targets) == 0 {
return errors.New("at least one target is required")
}
for i, target := range r.Targets {
for i, target := range s.Targets {
switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// host field will be ignored
@@ -402,42 +402,42 @@ func (r *ReverseProxy) Validate() error {
return nil
}
func (r *ReverseProxy) EventMeta() map[string]any {
return map[string]any{"name": r.Name, "domain": r.Domain, "proxy_cluster": r.ProxyCluster}
func (s *Service) EventMeta() map[string]any {
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster}
}
func (r *ReverseProxy) Copy() *ReverseProxy {
targets := make([]*Target, len(r.Targets))
for i, target := range r.Targets {
func (s *Service) Copy() *Service {
targets := make([]*Target, len(s.Targets))
for i, target := range s.Targets {
targetCopy := *target
targets[i] = &targetCopy
}
return &ReverseProxy{
ID: r.ID,
AccountID: r.AccountID,
Name: r.Name,
Domain: r.Domain,
ProxyCluster: r.ProxyCluster,
return &Service{
ID: s.ID,
AccountID: s.AccountID,
Name: s.Name,
Domain: s.Domain,
ProxyCluster: s.ProxyCluster,
Targets: targets,
Enabled: r.Enabled,
PassHostHeader: r.PassHostHeader,
RewriteRedirects: r.RewriteRedirects,
Auth: r.Auth,
Meta: r.Meta,
SessionPrivateKey: r.SessionPrivateKey,
SessionPublicKey: r.SessionPublicKey,
Enabled: s.Enabled,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
Auth: s.Auth,
Meta: s.Meta,
SessionPrivateKey: s.SessionPrivateKey,
SessionPublicKey: s.SessionPublicKey,
}
}
func (r *ReverseProxy) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if r.SessionPrivateKey != "" {
if s.SessionPrivateKey != "" {
var err error
r.SessionPrivateKey, err = enc.Encrypt(r.SessionPrivateKey)
s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey)
if err != nil {
return err
}
@@ -446,14 +446,14 @@ func (r *ReverseProxy) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
return nil
}
func (r *ReverseProxy) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if r.SessionPrivateKey != "" {
if s.SessionPrivateKey != "" {
var err error
r.SessionPrivateKey, err = enc.Decrypt(r.SessionPrivateKey)
s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey)
if err != nil {
return err
}

View File

@@ -13,8 +13,8 @@ import (
"github.com/netbirdio/netbird/shared/management/proto"
)
func validProxy() *ReverseProxy {
return &ReverseProxy{
func validProxy() *Service {
return &Service{
Name: "test",
Domain: "example.com",
Targets: []*Target{
@@ -170,7 +170,7 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := &ReverseProxy{
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",
@@ -193,7 +193,7 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) {
}
func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
rp := &ReverseProxy{
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",

View File

@@ -94,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer())
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
@@ -103,7 +103,7 @@ func (s *BaseServer) AccountManager() account.Manager {
}
s.AfterInit(func(s *BaseServer) {
accountManager.SetReverseProxyManager(s.ReverseProxyManager())
accountManager.SetServiceManager(s.ReverseProxyManager())
})
return accountManager
@@ -192,13 +192,13 @@ func (s *BaseServer) RecordsManager() records.Manager {
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager {
return Create(s, func() reverseproxy.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ProxyTokenStore(), s.ReverseProxyDomainManager())
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager())
})
}
func (s *BaseServer) ReverseProxyDomainManager() *domain.Manager {
return Create(s, func() *domain.Manager {
m := domain.NewManager(s.Store(), s.ReverseProxyGRPCServer())
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return Create(s, func() *manager.Manager {
m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager())
return &m
})
}

View File

@@ -140,6 +140,14 @@ func (s *BaseServer) Start(ctx context.Context) error {
go metricsWorker.Run(srvCtx)
}
// Run afterInit hooks before starting any servers
// This allows registering additional gRPC services (e.g., Signal) before Serve() is called
for _, fn := range s.afterInit {
if fn != nil {
fn(s)
}
}
var compatListener net.Listener
if s.mgmtPort != ManagementLegacyPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
@@ -180,12 +188,6 @@ func (s *BaseServer) Start(ctx context.Context) error {
}
}
for _, fn := range s.afterInit {
if fn != nil {
fn(s)
}
}
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
@@ -217,6 +219,7 @@ func (s *BaseServer) Stop() error {
_ = s.certManager.Listener().Close()
}
s.GRPCServer().Stop()
s.ReverseProxyGRPCServer().Close()
if s.proxyAuthClose != nil {
s.proxyAuthClose()
s.proxyAuthClose = nil
@@ -261,7 +264,23 @@ func (s *BaseServer) SetContainer(key string, container any) {
log.Tracef("container with key %s set successfully", key)
}
// SetHandlerFunc allows overriding the default HTTP handler function.
// This is useful for multiplexing additional services on the same port.
func (s *BaseServer) SetHandlerFunc(handler http.Handler) {
s.container["customHandler"] = handler
log.Tracef("custom handler set successfully")
}
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
// Check if a custom handler was set (for multiplexing additional services)
if customHandler, ok := s.GetContainer("customHandler"); ok {
if handler, ok := customHandler.(http.Handler); ok {
log.Tracef("using custom handler")
return handler
}
}
// Use default handler
wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter))
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {

View File

@@ -13,7 +13,7 @@ import (
// OneTimeTokenStore manages short-lived, single-use authentication tokens
// for proxy-to-management RPC authentication. Tokens are generated when
// a reverse proxy is created and must be used exactly once by the proxy
// a service is created and must be used exactly once by the proxy
// to authenticate a subsequent RPC call.
type OneTimeTokenStore struct {
tokens map[string]*tokenMetadata
@@ -24,10 +24,10 @@ type OneTimeTokenStore struct {
// tokenMetadata stores information about a one-time token
type tokenMetadata struct {
ReverseProxyID string
AccountID string
ExpiresAt time.Time
CreatedAt time.Time
ServiceID string
AccountID string
ExpiresAt time.Time
CreatedAt time.Time
}
// NewOneTimeTokenStore creates a new token store with automatic cleanup
@@ -48,10 +48,10 @@ func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
// GenerateToken creates a new cryptographically secure one-time token
// with the specified TTL. The token is associated with a specific
// accountID and reverseProxyID for validation purposes.
// accountID and serviceID for validation purposes.
//
// Returns the generated token string or an error if random generation fails.
func (s *OneTimeTokenStore) GenerateToken(accountID, reverseProxyID string, ttl time.Duration) (string, error) {
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
// Generate 32 bytes (256 bits) of cryptographically secure random data
randomBytes := make([]byte, 32)
if _, err := rand.Read(randomBytes); err != nil {
@@ -65,20 +65,20 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, reverseProxyID string, ttl
defer s.mu.Unlock()
s.tokens[token] = &tokenMetadata{
ReverseProxyID: reverseProxyID,
AccountID: accountID,
ExpiresAt: time.Now().Add(ttl),
CreatedAt: time.Now(),
ServiceID: serviceID,
AccountID: accountID,
ExpiresAt: time.Now().Add(ttl),
CreatedAt: time.Now(),
}
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
reverseProxyID, accountID, ttl)
serviceID, accountID, ttl)
return token, nil
}
// ValidateAndConsume verifies the token against the provided accountID and
// reverseProxyID, checks expiration, and then deletes it to enforce single-use.
// serviceID, checks expiration, and then deletes it to enforce single-use.
//
// This method uses constant-time comparison to prevent timing attacks.
//
@@ -87,14 +87,14 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, reverseProxyID string, ttl
// - Token has expired
// - Account ID doesn't match
// - Reverse proxy ID doesn't match
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, reverseProxyID string) error {
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
s.mu.Lock()
defer s.mu.Unlock()
metadata, exists := s.tokens[token]
if !exists {
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)",
reverseProxyID, accountID)
serviceID, accountID)
return fmt.Errorf("invalid token")
}
@@ -102,7 +102,7 @@ func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, reverseProxyID
if time.Now().After(metadata.ExpiresAt) {
delete(s.tokens, token)
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
reverseProxyID, accountID)
serviceID, accountID)
return fmt.Errorf("token expired")
}
@@ -113,18 +113,18 @@ func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, reverseProxyID
return fmt.Errorf("account ID mismatch")
}
// Validate reverse proxy ID using constant-time comparison
if subtle.ConstantTimeCompare([]byte(metadata.ReverseProxyID), []byte(reverseProxyID)) != 1 {
log.Warnf("Token validation failed: reverse proxy ID mismatch (expected: %s, got: %s)",
metadata.ReverseProxyID, reverseProxyID)
return fmt.Errorf("reverse proxy ID mismatch")
// Validate service ID using constant-time comparison
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)",
metadata.ServiceID, serviceID)
return fmt.Errorf("service ID mismatch")
}
// Delete token immediately to enforce single-use
delete(s.tokens, token)
log.Infof("Token validated and consumed for proxy %s in account %s",
reverseProxyID, accountID)
serviceID, accountID)
return nil
}

View File

@@ -3,18 +3,19 @@ package grpc
import (
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"net"
"net/url"
"strings"
"sync"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/netbirdio/netbird/shared/management/domain"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"google.golang.org/grpc/codes"
@@ -81,7 +82,17 @@ type ProxyServiceServer struct {
oidcConfig ProxyOIDCConfig
// TODO: use database to store these instead?
pkceVerifiers sync.Map
// pkceVerifiers stores PKCE code verifiers keyed by OAuth state.
// Entries expire after pkceVerifierTTL to prevent unbounded growth.
pkceVerifiers sync.Map
pkceCleanupCancel context.CancelFunc
}
const pkceVerifierTTL = 10 * time.Minute
type pkceEntry struct {
verifier string
createdAt time.Time
}
// proxyConnection represents a connected proxy
@@ -92,19 +103,47 @@ type proxyConnection struct {
sendChan chan *proto.ProxyMapping
ctx context.Context
cancel context.CancelFunc
mu sync.RWMutex
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager) *ProxyServiceServer {
return &ProxyServiceServer{
updatesChan: make(chan *proto.ProxyMapping, 100),
accessLogManager: accessLogMgr,
oidcConfig: oidcConfig,
tokenStore: tokenStore,
peersManager: peersManager,
usersManager: usersManager,
ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{
updatesChan: make(chan *proto.ProxyMapping, 100),
accessLogManager: accessLogMgr,
oidcConfig: oidcConfig,
tokenStore: tokenStore,
peersManager: peersManager,
usersManager: usersManager,
pkceCleanupCancel: cancel,
}
go s.cleanupPKCEVerifiers(ctx)
return s
}
// cleanupPKCEVerifiers periodically removes expired PKCE verifiers.
func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) {
ticker := time.NewTicker(pkceVerifierTTL)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
s.pkceVerifiers.Range(func(key, value any) bool {
if entry, ok := value.(pkceEntry); ok && now.Sub(entry.createdAt) > pkceVerifierTTL {
s.pkceVerifiers.Delete(key)
}
return true
})
}
}
}
// Close stops background goroutines.
func (s *ProxyServiceServer) Close() {
s.pkceCleanupCancel()
}
func (s *ProxyServiceServer) SetProxyManager(manager reverseproxy.Manager) {
@@ -128,12 +167,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
}
proxyAddress := req.GetAddress()
log.WithFields(log.Fields{
"proxy_id": proxyID,
"address": proxyAddress,
"version": req.GetVersion(),
"started": req.GetStartedAt().AsTime(),
}).Info("Proxy connected")
if !isProxyAddressValid(proxyAddress) {
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
}
connCtx, cancel := context.WithCancel(ctx)
conn := &proxyConnection{
@@ -150,7 +186,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
log.WithFields(log.Fields{
"proxy_id": proxyID,
"address": proxyAddress,
"cluster_addr": extractClusterAddr(proxyAddress),
"cluster_addr": proxyAddress,
"total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster")
defer func() {
@@ -161,8 +197,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
}()
if err := s.sendSnapshot(ctx, conn); err != nil {
log.Errorf("Failed to send snapshot to proxy %s: %v", proxyID, err)
return err
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
}
errChan := make(chan error, 2)
@@ -170,73 +205,81 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
select {
case err := <-errChan:
return err
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
case <-connCtx.Done():
return connCtx.Err()
}
}
// sendSnapshot sends the initial snapshot of reverse proxies to the connecting proxy.
// Only reverse proxies matching the proxy's cluster address are sent.
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
// Only services matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
reverseProxies, err := s.reverseProxyManager.GetGlobalReverseProxies(ctx)
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
if err != nil {
return fmt.Errorf("get reverse proxies from store: %w", err)
return fmt.Errorf("get services from store: %w", err)
}
proxyClusterAddr := extractClusterAddr(conn.address)
if !isProxyAddressValid(conn.address) {
return fmt.Errorf("proxy address is invalid")
}
for _, rp := range reverseProxies {
if !rp.Enabled {
var filtered []*reverseproxy.Service
for _, service := range services {
if !service.Enabled {
continue
}
if rp.ProxyCluster != "" && proxyClusterAddr != "" && rp.ProxyCluster != proxyClusterAddr {
if service.ProxyCluster == "" || service.ProxyCluster != conn.address {
continue
}
filtered = append(filtered, service)
}
// Generate one-time authentication token for each proxy in the snapshot
if len(filtered) == 0 {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
InitialSyncComplete: true,
}); err != nil {
return fmt.Errorf("send snapshot completion: %w", err)
}
return nil
}
for i, service := range filtered {
// Generate one-time authentication token for each service in the snapshot
// Tokens are not persistent on the proxy, so we need to generate new ones on reconnection
token, err := s.tokenStore.GenerateToken(rp.AccountID, rp.ID, 5*time.Minute)
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
if err != nil {
log.WithFields(log.Fields{
"proxy": rp.Name,
"account": rp.AccountID,
}).WithError(err).Error("Failed to generate auth token for snapshot")
"service": service.Name,
"account": service.AccountID,
}).WithError(err).Error("failed to generate auth token for snapshot")
continue
}
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{
rp.ToProtoMapping(
service.ToProtoMapping(
reverseproxy.Create, // Initial snapshot, all records are "new" for the proxy.
token,
s.GetOIDCValidationConfig(),
),
},
InitialSyncComplete: i == len(filtered)-1,
}); err != nil {
log.WithError(err).Error("Failed to send proxy mapping")
continue
log.WithFields(log.Fields{
"domain": service.Domain,
"account": service.AccountID,
}).WithError(err).Error("failed to send proxy mapping")
return fmt.Errorf("send proxy mapping: %w", err)
}
}
return nil
}
// extractClusterAddr extracts the host from a proxy address URL.
func extractClusterAddr(addr string) string {
if addr == "" {
return ""
}
u, err := url.Parse(addr)
if err != nil {
return addr
}
host := u.Host
if h, _, err := net.SplitHostPort(host); err == nil {
return h
}
return host
// isProxyAddressValid validates a proxy address
func isProxyAddressValid(addr string) bool {
_, err := domain.ValidateDomains([]string{addr})
return err == nil
}
// sender handles sending messages to proxy
@@ -245,7 +288,6 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
select {
case msg := <-conn.sendChan:
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{Mapping: []*proto.ProxyMapping{msg}}); err != nil {
log.Errorf("Failed to send message to proxy %s: %v", conn.proxyID, err)
errChan <- err
return
}
@@ -260,10 +302,10 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
accessLog := req.GetLog()
fields := log.Fields{
"reverse_proxy_id": accessLog.GetServiceId(),
"account_id": accessLog.GetAccountId(),
"host": accessLog.GetHost(),
"source_ip": accessLog.GetSourceIp(),
"service_id": accessLog.GetServiceId(),
"account_id": accessLog.GetAccountId(),
"host": accessLog.GetHost(),
"source_ip": accessLog.GetSourceIp(),
}
if mechanism := accessLog.GetAuthMechanism(); mechanism != "" {
fields["auth_mechanism"] = mechanism
@@ -286,24 +328,29 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
if err := s.accessLogManager.SaveAccessLog(ctx, logEntry); err != nil {
log.WithContext(ctx).Errorf("failed to save access log: %v", err)
return nil, status.Errorf(codes.Internal, "failed to save access log: %v", err)
return nil, status.Errorf(codes.Internal, "save access log: %v", err)
}
return &proto.SendAccessLogResponse{}, nil
}
// SendReverseProxyUpdate broadcasts a reverse proxy update to all connected proxies.
// Management should call this when reverse proxies are created/updated/removed
func (s *ProxyServiceServer) SendReverseProxyUpdate(update *proto.ProxyMapping) {
// Send it to all connected proxies
log.Debugf("Broadcasting reverse proxy update to all connected proxies")
// SendServiceUpdate broadcasts a service update to all connected proxy servers.
// Management should call this when services are created/updated/removed.
// For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management.
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) {
log.Debugf("Broadcasting service update to all connected proxy servers")
s.connectedProxies.Range(func(key, value interface{}) bool {
conn := value.(*proxyConnection)
msg := s.perProxyMessage(update, conn.proxyID)
if msg == nil {
return true
}
select {
case conn.sendChan <- update:
log.Debugf("Sent reverse proxy update with id %s to proxy %s", update.Id, conn.proxyID)
case conn.sendChan <- msg:
log.Debugf("Sent service update with id %s to proxy server %s", update.Id, conn.proxyID)
default:
log.Warnf("Failed to send reverse proxy update to proxy %s (channel full)", conn.proxyID)
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
}
return true
})
@@ -366,11 +413,13 @@ func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
}
}
// SendReverseProxyUpdateToCluster sends a reverse proxy update to all proxies in a specific cluster.
// If clusterAddr is empty, broadcasts to all connected proxies (backward compatibility).
func (s *ProxyServiceServer) SendReverseProxyUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) {
// SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster.
// If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility).
// For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management.
func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) {
if clusterAddr == "" {
s.SendReverseProxyUpdate(update)
s.SendServiceUpdate(update)
return
}
@@ -380,22 +429,62 @@ func (s *ProxyServiceServer) SendReverseProxyUpdateToCluster(update *proto.Proxy
return
}
log.Debugf("Sending reverse proxy update to cluster %s", clusterAddr)
log.Debugf("Sending service update to cluster %s", clusterAddr)
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
proxyID := key.(string)
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection)
msg := s.perProxyMessage(update, proxyID)
if msg == nil {
return true
}
select {
case conn.sendChan <- update:
log.Debugf("Sent reverse proxy update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
case conn.sendChan <- msg:
log.Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default:
log.Warnf("Failed to send reverse proxy update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
}
}
return true
})
}
// perProxyMessage returns a copy of update with a fresh one-time token for
// create/update operations. For delete operations the original message is
// returned unchanged because proxies do not need to authenticate for removal.
// Returns nil if token generation fails (the proxy should be skipped).
func (s *ProxyServiceServer) perProxyMessage(update *proto.ProxyMapping, proxyID string) *proto.ProxyMapping {
if update.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED || update.AccountId == "" {
return update
}
token, err := s.tokenStore.GenerateToken(update.AccountId, update.Id, 5*time.Minute)
if err != nil {
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
return nil
}
msg := shallowCloneMapping(update)
msg.AuthToken = token
return msg
}
// shallowCloneMapping creates a shallow copy of a ProxyMapping, reusing the
// same slice/pointer fields. Only scalar fields that differ per proxy (AuthToken)
// should be set on the copy.
func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
return &proto.ProxyMapping{
Type: m.Type,
Id: m.Id,
AccountId: m.AccountId,
Domain: m.Domain,
Path: m.Path,
Auth: m.Auth,
PassHostHeader: m.PassHostHeader,
RewriteRedirects: m.RewriteRedirects,
}
}
// GetAvailableClusters returns information about all connected proxy clusters.
func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
clusterCounts := make(map[string]int)
@@ -424,10 +513,10 @@ func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
proxy, err := s.reverseProxyManager.GetProxyByID(ctx, req.GetAccountId(), req.GetId())
service, err := s.reverseProxyManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil {
// TODO: log the error
return nil, status.Errorf(codes.FailedPrecondition, "failed to get reverse proxy from store: %v", err)
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err)
}
var authenticated bool
@@ -435,10 +524,9 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
var method proxyauth.Method
switch v := req.GetRequest().(type) {
case *proto.AuthenticateRequest_Pin:
auth := proxy.Auth.PinAuth
auth := service.Auth.PinAuth
if auth == nil || !auth.Enabled {
// TODO: log
// Break here and use the default authenticated == false.
log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", req.GetId())
break
}
err = argon2id.Verify(v.Pin.GetPin(), auth.Pin)
@@ -454,10 +542,9 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
userId = "pin-user"
method = proxyauth.MethodPIN
case *proto.AuthenticateRequest_Password:
auth := proxy.Auth.PasswordAuth
auth := service.Auth.PasswordAuth
if auth == nil || !auth.Enabled {
// TODO: log
// Break here and use the default authenticated == false.
log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", req.GetId())
break
}
err = argon2id.Verify(v.Password.GetPassword(), auth.Password)
@@ -475,17 +562,17 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
}
var token string
if authenticated && proxy.SessionPrivateKey != "" {
if authenticated && service.SessionPrivateKey != "" {
token, err = sessionkey.SignToken(
proxy.SessionPrivateKey,
service.SessionPrivateKey,
userId,
proxy.Domain,
service.Domain,
method,
proxyauth.DefaultSessionExpiry,
)
if err != nil {
log.WithError(err).Error("Failed to sign session token")
authenticated = false
log.WithContext(ctx).WithError(err).Error("failed to sign session token")
return nil, status.Errorf(codes.Internal, "sign session token: %v", err)
}
}
@@ -498,45 +585,45 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
// SendStatusUpdate handles status updates from proxy clients
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
accountID := req.GetAccountId()
reverseProxyID := req.GetReverseProxyId()
serviceID := req.GetServiceId()
protoStatus := req.GetStatus()
certificateIssued := req.GetCertificateIssued()
log.WithFields(log.Fields{
"reverse_proxy_id": reverseProxyID,
"service_id": serviceID,
"account_id": accountID,
"status": protoStatus,
"certificate_issued": certificateIssued,
"error_message": req.GetErrorMessage(),
}).Debug("Status update from proxy")
}).Debug("Status update from proxy server")
if reverseProxyID == "" || accountID == "" {
return nil, status.Errorf(codes.InvalidArgument, "reverse_proxy_id and account_id are required")
if serviceID == "" || accountID == "" {
return nil, status.Errorf(codes.InvalidArgument, "service_id and account_id are required")
}
if certificateIssued {
if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, reverseProxyID); err != nil {
log.WithContext(ctx).WithError(err).Error("Failed to set certificate issued timestamp")
return nil, status.Errorf(codes.Internal, "failed to update certificate timestamp: %v", err)
if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp")
return nil, status.Errorf(codes.Internal, "update certificate timestamp: %v", err)
}
log.WithFields(log.Fields{
"reverse_proxy_id": reverseProxyID,
"account_id": accountID,
"service_id": serviceID,
"account_id": accountID,
}).Info("Certificate issued timestamp updated")
}
internalStatus := protoStatusToInternal(protoStatus)
if err := s.reverseProxyManager.SetStatus(ctx, accountID, reverseProxyID, internalStatus); err != nil {
log.WithContext(ctx).WithError(err).Error("Failed to set proxy status")
return nil, status.Errorf(codes.Internal, "failed to update proxy status: %v", err)
if err := s.reverseProxyManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to update service status")
return nil, status.Errorf(codes.Internal, "update service status: %v", err)
}
log.WithFields(log.Fields{
"reverse_proxy_id": reverseProxyID,
"account_id": accountID,
"status": internalStatus,
}).Info("Proxy status updated")
"service_id": serviceID,
"account_id": accountID,
"status": internalStatus,
}).Info("Service status updated")
return &proto.SendStatusUpdateResponse{}, nil
}
@@ -563,47 +650,47 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStat
// CreateProxyPeer handles proxy peer creation with one-time token authentication
func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) {
reverseProxyID := req.GetReverseProxyId()
serviceID := req.GetServiceId()
accountID := req.GetAccountId()
token := req.GetToken()
cluster := req.GetCluster()
key := req.WireguardPublicKey
log.WithFields(log.Fields{
"reverse_proxy_id": reverseProxyID,
"account_id": accountID,
"cluster": cluster,
"service_id": serviceID,
"account_id": accountID,
"cluster": cluster,
}).Debug("CreateProxyPeer request received")
if reverseProxyID == "" || accountID == "" || token == "" {
if serviceID == "" || accountID == "" || token == "" {
log.Warn("CreateProxyPeer: missing required fields")
return &proto.CreateProxyPeerResponse{
Success: false,
ErrorMessage: strPtr("missing required fields: reverse_proxy_id, account_id, and token are required"),
ErrorMessage: strPtr("missing required fields: service_id, account_id, and token are required"),
}, nil
}
if err := s.tokenStore.ValidateAndConsume(token, accountID, reverseProxyID); err != nil {
if err := s.tokenStore.ValidateAndConsume(token, accountID, serviceID); err != nil {
log.WithFields(log.Fields{
"reverse_proxy_id": reverseProxyID,
"account_id": accountID,
"service_id": serviceID,
"account_id": accountID,
}).WithError(err).Warn("CreateProxyPeer: token validation failed")
return &proto.CreateProxyPeerResponse{
Success: false,
ErrorMessage: strPtr("authentication failed: invalid or expired token"),
}, status.Errorf(codes.Unauthenticated, "token validation failed: %v", err)
}, status.Errorf(codes.Unauthenticated, "token validation: %v", err)
}
err := s.peersManager.CreateProxyPeer(ctx, accountID, key, cluster)
if err != nil {
log.WithFields(log.Fields{
"reverse_proxy_id": reverseProxyID,
"account_id": accountID,
}).WithError(err).Error("CreateProxyPeer: failed to create proxy peer")
"service_id": serviceID,
"account_id": accountID,
}).WithError(err).Error("failed to create proxy peer")
return &proto.CreateProxyPeerResponse{
Success: false,
ErrorMessage: strPtr(fmt.Sprintf("failed to create proxy peer: %v", err)),
}, status.Errorf(codes.Internal, "failed to create proxy peer: %v", err)
ErrorMessage: strPtr(fmt.Sprintf("create proxy peer: %v", err)),
}, status.Errorf(codes.Internal, "create proxy peer: %v", err)
}
return &proto.CreateProxyPeerResponse{
@@ -619,31 +706,30 @@ func strPtr(s string) *string {
func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) {
redirectURL, err := url.Parse(req.GetRedirectUrl())
if err != nil {
// TODO: log
return nil, status.Errorf(codes.InvalidArgument, "failed to parse redirect url: %v", err)
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
}
// Validate redirectURL against known proxy endpoints to avoid abuse of OIDC redirection.
proxies, err := s.reverseProxyManager.GetAccountReverseProxies(ctx, req.GetAccountId())
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
services, err := s.reverseProxyManager.GetAccountServices(ctx, req.GetAccountId())
if err != nil {
// TODO: log
return nil, status.Errorf(codes.FailedPrecondition, "failed to get reverse proxy from store: %v", err)
log.WithContext(ctx).Errorf("failed to get account services: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err)
}
var found bool
for _, proxy := range proxies {
if proxy.Domain == redirectURL.Hostname() {
for _, service := range services {
if service.Domain == redirectURL.Hostname() {
found = true
break
}
}
if !found {
// TODO: log
return nil, status.Errorf(codes.FailedPrecondition, "reverse proxy not found in store")
log.WithContext(ctx).Debugf("OIDC redirect URL %q does not match any service domain", redirectURL.Hostname())
return nil, status.Errorf(codes.FailedPrecondition, "service not found in store")
}
provider, err := oidc.NewProvider(ctx, s.oidcConfig.Issuer)
if err != nil {
// TODO: log
return nil, status.Errorf(codes.FailedPrecondition, "failed to create OIDC provider: %v", err)
log.WithContext(ctx).Errorf("failed to create OIDC provider: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "create OIDC provider: %v", err)
}
scopes := s.oidcConfig.Scopes
@@ -651,13 +737,23 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
scopes = []string{oidc.ScopeOpenID, "profile", "email"}
}
// Generate a random nonce to ensure each OIDC request gets a unique state.
// Without this, multiple requests to the same URL would generate the same state
// but different PKCE verifiers, causing the later verifier to overwrite the earlier one.
nonce := make([]byte, 16)
if _, err := rand.Read(nonce); err != nil {
return nil, status.Errorf(codes.Internal, "generate nonce: %v", err)
}
nonceB64 := base64.URLEncoding.EncodeToString(nonce)
// Using an HMAC here to avoid redirection state being modified.
// State format: base64(redirectURL)|hmac
hmacSum := s.generateHMAC(redirectURL.String())
state := fmt.Sprintf("%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), hmacSum)
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
payload := redirectURL.String() + "|" + nonceB64
hmacSum := s.generateHMAC(payload)
state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum)
codeVerifier := oauth2.GenerateVerifier()
s.pkceVerifiers.Store(state, codeVerifier)
s.pkceVerifiers.Store(state, pkceEntry{verifier: codeVerifier, createdAt: time.Now()})
return &proto.GetOIDCURLResponse{
Url: (&oauth2.Config{
@@ -698,18 +794,24 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
if !ok {
return "", "", errors.New("no verifier for state")
}
verifier, ok = v.(string)
entry, ok := v.(pkceEntry)
if !ok {
return "", "", errors.New("invalid verifier for state")
}
if time.Since(entry.createdAt) > pkceVerifierTTL {
return "", "", errors.New("PKCE verifier expired")
}
verifier = entry.verifier
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
parts := strings.Split(state, "|")
if len(parts) != 2 {
if len(parts) != 3 {
return "", "", errors.New("invalid state format")
}
encodedURL := parts[0]
providedHMAC := parts[1]
nonce := parts[1]
providedHMAC := parts[2]
redirectURLBytes, err := base64.URLEncoding.DecodeString(encodedURL)
if err != nil {
@@ -717,10 +819,11 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
}
redirectURL = string(redirectURLBytes)
expectedHMAC := s.generateHMAC(redirectURL)
payload := redirectURL + "|" + nonce
expectedHMAC := s.generateHMAC(payload)
if !hmac.Equal([]byte(providedHMAC), []byte(expectedHMAC)) {
return "", "", fmt.Errorf("invalid state signature")
return "", "", errors.New("invalid state signature")
}
return verifier, redirectURL, nil
@@ -728,29 +831,29 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
// GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
// Find the proxy by domain to get its signing key
proxies, err := s.reverseProxyManager.GetGlobalReverseProxies(ctx)
// Find the service by domain to get its signing key
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
if err != nil {
return "", fmt.Errorf("get reverse proxies: %w", err)
return "", fmt.Errorf("get services: %w", err)
}
var proxy *reverseproxy.ReverseProxy
for _, p := range proxies {
if p.Domain == domain {
proxy = p
var service *reverseproxy.Service
for _, svc := range services {
if svc.Domain == domain {
service = svc
break
}
}
if proxy == nil {
return "", fmt.Errorf("reverse proxy not found for domain: %s", domain)
if service == nil {
return "", fmt.Errorf("service not found for domain: %s", domain)
}
if proxy.SessionPrivateKey == "" {
if service.SessionPrivateKey == "" {
return "", fmt.Errorf("no session key configured for domain: %s", domain)
}
return sessionkey.SignToken(
proxy.SessionPrivateKey,
service.SessionPrivateKey,
userID,
domain,
method,
@@ -758,8 +861,8 @@ func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, u
)
}
// ValidateUserGroupAccess checks if a user has access to a reverse proxy.
// It looks up the proxy within the user's account only, then optionally checks
// ValidateUserGroupAccess checks if a user has access to a service.
// It looks up the service within the user's account only, then optionally checks
// group membership if BearerAuth with DistributionGroups is configured.
func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain, userID string) error {
user, err := s.usersManager.GetUser(ctx, userID)
@@ -767,16 +870,16 @@ func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain
return fmt.Errorf("user not found: %s", userID)
}
proxy, err := s.getAccountProxyByDomain(ctx, user.AccountID, domain)
service, err := s.getAccountServiceByDomain(ctx, user.AccountID, domain)
if err != nil {
return err
}
if proxy.Auth.BearerAuth == nil || !proxy.Auth.BearerAuth.Enabled {
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
return nil
}
allowedGroups := proxy.Auth.BearerAuth.DistributionGroups
allowedGroups := service.Auth.BearerAuth.DistributionGroups
if len(allowedGroups) == 0 {
return nil
}
@@ -800,19 +903,19 @@ func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain
return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain)
}
func (s *ProxyServiceServer) getAccountProxyByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.ReverseProxy, error) {
proxies, err := s.reverseProxyManager.GetAccountReverseProxies(ctx, accountID)
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
services, err := s.reverseProxyManager.GetAccountServices(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get account reverse proxies: %w", err)
return nil, fmt.Errorf("get account services: %w", err)
}
for _, proxy := range proxies {
if proxy.Domain == domain {
return proxy, nil
for _, service := range services {
if service.Domain == domain {
return service, nil
}
}
return nil, fmt.Errorf("reverse proxy not found for domain %s in account %s", domain, accountID)
return nil, fmt.Errorf("service not found for domain %s in account %s", domain, accountID)
}
// ValidateSession validates a session token and checks if the user has access to the domain.
@@ -827,27 +930,29 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
proxy, err := s.getProxyByDomain(ctx, domain)
service, err := s.getServiceByDomain(ctx, domain)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
"error": err.Error(),
}).Debug("ValidateSession: proxy not found")
}).Debug("ValidateSession: service not found")
//nolint:nilerr
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "proxy_not_found",
DeniedReason: "service_not_found",
}, nil
}
pubKeyBytes, err := base64.StdEncoding.DecodeString(proxy.SessionPublicKey)
pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
"error": err.Error(),
}).Error("ValidateSession: decode public key")
//nolint:nilerr
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "invalid_proxy_config",
DeniedReason: "invalid_service_config",
}, nil
}
@@ -857,6 +962,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
"domain": domain,
"error": err.Error(),
}).Debug("ValidateSession: invalid session token")
//nolint:nilerr
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "invalid_token",
@@ -870,31 +976,34 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
"user_id": userID,
"error": err.Error(),
}).Debug("ValidateSession: user not found")
//nolint:nilerr
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "user_not_found",
}, nil
}
if user.AccountID != proxy.AccountID {
if user.AccountID != service.AccountID {
log.WithFields(log.Fields{
"domain": domain,
"user_id": userID,
"user_account": user.AccountID,
"proxy_account": proxy.AccountID,
"domain": domain,
"user_id": userID,
"user_account": user.AccountID,
"service_account": service.AccountID,
}).Debug("ValidateSession: user account mismatch")
//nolint:nilerr
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "account_mismatch",
}, nil
}
if err := s.checkGroupAccess(proxy, user); err != nil {
if err := s.checkGroupAccess(service, user); err != nil {
log.WithFields(log.Fields{
"domain": domain,
"user_id": userID,
"error": err.Error(),
}).Debug("ValidateSession: access denied")
//nolint:nilerr
return &proto.ValidateSessionResponse{
Valid: false,
UserId: user.Id,
@@ -916,27 +1025,27 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
func (s *ProxyServiceServer) getProxyByDomain(ctx context.Context, domain string) (*reverseproxy.ReverseProxy, error) {
proxies, err := s.reverseProxyManager.GetGlobalReverseProxies(ctx)
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*reverseproxy.Service, error) {
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
if err != nil {
return nil, fmt.Errorf("get reverse proxies: %w", err)
return nil, fmt.Errorf("get services: %w", err)
}
for _, proxy := range proxies {
if proxy.Domain == domain {
return proxy, nil
for _, service := range services {
if service.Domain == domain {
return service, nil
}
}
return nil, fmt.Errorf("reverse proxy not found for domain: %s", domain)
return nil, fmt.Errorf("service not found for domain: %s", domain)
}
func (s *ProxyServiceServer) checkGroupAccess(proxy *reverseproxy.ReverseProxy, user *types.User) error {
if proxy.Auth.BearerAuth == nil || !proxy.Auth.BearerAuth.Enabled {
func (s *ProxyServiceServer) checkGroupAccess(service *reverseproxy.Service, user *types.User) error {
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
return nil
}
allowedGroups := proxy.Auth.BearerAuth.DistributionGroups
allowedGroups := service.Auth.BearerAuth.DistributionGroups
if len(allowedGroups) == 0 {
return nil
}

View File

@@ -13,38 +13,38 @@ import (
)
type mockReverseProxyManager struct {
proxiesByAccount map[string][]*reverseproxy.ReverseProxy
proxiesByAccount map[string][]*reverseproxy.Service
err error
}
func (m *mockReverseProxyManager) GetAccountReverseProxies(ctx context.Context, accountID string) ([]*reverseproxy.ReverseProxy, error) {
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
if m.err != nil {
return nil, m.err
}
return m.proxiesByAccount[accountID], nil
}
func (m *mockReverseProxyManager) GetGlobalReverseProxies(ctx context.Context) ([]*reverseproxy.ReverseProxy, error) {
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
return nil, nil
}
func (m *mockReverseProxyManager) GetAllReverseProxies(ctx context.Context, accountID, userID string) ([]*reverseproxy.ReverseProxy, error) {
return nil, nil
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
return []*reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) GetReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.ReverseProxy, error) {
return nil, nil
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) CreateReverseProxy(ctx context.Context, accountID, userID string, rp *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) {
return nil, nil
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) UpdateReverseProxy(ctx context.Context, accountID, userID string, rp *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) {
return nil, nil
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) DeleteReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) error {
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
return nil
}
@@ -56,19 +56,19 @@ func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reve
return nil
}
func (m *mockReverseProxyManager) ReloadAllReverseProxiesForAccount(ctx context.Context, accountID string) error {
func (m *mockReverseProxyManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
return nil
}
func (m *mockReverseProxyManager) ReloadReverseProxy(ctx context.Context, accountID, reverseProxyID string) error {
func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) GetProxyByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.ReverseProxy, error) {
return nil, nil
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) GetProxyIDByTargetID(_ context.Context, _, _ string) (string, error) {
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
@@ -93,7 +93,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name string
domain string
userID string
proxiesByAccount map[string][]*reverseproxy.ReverseProxy
proxiesByAccount map[string][]*reverseproxy.Service
users map[string]*types.User
proxyErr error
userErr error
@@ -104,7 +104,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user not found",
domain: "app.example.com",
userID: "unknown-user",
proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
users: map[string]*types.User{},
@@ -115,31 +115,31 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "proxy not found in user's account",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{},
proxiesByAccount: map[string][]*reverseproxy.Service{},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "reverse proxy not found",
expectErrMsg: "service not found",
},
{
name: "proxy exists in different account - not accessible",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{
proxiesByAccount: map[string][]*reverseproxy.Service{
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "reverse proxy not found",
expectErrMsg: "service not found",
},
{
name: "no bearer auth configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
},
users: map[string]*types.User{
@@ -151,7 +151,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "bearer auth disabled - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
@@ -169,7 +169,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "bearer auth enabled but no groups configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
@@ -190,7 +190,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user not in allowed groups",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
@@ -212,7 +212,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user in one of the allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
@@ -233,7 +233,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user in all allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
@@ -260,13 +260,13 @@ func TestValidateUserGroupAccess(t *testing.T) {
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "get account reverse proxies",
expectErrMsg: "get account services",
},
{
name: "multiple proxies in account - finds correct one",
domain: "app2.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {
{Domain: "app1.example.com", AccountID: "account1"},
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
@@ -310,7 +310,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name string
accountID string
domain string
proxiesByAccount map[string][]*reverseproxy.ReverseProxy
proxiesByAccount map[string][]*reverseproxy.Service
err error
expectProxy bool
expectErr bool
@@ -319,7 +319,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "proxy found",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {
{Domain: "other.example.com", AccountID: "account1"},
{Domain: "app.example.com", AccountID: "account1"},
@@ -332,7 +332,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "proxy not found in account",
accountID: "account1",
domain: "unknown.example.com",
proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
expectProxy: false,
@@ -342,7 +342,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "empty proxy list for account",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{},
proxiesByAccount: map[string][]*reverseproxy.Service{},
expectProxy: false,
expectErr: true,
},
@@ -366,7 +366,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
},
}
proxy, err := server.getAccountProxyByDomain(context.Background(), tt.accountID, tt.domain)
proxy, err := server.getAccountServiceByDomain(context.Background(), tt.accountID, tt.domain)
if tt.expectErr {
require.Error(t, err)

View File

@@ -0,0 +1,232 @@
package grpc
import (
"crypto/rand"
"encoding/base64"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/proto"
)
// registerFakeProxy adds a fake proxy connection to the server's internal maps
// and returns the channel where messages will be received.
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping {
ch := make(chan *proto.ProxyMapping, 10)
conn := &proxyConnection{
proxyID: proxyID,
address: clusterAddr,
sendChan: ch,
}
s.connectedProxies.Store(proxyID, conn)
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
return ch
}
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
select {
case msg := <-ch:
return msg
case <-time.After(time.Second):
return nil
}
}
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
}
const cluster = "proxy.example.com"
const numProxies = 3
channels := make([]chan *proto.ProxyMapping, numProxies)
for i := range numProxies {
id := "proxy-" + string(rune('a'+i))
channels[i] = registerFakeProxy(s, id, cluster)
}
update := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
Path: []*proto.PathMapping{
{Path: "/", Target: "http://10.0.0.1:8080/"},
},
}
s.SendServiceUpdateToCluster(update, cluster)
tokens := make([]string, numProxies)
for i, ch := range channels {
msg := drainChannel(ch)
require.NotNil(t, msg, "proxy %d should receive a message", i)
assert.Equal(t, update.Domain, msg.Domain)
assert.Equal(t, update.Id, msg.Id)
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
tokens[i] = msg.AuthToken
}
// All tokens must be unique
tokenSet := make(map[string]struct{})
for i, tok := range tokens {
_, exists := tokenSet[tok]
assert.False(t, exists, "proxy %d got duplicate token", i)
tokenSet[tok] = struct{}{}
}
// Each token must be independently consumable
for i, tok := range tokens {
err := tokenStore.ValidateAndConsume(tok, "account-1", "service-1")
assert.NoError(t, err, "proxy %d token should validate successfully", i)
}
}
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
}
const cluster = "proxy.example.com"
ch1 := registerFakeProxy(s, "proxy-a", cluster)
ch2 := registerFakeProxy(s, "proxy-b", cluster)
update := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
}
s.SendServiceUpdateToCluster(update, cluster)
msg1 := drainChannel(ch1)
msg2 := drainChannel(ch2)
require.NotNil(t, msg1)
require.NotNil(t, msg2)
// Delete operations should not generate tokens
assert.Empty(t, msg1.AuthToken)
assert.Empty(t, msg2.AuthToken)
// No tokens should have been created
assert.Equal(t, 0, tokenStore.GetTokenCount())
}
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
}
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
update := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
}
s.SendServiceUpdate(update)
msg1 := drainChannel(ch1)
msg2 := drainChannel(ch2)
require.NotNil(t, msg1)
require.NotNil(t, msg2)
assert.NotEmpty(t, msg1.AuthToken)
assert.NotEmpty(t, msg2.AuthToken)
assert.NotEqual(t, msg1.AuthToken, msg2.AuthToken, "tokens must be unique per proxy")
// Both tokens should validate
assert.NoError(t, tokenStore.ValidateAndConsume(msg1.AuthToken, "account-1", "service-1"))
assert.NoError(t, tokenStore.ValidateAndConsume(msg2.AuthToken, "account-1", "service-1"))
}
// generateState creates a state using the same format as GetOIDCURL.
func generateState(s *ProxyServiceServer, redirectURL string) string {
nonce := make([]byte, 16)
_, _ = rand.Read(nonce)
nonceB64 := base64.URLEncoding.EncodeToString(nonce)
payload := redirectURL + "|" + nonceB64
hmacSum := s.generateHMAC(payload)
return base64.URLEncoding.EncodeToString([]byte(redirectURL)) + "|" + nonceB64 + "|" + hmacSum
}
func TestOAuthState_NeverTheSame(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
redirectURL := "https://app.example.com/callback"
// Generate 100 states for the same redirect URL
states := make(map[string]bool)
for i := 0; i < 100; i++ {
state := generateState(s, redirectURL)
// State must have 3 parts: base64(url)|nonce|hmac
parts := strings.Split(state, "|")
require.Equal(t, 3, len(parts), "state must have 3 parts")
// State must be unique
require.False(t, states[state], "state %d is a duplicate", i)
states[state] = true
}
}
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
// Old format had only 2 parts: base64(url)|hmac
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
_, _, err := s.ValidateState("base64url|hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state format")
}
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
// Store with tampered HMAC
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state signature")
}

View File

@@ -54,7 +54,7 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
pubKey, privKey := generateSessionKeyPair(t)
testProxy := &reverseproxy.ReverseProxy{
testProxy := &reverseproxy.Service{
ID: "testProxyId",
AccountID: "testAccountId",
Name: "Test Proxy",
@@ -68,9 +68,9 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
},
},
}
require.NoError(t, testStore.CreateReverseProxy(ctx, testProxy))
require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.ReverseProxy{
restrictedProxy := &reverseproxy.Service{
ID: "restrictedProxyId",
AccountID: "testAccountId",
Name: "Restricted Proxy",
@@ -85,7 +85,7 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
},
},
}
require.NoError(t, testStore.CreateReverseProxy(ctx, restrictedProxy))
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
}
func generateSessionKeyPair(t *testing.T) (string, string) {
@@ -106,7 +106,7 @@ func TestValidateSession_UserAllowed(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetReverseProxyByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "test-proxy.example.com")
@@ -126,7 +126,7 @@ func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetReverseProxyByID(context.Background(), store.LockingStrengthNone, "testAccountId", "restrictedProxyId")
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "restrictedProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "nonGroupUserId", "restricted-proxy.example.com")
@@ -146,7 +146,7 @@ func TestValidateSession_UserInDifferentAccount(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetReverseProxyByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "otherAccountUserId", "test-proxy.example.com")
@@ -165,7 +165,7 @@ func TestValidateSession_UserNotFound(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetReverseProxyByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "nonExistentUserId", "test-proxy.example.com")
@@ -184,7 +184,7 @@ func TestValidateSession_ProxyNotFound(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetReverseProxyByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "unknown-proxy.example.com")
@@ -243,23 +243,23 @@ type testValidateSessionProxyManager struct {
store store.Store
}
func (m *testValidateSessionProxyManager) GetAllReverseProxies(_ context.Context, _, _ string) ([]*reverseproxy.ReverseProxy, error) {
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) GetReverseProxy(_ context.Context, _, _, _ string) (*reverseproxy.ReverseProxy, error) {
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CreateReverseProxy(_ context.Context, _, _ string, _ *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) {
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) UpdateReverseProxy(_ context.Context, _, _ string, _ *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) {
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) DeleteReverseProxy(_ context.Context, _, _, _ string) error {
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error {
return nil
}
@@ -271,27 +271,27 @@ func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ stri
return nil
}
func (m *testValidateSessionProxyManager) ReloadAllReverseProxiesForAccount(_ context.Context, _ string) error {
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) ReloadReverseProxy(_ context.Context, _, _ string) error {
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) GetGlobalReverseProxies(ctx context.Context) ([]*reverseproxy.ReverseProxy, error) {
return m.store.GetReverseProxies(ctx, store.LockingStrengthNone)
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone)
}
func (m *testValidateSessionProxyManager) GetProxyByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.ReverseProxy, error) {
return m.store.GetReverseProxyByID(ctx, store.LockingStrengthNone, accountID, proxyID)
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
}
func (m *testValidateSessionProxyManager) GetAccountReverseProxies(ctx context.Context, accountID string) ([]*reverseproxy.ReverseProxy, error) {
return m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID)
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *testValidateSessionProxyManager) GetProxyIDByTargetID(_ context.Context, _, _ string) (string, error) {
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}

View File

@@ -115,8 +115,8 @@ type DefaultAccountManager struct {
var _ account.Manager = (*DefaultAccountManager)(nil)
func (am *DefaultAccountManager) SetReverseProxyManager(reverseProxyManager reverseproxy.Manager) {
am.reverseProxyManager = reverseProxyManager
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
am.reverseProxyManager = serviceManager
}
func isUniqueConstraintError(err error) bool {
@@ -327,8 +327,8 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
return err
}
if err = am.reverseProxyManager.ReloadAllReverseProxiesForAccount(ctx, accountID); err != nil {
log.WithContext(ctx).Warnf("failed to reload all reverse proxy for account %s: %v", accountID, err)
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
}
updateAccountPeers = true
}

View File

@@ -140,5 +140,5 @@ type Manager interface {
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
SetReverseProxyManager(reverseProxyManager reverseproxy.Manager)
SetServiceManager(serviceManager reverseproxy.Manager)
}

View File

@@ -27,6 +27,8 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
nbAccount "github.com/netbirdio/netbird/management/server/account"
@@ -1800,6 +1802,14 @@ func TestAccount_Copy(t *testing.T) {
Address: "172.12.6.1/24",
},
},
Services: []*reverseproxy.Service{
{
ID: "service1",
Name: "test-service",
AccountID: "account1",
Targets: []*reverseproxy.Target{},
},
},
NetworkMapCache: &types.NetworkMapBuilder{},
}
account.InitOnce()
@@ -3112,6 +3122,8 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err
}
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil))
return manager, updateManager, nil
}

View File

@@ -204,9 +204,9 @@ const (
UserInviteLinkRegenerated Activity = 106
UserInviteLinkDeleted Activity = 107
ReverseProxyCreated Activity = 108
ReverseProxyUpdated Activity = 109
ReverseProxyDeleted Activity = 110
ServiceCreated Activity = 108
ServiceUpdated Activity = 109
ServiceDeleted Activity = 110
AccountDeleted Activity = 99999
)
@@ -342,9 +342,9 @@ var activityMap = map[Activity]Code{
UserInviteLinkRegenerated: {"User invite link regenerated", "user.invite.link.regenerate"},
UserInviteLinkDeleted: {"User invite link deleted", "user.invite.link.delete"},
ReverseProxyCreated: {"Reverse proxy created", "reverseproxy.create"},
ReverseProxyUpdated: {"Reverse proxy updated", "reverseproxy.update"},
ReverseProxyDeleted: {"Reverse proxy deleted", "reverseproxy.delete"},
ServiceCreated: {"Service created", "service.create"},
ServiceUpdated: {"Service updated", "service.update"},
ServiceDeleted: {"Service deleted", "service.delete"},
}
// StringCode returns a string code of the activity

View File

@@ -703,7 +703,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
t.Run("saving group linked to network router", func(t *testing.T) {
permissionsManager := permissions.NewManager(manager.Store)
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)

View File

@@ -4,23 +4,28 @@ import (
"context"
"fmt"
"net/http"
"net/netip"
"os"
"strconv"
"time"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/types"
"github.com/rs/cors"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
idpmanager "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
@@ -32,6 +37,8 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/geolocation"
@@ -45,7 +52,6 @@ import (
"github.com/netbirdio/netbird/management/server/http/handlers/networks"
"github.com/netbirdio/netbird/management/server/http/handlers/peers"
"github.com/netbirdio/netbird/management/server/http/handlers/policies"
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
"github.com/netbirdio/netbird/management/server/http/handlers/routes"
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
"github.com/netbirdio/netbird/management/server/http/handlers/users"
@@ -67,7 +73,7 @@ const (
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *domain.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -173,7 +179,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
// Register OAuth callback handler for proxy authentication
if proxyGRPCServer != nil {
oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer)
oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies)
oauthHandler.RegisterEndpoints(router)
}

View File

@@ -152,6 +152,11 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
return
}
if peer.ProxyMeta.Embedded {
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "not allowed to read peer"), w)
return
}
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
if err != nil {
util.WriteError(ctx, err, w)
@@ -319,6 +324,9 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers))
respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
if peer.ProxyMeta.Embedded {
continue
}
respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0))
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"net"
"net/http"
"net/netip"
"net/url"
"strings"
"time"
@@ -21,12 +22,13 @@ import (
// AuthCallbackHandler handles OAuth callbacks for proxy authentication.
type AuthCallbackHandler struct {
proxyService *nbgrpc.ProxyServiceServer
rateLimiter *middleware.APIRateLimiter
proxyService *nbgrpc.ProxyServiceServer
rateLimiter *middleware.APIRateLimiter
trustedProxies []netip.Prefix
}
// NewAuthCallbackHandler creates a new OAuth callback handler.
func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallbackHandler {
func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer, trustedProxies []netip.Prefix) *AuthCallbackHandler {
rateLimiterConfig := &middleware.RateLimiterConfig{
RequestsPerMinute: 10,
Burst: 15,
@@ -35,8 +37,9 @@ func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallba
}
return &AuthCallbackHandler{
proxyService: proxyService,
rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig),
proxyService: proxyService,
rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig),
trustedProxies: trustedProxies,
}
}
@@ -46,7 +49,7 @@ func (h *AuthCallbackHandler) RegisterEndpoints(router *mux.Router) {
}
func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
clientIP := getClientIP(r)
clientIP := h.resolveClientIP(r)
if !h.rateLimiter.Allow(clientIP) {
log.WithField("client_ip", clientIP).Warn("OAuth callback rate limit exceeded")
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
@@ -149,23 +152,57 @@ func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config
return claims.Subject
}
// getClientIP extracts the client IP address from the request.
func getClientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
// resolveClientIP extracts the real client IP from the request.
// When trustedProxies is non-empty and the direct peer is trusted,
// it walks X-Forwarded-For right-to-left skipping trusted IPs.
// Otherwise it returns RemoteAddr directly.
func (h *AuthCallbackHandler) resolveClientIP(r *http.Request) string {
remoteIP := extractHost(r.RemoteAddr)
if len(h.trustedProxies) == 0 || !isTrustedProxy(remoteIP, h.trustedProxies) {
return remoteIP
}
xff := r.Header.Get("X-Forwarded-For")
if xff == "" {
return remoteIP
}
parts := strings.Split(xff, ",")
for i := len(parts) - 1; i >= 0; i-- {
ip := strings.TrimSpace(parts[i])
if ip == "" {
continue
}
if !isTrustedProxy(ip, h.trustedProxies) {
return ip
}
return xff
}
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
// All IPs in XFF are trusted; return the leftmost as best guess.
if first := strings.TrimSpace(parts[0]); first != "" {
return first
}
return remoteIP
}
// Fall back to RemoteAddr
host, _, err := net.SplitHostPort(r.RemoteAddr)
func extractHost(remoteAddr string) string {
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return r.RemoteAddr
return remoteAddr
}
return host
}
func isTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
addr, err := netip.ParseAddr(ipStr)
if err != nil {
return false
}
for _, prefix := range trusted {
if prefix.Contains(addr) {
return true
}
}
return false
}

View File

@@ -19,7 +19,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -161,8 +161,8 @@ func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.Ac
return nil
}
func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string) ([]*accesslogs.AccessLogEntry, error) {
return nil, nil
func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, _ *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
return nil, 0, nil
}
func setupAuthCallbackTest(t *testing.T) *testSetup {
@@ -198,9 +198,9 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
usersManager,
)
proxyService.SetProxyManager(&testReverseProxyManager{store: testStore})
proxyService.SetProxyManager(&testServiceManager{store: testStore})
handler := NewAuthCallbackHandler(proxyService)
handler := NewAuthCallbackHandler(proxyService, nil)
router := mux.NewRouter()
handler.RegisterEndpoints(router)
@@ -227,7 +227,7 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
pubKey := base64.StdEncoding.EncodeToString(pub)
privKey := base64.StdEncoding.EncodeToString(priv)
testProxy := &reverseproxy.ReverseProxy{
testProxy := &reverseproxy.Service{
ID: "testProxyId",
AccountID: "testAccountId",
Name: "Test Proxy",
@@ -251,9 +251,9 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
}
require.NoError(t, testStore.CreateReverseProxy(ctx, testProxy))
require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.ReverseProxy{
restrictedProxy := &reverseproxy.Service{
ID: "restrictedProxyId",
AccountID: "testAccountId",
Name: "Restricted Proxy",
@@ -277,9 +277,9 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
}
require.NoError(t, testStore.CreateReverseProxy(ctx, restrictedProxy))
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
noAuthProxy := &reverseproxy.ReverseProxy{
noAuthProxy := &reverseproxy.Service{
ID: "noAuthProxyId",
AccountID: "testAccountId",
Name: "No Auth Proxy",
@@ -302,7 +302,7 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
}
require.NoError(t, testStore.CreateReverseProxy(ctx, noAuthProxy))
require.NoError(t, testStore.CreateService(ctx, noAuthProxy))
}
func strPtr(s string) *string {
@@ -340,60 +340,60 @@ func createTestAccountsAndUsers(t *testing.T, ctx context.Context, testStore sto
require.NoError(t, testStore.SaveUser(ctx, allowedUser))
}
// testReverseProxyManager is a minimal implementation for testing.
type testReverseProxyManager struct {
// testServiceManager is a minimal implementation for testing.
type testServiceManager struct {
store store.Store
}
func (m *testReverseProxyManager) GetAllReverseProxies(_ context.Context, _, _ string) ([]*reverseproxy.ReverseProxy, error) {
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
return nil, nil
}
func (m *testReverseProxyManager) GetReverseProxy(_ context.Context, _, _, _ string) (*reverseproxy.ReverseProxy, error) {
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testReverseProxyManager) CreateReverseProxy(_ context.Context, _, _ string, _ *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) {
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testReverseProxyManager) UpdateReverseProxy(_ context.Context, _, _ string, _ *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) {
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testReverseProxyManager) DeleteReverseProxy(_ context.Context, _, _, _ string) error {
func (m *testServiceManager) DeleteService(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testReverseProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil
}
func (m *testReverseProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
return nil
}
func (m *testReverseProxyManager) ReloadAllReverseProxiesForAccount(_ context.Context, _ string) error {
func (m *testServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
return nil
}
func (m *testReverseProxyManager) ReloadReverseProxy(_ context.Context, _, _ string) error {
func (m *testServiceManager) ReloadService(_ context.Context, _, _ string) error {
return nil
}
func (m *testReverseProxyManager) GetGlobalReverseProxies(ctx context.Context) ([]*reverseproxy.ReverseProxy, error) {
return m.store.GetReverseProxies(ctx, store.LockingStrengthNone)
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone)
}
func (m *testReverseProxyManager) GetProxyByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.ReverseProxy, error) {
return m.store.GetReverseProxyByID(ctx, store.LockingStrengthNone, accountID, proxyID)
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
}
func (m *testReverseProxyManager) GetAccountReverseProxies(ctx context.Context, accountID string) ([]*reverseproxy.ReverseProxy, error) {
return m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID)
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *testReverseProxyManager) GetProxyIDByTargetID(_ context.Context, _, _ string) (string, error) {
func (m *testServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
@@ -446,7 +446,7 @@ func TestAuthCallback_ProxyNotFound(t *testing.T) {
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
require.NoError(t, setup.store.DeleteReverseProxy(context.Background(), "testAccountId", "testProxyId"))
require.NoError(t, setup.store.DeleteService(context.Background(), "testAccountId", "testProxyId"))
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()

View File

@@ -3,6 +3,7 @@ package proxy
import (
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
@@ -12,7 +13,7 @@ import (
)
func TestAuthCallbackHandler_RateLimiting(t *testing.T) {
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{})
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil)
require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized")
req := httptest.NewRequest(http.MethodGet, "/callback?state=test&code=test", nil)
@@ -54,7 +55,7 @@ func TestAuthCallbackHandler_RateLimiting(t *testing.T) {
}
func TestAuthCallbackHandler_RateLimitInHandleCallback(t *testing.T) {
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{})
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil)
testIP := "10.0.0.50"
handler.rateLimiter.Reset(testIP)
@@ -75,46 +76,76 @@ func TestAuthCallbackHandler_RateLimitInHandleCallback(t *testing.T) {
})
}
func TestGetClientIP(t *testing.T) {
func TestResolveClientIP(t *testing.T) {
trusted := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("172.16.0.0/12"),
}
tests := []struct {
name string
remoteAddr string
xForwardedFor string
xRealIP string
trustedProxy []netip.Prefix
expectedIP string
}{
{
name: "extract from RemoteAddr",
remoteAddr: "192.168.1.100:12345",
expectedIP: "192.168.1.100",
name: "no trusted proxies returns RemoteAddr",
remoteAddr: "203.0.113.50:9999",
xForwardedFor: "1.2.3.4",
trustedProxy: nil,
expectedIP: "203.0.113.50",
},
{
name: "extract from X-Forwarded-For single IP",
remoteAddr: "10.0.0.1:54321",
xForwardedFor: "203.0.113.195",
expectedIP: "203.0.113.195",
name: "untrusted RemoteAddr ignores XFF",
remoteAddr: "203.0.113.50:9999",
xForwardedFor: "1.2.3.4, 10.0.0.1",
trustedProxy: trusted,
expectedIP: "203.0.113.50",
},
{
name: "extract from X-Forwarded-For multiple IPs",
remoteAddr: "10.0.0.1:54321",
xForwardedFor: "203.0.113.195, 70.41.3.18, 150.172.238.178",
expectedIP: "203.0.113.195",
name: "trusted RemoteAddr with single client in XFF",
remoteAddr: "10.0.0.1:5000",
xForwardedFor: "203.0.113.50",
trustedProxy: trusted,
expectedIP: "203.0.113.50",
},
{
name: "extract from X-Real-IP",
remoteAddr: "10.0.0.1:54321",
xRealIP: "198.51.100.42",
expectedIP: "198.51.100.42",
name: "trusted RemoteAddr walks past trusted entries in XFF",
remoteAddr: "10.0.0.1:5000",
xForwardedFor: "203.0.113.50, 10.0.0.2, 172.16.0.5",
trustedProxy: trusted,
expectedIP: "203.0.113.50",
},
{
name: "X-Forwarded-For takes precedence over X-Real-IP",
remoteAddr: "10.0.0.1:54321",
xForwardedFor: "203.0.113.195",
xRealIP: "198.51.100.42",
expectedIP: "203.0.113.195",
name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr",
remoteAddr: "10.0.0.1:5000",
trustedProxy: trusted,
expectedIP: "10.0.0.1",
},
{
name: "handle RemoteAddr without port",
name: "all XFF IPs trusted returns leftmost",
remoteAddr: "10.0.0.1:5000",
xForwardedFor: "10.0.0.2, 172.16.0.1, 10.0.0.3",
trustedProxy: trusted,
expectedIP: "10.0.0.2",
},
{
name: "XFF with whitespace",
remoteAddr: "10.0.0.1:5000",
xForwardedFor: " 203.0.113.50 , 10.0.0.2 ",
trustedProxy: trusted,
expectedIP: "203.0.113.50",
},
{
name: "multi-hop with mixed trust",
remoteAddr: "10.0.0.1:5000",
xForwardedFor: "8.8.8.8, 203.0.113.50, 172.16.0.1",
trustedProxy: trusted,
expectedIP: "203.0.113.50",
},
{
name: "RemoteAddr without port",
remoteAddr: "192.168.1.100",
expectedIP: "192.168.1.100",
},
@@ -122,24 +153,22 @@ func TestGetClientIP(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, tt.trustedProxy)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = tt.remoteAddr
if tt.xForwardedFor != "" {
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
}
if tt.xRealIP != "" {
req.Header.Set("X-Real-IP", tt.xRealIP)
}
ip := getClientIP(req)
assert.Equal(t, tt.expectedIP, ip, "Extracted IP should match expected")
ip := handler.resolveClientIP(req)
assert.Equal(t, tt.expectedIP, ip)
})
}
}
func TestAuthCallbackHandler_RateLimiterConfiguration(t *testing.T) {
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{})
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil)
require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized")

View File

@@ -11,6 +11,7 @@ import (
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/management-integrations/integrations"
serverauth "github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
@@ -130,8 +131,10 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
}
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
userAuth.AccountId = impersonate[0]
userAuth.IsChild = ok
if integrations.IsValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
userAuth.AccountId = impersonate[0]
userAuth.IsChild = true
}
}
// Email is now extracted in ToUserAuth (from claims or userinfo endpoint)
@@ -207,8 +210,10 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
}
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
userAuth.AccountId = impersonate[0]
userAuth.IsChild = ok
if integrations.IsValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
userAuth.AccountId = impersonate[0]
userAuth.IsChild = true
}
}
return nbcontext.SetUserAuthInRequest(r, userAuth), nil

View File

@@ -627,15 +627,14 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
},
},
{
name: "Valid PAT Token accesses child",
name: "PAT Token with account param ignored in public version",
path: "/test?account=xyz",
authHeader: "Token " + PAT,
expectedUserAuth: &nbauth.UserAuth{
AccountId: "xyz",
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
DomainCategory: testAccount.DomainCategory,
IsChild: true,
IsPAT: true,
},
},
@@ -652,15 +651,14 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
},
{
name: "Valid JWT Token with child",
name: "JWT Token with account param ignored in public version",
path: "/test?account=xyz",
authHeader: "Bearer " + JWT,
expectedUserAuth: &nbauth.UserAuth{
AccountId: "xyz",
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
DomainCategory: testAccount.DomainCategory,
IsChild: true,
},
},
}

View File

@@ -10,6 +10,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/management-integrations/integrations"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
@@ -86,6 +90,14 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
t.Fatalf("Failed to create manager: %v", err)
}
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager)
domainManager := manager.NewManager(store, proxyServiceServer, permissionsManager)
reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager)
proxyServiceServer.SetProxyManager(reverseProxyManager)
am.SetServiceManager(reverseProxyManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
authManagerMock := &serverauth.MockManager{
@@ -102,7 +114,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}

View File

@@ -12,6 +12,7 @@ import (
"google.golang.org/grpc/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
@@ -147,6 +148,10 @@ type MockAccountManager struct {
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
}
func (am *MockAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
// Mock implementation - no-op
}
func (am *MockAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
if am.CreatePeerJobFunc != nil {
return am.CreatePeerJobFunc(ctx, accountID, peerID, userID, job)

View File

@@ -29,7 +29,7 @@ func Test_GetAllNetworksReturnsNetworks(t *testing.T) {
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
networks, err := manager.GetAllNetworks(ctx, accountID, userID)
@@ -52,7 +52,7 @@ func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) {
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
networks, err := manager.GetAllNetworks(ctx, accountID, userID)
@@ -75,7 +75,7 @@ func Test_GetNetworkReturnsNetwork(t *testing.T) {
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
networks, err := manager.GetNetwork(ctx, accountID, userID, networkID)
@@ -98,7 +98,7 @@ func Test_GetNetworkReturnsPermissionDenied(t *testing.T) {
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
network, err := manager.GetNetwork(ctx, accountID, userID, networkID)
@@ -123,7 +123,7 @@ func Test_CreateNetworkSuccessfully(t *testing.T) {
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
createdNetwork, err := manager.CreateNetwork(ctx, userID, network)
@@ -148,7 +148,7 @@ func Test_CreateNetworkFailsWithPermissionDenied(t *testing.T) {
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
createdNetwork, err := manager.CreateNetwork(ctx, userID, network)
@@ -171,7 +171,7 @@ func Test_DeleteNetworkSuccessfully(t *testing.T) {
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
err = manager.DeleteNetwork(ctx, accountID, userID, networkID)
@@ -193,7 +193,7 @@ func Test_DeleteNetworkFailsWithPermissionDenied(t *testing.T) {
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
err = manager.DeleteNetwork(ctx, accountID, userID, networkID)
@@ -218,7 +218,7 @@ func Test_UpdateNetworkSuccessfully(t *testing.T) {
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network)
@@ -245,7 +245,7 @@ func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) {
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network)

View File

@@ -264,7 +264,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
// TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them
go func() {
err := m.reverseProxyManager.ReloadAllReverseProxiesForAccount(ctx, resource.AccountID)
err := m.reverseProxyManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
if err != nil {
log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err)
}
@@ -322,12 +322,12 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
return status.NewPermissionDeniedError()
}
proxyID, err := m.reverseProxyManager.GetProxyIDByTargetID(ctx, accountID, resourceID)
serviceID, err := m.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
if err != nil {
return fmt.Errorf("failed to check if resource is used by reverse proxy: %w", err)
return fmt.Errorf("failed to check if resource is used by service: %w", err)
}
if proxyID != "" {
return status.NewResourceInUseError(resourceID, proxyID)
if serviceID != "" {
return status.NewResourceInUseError(resourceID, serviceID)
}
var events []func()

View File

@@ -4,8 +4,10 @@ import (
"context"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -28,7 +30,9 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
require.NoError(t, err)
@@ -49,7 +53,9 @@ func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) {
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
require.Error(t, err)
@@ -69,7 +75,9 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) {
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
require.NoError(t, err)
@@ -89,7 +97,9 @@ func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) {
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
require.Error(t, err)
@@ -112,7 +122,9 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) {
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
require.NoError(t, err)
@@ -134,7 +146,9 @@ func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) {
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
require.Error(t, err)
@@ -161,7 +175,10 @@ func Test_CreateResourceSuccessfully(t *testing.T) {
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.NoError(t, err)
@@ -187,7 +204,9 @@ func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) {
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err)
@@ -214,7 +233,9 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err)
@@ -240,7 +261,9 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err)
@@ -270,7 +293,10 @@ func Test_UpdateResourceSuccessfully(t *testing.T) {
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.NoError(t, err)
@@ -302,7 +328,9 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) {
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err)
@@ -332,7 +360,9 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) {
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err)
@@ -361,7 +391,9 @@ func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) {
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err)
@@ -383,7 +415,10 @@ func Test_DeleteResourceSuccessfully(t *testing.T) {
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
require.NoError(t, err)
@@ -404,7 +439,9 @@ func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) {
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
require.Error(t, err)

View File

@@ -221,6 +221,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return err
}
if peer.ProxyMeta.Embedded {
return fmt.Errorf("not allowed to update peer")
}
settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
@@ -489,12 +493,12 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
var settings *types.Settings
var eventsToStore []func()
proxyID, err := am.reverseProxyManager.GetProxyIDByTargetID(ctx, accountID, peerID)
serviceID, err := am.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, peerID)
if err != nil {
return fmt.Errorf("failed to check if resource is used by reverse proxy: %w", err)
return fmt.Errorf("failed to check if resource is used by service: %w", err)
}
if proxyID != "" {
return status.NewPeerInUseError(peerID, proxyID)
if serviceID != "" {
return status.NewPeerInUseError(peerID, serviceID)
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {

View File

@@ -131,7 +131,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.ReverseProxy{}, &reverseproxy.Target{}, &domain.Domain{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.Service{}, &reverseproxy.Target{}, &domain.Domain{},
&accesslogs.AccessLogEntry{},
)
if err != nil {
@@ -1100,7 +1100,7 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
Preload("NetworkRouters").
Preload("NetworkResources").
Preload("Onboarding").
Preload("ReverseProxies").
Preload("Services.Targets").
Take(&account, idQueryCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
@@ -1281,12 +1281,12 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
wg.Add(1)
go func() {
defer wg.Done()
proxies, err := s.getProxies(ctx, accountID)
services, err := s.getServices(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.ReverseProxies = proxies
account.Services = services
}()
wg.Add(1)
@@ -2063,39 +2063,39 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
return checks, nil
}
func (s *SqlStore) getProxies(ctx context.Context, accountID string) ([]*reverseproxy.ReverseProxy, error) {
const proxyQuery = `SELECT id, account_id, name, domain, enabled, auth,
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
pass_host_header, rewrite_redirects, session_private_key, session_public_key
FROM reverse_proxies WHERE account_id = $1`
FROM services WHERE account_id = $1`
const targetsQuery = `SELECT id, account_id, reverse_proxy_id, path, host, port, protocol,
const targetsQuery = `SELECT id, account_id, service_id, path, host, port, protocol,
target_id, target_type, enabled
FROM targets WHERE reverse_proxy_id = ANY($1)`
FROM targets WHERE service_id = ANY($1)`
proxyRows, err := s.pool.Query(ctx, proxyQuery, accountID)
serviceRows, err := s.pool.Query(ctx, serviceQuery, accountID)
if err != nil {
return nil, err
}
proxies, err := pgx.CollectRows(proxyRows, func(row pgx.CollectableRow) (*reverseproxy.ReverseProxy, error) {
var p reverseproxy.ReverseProxy
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*reverseproxy.Service, error) {
var s reverseproxy.Service
var auth []byte
var createdAt, certIssuedAt sql.NullTime
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
err := row.Scan(
&p.ID,
&p.AccountID,
&p.Name,
&p.Domain,
&p.Enabled,
&s.ID,
&s.AccountID,
&s.Name,
&s.Domain,
&s.Enabled,
&auth,
&createdAt,
&certIssuedAt,
&status,
&proxyCluster,
&p.PassHostHeader,
&p.RewriteRedirects,
&s.PassHostHeader,
&s.RewriteRedirects,
&sessionPrivateKey,
&sessionPublicKey,
)
@@ -2104,50 +2104,50 @@ func (s *SqlStore) getProxies(ctx context.Context, accountID string) ([]*reverse
}
if auth != nil {
if err := json.Unmarshal(auth, &p.Auth); err != nil {
if err := json.Unmarshal(auth, &s.Auth); err != nil {
return nil, err
}
}
p.Meta = reverseproxy.ReverseProxyMeta{}
s.Meta = reverseproxy.ServiceMeta{}
if createdAt.Valid {
p.Meta.CreatedAt = createdAt.Time
s.Meta.CreatedAt = createdAt.Time
}
if certIssuedAt.Valid {
p.Meta.CertificateIssuedAt = certIssuedAt.Time
s.Meta.CertificateIssuedAt = certIssuedAt.Time
}
if status.Valid {
p.Meta.Status = status.String
s.Meta.Status = status.String
}
if proxyCluster.Valid {
p.ProxyCluster = proxyCluster.String
s.ProxyCluster = proxyCluster.String
}
if sessionPrivateKey.Valid {
p.SessionPrivateKey = sessionPrivateKey.String
s.SessionPrivateKey = sessionPrivateKey.String
}
if sessionPublicKey.Valid {
p.SessionPublicKey = sessionPublicKey.String
s.SessionPublicKey = sessionPublicKey.String
}
p.Targets = []*reverseproxy.Target{}
return &p, nil
s.Targets = []*reverseproxy.Target{}
return &s, nil
})
if err != nil {
return nil, err
}
if len(proxies) == 0 {
return proxies, nil
if len(services) == 0 {
return services, nil
}
proxyIDs := make([]string, len(proxies))
proxyMap := make(map[string]*reverseproxy.ReverseProxy)
for i, p := range proxies {
proxyIDs[i] = p.ID
proxyMap[p.ID] = p
serviceIDs := make([]string, len(services))
serviceMap := make(map[string]*reverseproxy.Service)
for i, s := range services {
serviceIDs[i] = s.ID
serviceMap[s.ID] = s
}
targetRows, err := s.pool.Query(ctx, targetsQuery, proxyIDs)
targetRows, err := s.pool.Query(ctx, targetsQuery, serviceIDs)
if err != nil {
return nil, err
}
@@ -2158,7 +2158,7 @@ func (s *SqlStore) getProxies(ctx context.Context, accountID string) ([]*reverse
err := row.Scan(
&t.ID,
&t.AccountID,
&t.ReverseProxyID,
&t.ServiceID,
&path,
&t.Host,
&t.Port,
@@ -2180,12 +2180,12 @@ func (s *SqlStore) getProxies(ctx context.Context, accountID string) ([]*reverse
}
for _, target := range targets {
if proxy, ok := proxyMap[target.ReverseProxyID]; ok {
proxy.Targets = append(proxy.Targets, target)
if service, ok := serviceMap[target.ServiceID]; ok {
service.Targets = append(service.Targets, target)
}
}
return proxies, nil
return services, nil
}
func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) {
@@ -2792,7 +2792,7 @@ func getGormConfig() *gorm.Config {
// newPostgresStore initializes a new Postgres store.
func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics, skipMigration bool) (Store, error) {
dsn, ok := os.LookupEnv(postgresDsnEnv)
dsn, ok := lookupDSNEnv(postgresDsnEnv, postgresDsnEnvLegacy)
if !ok {
return nil, fmt.Errorf("%s is not set", postgresDsnEnv)
}
@@ -2801,7 +2801,7 @@ func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics, skipMig
// newMysqlStore initializes a new MySQL store.
func newMysqlStore(ctx context.Context, metrics telemetry.AppMetrics, skipMigration bool) (Store, error) {
dsn, ok := os.LookupEnv(mysqlDsnEnv)
dsn, ok := lookupDSNEnv(mysqlDsnEnv, mysqlDsnEnvLegacy)
if !ok {
return nil, fmt.Errorf("%s is not set", mysqlDsnEnv)
}
@@ -4825,35 +4825,35 @@ func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStren
return peerID, nil
}
func (s *SqlStore) CreateReverseProxy(ctx context.Context, proxy *reverseproxy.ReverseProxy) error {
proxyCopy := proxy.Copy()
if err := proxyCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt reverse proxy data: %w", err)
func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Service) error {
serviceCopy := service.Copy()
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt service data: %w", err)
}
result := s.db.Create(proxyCopy)
result := s.db.Create(serviceCopy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to create reverse proxy to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to create reverse proxy to store")
log.WithContext(ctx).Errorf("failed to create service to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to create service to store")
}
return nil
}
func (s *SqlStore) UpdateReverseProxy(ctx context.Context, proxy *reverseproxy.ReverseProxy) error {
proxyCopy := proxy.Copy()
if err := proxyCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt reverse proxy data: %w", err)
func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error {
serviceCopy := service.Copy()
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt service data: %w", err)
}
// Use a transaction to ensure atomic updates of the proxy and its targets
// Use a transaction to ensure atomic updates of the service and its targets
err := s.db.Transaction(func(tx *gorm.DB) error {
// Delete existing targets
if err := tx.Where("reverse_proxy_id = ?", proxyCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil {
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil {
return err
}
// Update the proxy and create new targets
if err := tx.Session(&gorm.Session{FullSaveAssociations: true}).Save(proxyCopy).Error; err != nil {
// Update the service and create new targets
if err := tx.Session(&gorm.Session{FullSaveAssociations: true}).Save(serviceCopy).Error; err != nil {
return err
}
@@ -4861,112 +4861,112 @@ func (s *SqlStore) UpdateReverseProxy(ctx context.Context, proxy *reverseproxy.R
})
if err != nil {
log.WithContext(ctx).Errorf("failed to update reverse proxy to store: %v", err)
return status.Errorf(status.Internal, "failed to update reverse proxy to store")
log.WithContext(ctx).Errorf("failed to update service to store: %v", err)
return status.Errorf(status.Internal, "failed to update service to store")
}
return nil
}
func (s *SqlStore) DeleteReverseProxy(ctx context.Context, accountID, proxyID string) error {
result := s.db.Delete(&reverseproxy.ReverseProxy{}, accountAndIDQueryCondition, accountID, proxyID)
func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID string) error {
result := s.db.Delete(&reverseproxy.Service{}, accountAndIDQueryCondition, accountID, serviceID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete reverse proxy from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete reverse proxy from store")
log.WithContext(ctx).Errorf("failed to delete service from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete service from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "reverse proxy %s not found", proxyID)
return status.Errorf(status.NotFound, "service %s not found", serviceID)
}
return nil
}
func (s *SqlStore) GetReverseProxyByID(ctx context.Context, lockStrength LockingStrength, accountID, proxyID string) (*reverseproxy.ReverseProxy, error) {
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var proxy *reverseproxy.ReverseProxy
result := tx.Take(&proxy, accountAndIDQueryCondition, accountID, proxyID)
var service *reverseproxy.Service
result := tx.Take(&service, accountAndIDQueryCondition, accountID, serviceID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "reverse proxy %s not found", proxyID)
return nil, status.Errorf(status.NotFound, "service %s not found", serviceID)
}
log.WithContext(ctx).Errorf("failed to get reverse proxy from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get reverse proxy from store")
log.WithContext(ctx).Errorf("failed to get service from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get service from store")
}
if err := proxy.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt reverse proxy data: %w", err)
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt service data: %w", err)
}
return proxy, nil
return service, nil
}
func (s *SqlStore) GetReverseProxyByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.ReverseProxy, error) {
var proxy *reverseproxy.ReverseProxy
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&proxy)
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
var service *reverseproxy.Service
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "reverse proxy with domain %s not found", domain)
return nil, status.Errorf(status.NotFound, "service with domain %s not found", domain)
}
log.WithContext(ctx).Errorf("failed to get reverse proxy by domain from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get reverse proxy by domain from store")
log.WithContext(ctx).Errorf("failed to get service by domain from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get service by domain from store")
}
if err := proxy.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt reverse proxy data: %w", err)
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt service data: %w", err)
}
return proxy, nil
return service, nil
}
func (s *SqlStore) GetReverseProxies(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.ReverseProxy, error) {
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var proxyList []*reverseproxy.ReverseProxy
result := tx.Find(&proxyList)
var serviceList []*reverseproxy.Service
result := tx.Find(&serviceList)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get reverse proxy from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get reverse proxy from store")
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get services from store")
}
for _, proxy := range proxyList {
if err := proxy.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt reverse proxy data: %w", err)
for _, service := range serviceList {
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt service data: %w", err)
}
}
return proxyList, nil
return serviceList, nil
}
func (s *SqlStore) GetAccountReverseProxies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error) {
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var proxyList []*reverseproxy.ReverseProxy
result := tx.Find(&proxyList, accountIDCondition, accountID)
var serviceList []*reverseproxy.Service
result := tx.Find(&serviceList, accountIDCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get reverse proxy from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get reverse proxy from store")
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get services from store")
}
for _, proxy := range proxyList {
if err := proxy.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt reverse proxy data: %w", err)
for _, service := range serviceList {
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt service data: %w", err)
}
}
return proxyList, nil
return serviceList, nil
}
func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) {
@@ -4976,11 +4976,11 @@ func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domain
result := tx.Take(&customDomain, accountAndIDQueryCondition, accountID, domainID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "reverse proxy custom domain %s not found", domainID)
return nil, status.Errorf(status.NotFound, "custom domain %s not found", domainID)
}
log.WithContext(ctx).Errorf("failed to get reverse proxy custom domain from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get reverse proxy custom domain from store")
log.WithContext(ctx).Errorf("failed to get custom domain from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get custom domain from store")
}
return customDomain, nil
@@ -5051,10 +5051,10 @@ func (s *SqlStore) CreateAccessLog(ctx context.Context, logEntry *accesslogs.Acc
result := s.db.Create(logEntry)
if result.Error != nil {
log.WithContext(ctx).WithFields(log.Fields{
"proxy_id": logEntry.ProxyID,
"method": logEntry.Method,
"host": logEntry.Host,
"path": logEntry.Path,
"service_id": logEntry.ServiceID,
"method": logEntry.Method,
"host": logEntry.Host,
"path": logEntry.Path,
}).Errorf("failed to create access log entry in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to create access log entry in store")
}
@@ -5105,8 +5105,8 @@ func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.Acces
if filter.Search != nil {
searchPattern := "%" + *filter.Search + "%"
query = query.Where(
"location_connection_ip LIKE ? OR host LIKE ? OR path LIKE ? OR CONCAT(host, path) LIKE ? OR user_id IN (SELECT id FROM users WHERE email LIKE ? OR name LIKE ?)",
searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern,
"id LIKE ? OR location_connection_ip LIKE ? OR host LIKE ? OR path LIKE ? OR CONCAT(host, path) LIKE ? OR user_id IN (SELECT id FROM users WHERE email LIKE ? OR name LIKE ?)",
searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern,
)
}
@@ -5132,9 +5132,10 @@ func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.Acces
}
if filter.Status != nil {
if *filter.Status == "success" {
switch *filter.Status {
case "success":
query = query.Where("status_code >= ? AND status_code < ?", 200, 400)
} else if *filter.Status == "failed" {
case "failed":
query = query.Where("status_code < ? OR status_code >= ?", 200, 400)
}
}
@@ -5154,7 +5155,7 @@ func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.Acces
return query
}
func (s *SqlStore) GetReverseProxyTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) {
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
@@ -5164,11 +5165,11 @@ func (s *SqlStore) GetReverseProxyTargetByTargetID(ctx context.Context, lockStre
result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "reverse proxy target with ID %s not found", targetID)
return nil, status.Errorf(status.NotFound, "service target with ID %s not found", targetID)
}
log.WithContext(ctx).Errorf("failed to get reverse proxy target from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get reverse proxy target from store")
log.WithContext(ctx).Errorf("failed to get service target from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get service target from store")
}
return target, nil

View File

@@ -20,6 +20,7 @@ import (
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -263,7 +264,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
&types.Policy{}, &types.PolicyRule{}, &route.Route{},
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
&types.AccountOnboarding{},
&types.AccountOnboarding{}, &reverseproxy.Service{}, &reverseproxy.Target{},
}
for i := len(models) - 1; i >= 0; i-- {

View File

@@ -250,13 +250,13 @@ type Store interface {
MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error)
CreateReverseProxy(ctx context.Context, service *reverseproxy.ReverseProxy) error
UpdateReverseProxy(ctx context.Context, service *reverseproxy.ReverseProxy) error
DeleteReverseProxy(ctx context.Context, accountID, serviceID string) error
GetReverseProxyByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.ReverseProxy, error)
GetReverseProxyByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.ReverseProxy, error)
GetReverseProxies(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.ReverseProxy, error)
GetAccountReverseProxies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error)
CreateService(ctx context.Context, service *reverseproxy.Service) error
UpdateService(ctx context.Context, service *reverseproxy.Service) error
DeleteService(ctx context.Context, accountID, serviceID string) error
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error)
GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error)
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error)
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
@@ -267,14 +267,24 @@ type Store interface {
CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
GetReverseProxyTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error)
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error)
}
const (
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
postgresDsnEnv = "NB_STORE_ENGINE_POSTGRES_DSN"
postgresDsnEnvLegacy = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
mysqlDsnEnv = "NB_STORE_ENGINE_MYSQL_DSN"
mysqlDsnEnvLegacy = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
)
// lookupDSNEnv checks the NB_ env var first, then falls back to the legacy NETBIRD_ env var.
func lookupDSNEnv(nbKey, legacyKey string) (string, bool) {
if v, ok := os.LookupEnv(nbKey); ok {
return v, true
}
return os.LookupEnv(legacyKey)
}
var supportedEngines = []types.Engine{types.SqliteStoreEngine, types.PostgresStoreEngine, types.MysqlStoreEngine}
func getStoreEngineFromEnv() types.Engine {
@@ -559,7 +569,7 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine)
}
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
dsn, ok := os.LookupEnv(postgresDsnEnv)
dsn, ok := lookupDSNEnv(postgresDsnEnv, postgresDsnEnvLegacy)
if !ok || dsn == "" {
var err error
_, dsn, err = testutil.CreatePostgresTestContainer()
@@ -597,7 +607,7 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Eng
}
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
dsn, ok := os.LookupEnv(mysqlDsnEnv)
dsn, ok := lookupDSNEnv(mysqlDsnEnv, mysqlDsnEnvLegacy)
if !ok || dsn == "" {
var err error
_, dsn, err = testutil.CreateMysqlTestContainer()

View File

@@ -122,6 +122,7 @@ type defaultAppMetrics struct {
Meter metric2.Meter
listener net.Listener
ctx context.Context
externallyManaged bool
idpMetrics *IDPMetrics
httpMiddleware *HTTPMiddleware
grpcMetrics *GRPCMetrics
@@ -171,6 +172,9 @@ func (appMetrics *defaultAppMetrics) Close() error {
// Expose metrics on a given port and endpoint. If endpoint is empty a defaultEndpoint one will be used.
// Exposes metrics in the Prometheus format https://prometheus.io/
func (appMetrics *defaultAppMetrics) Expose(ctx context.Context, port int, endpoint string) error {
if appMetrics.externallyManaged {
return nil
}
if endpoint == "" {
endpoint = defaultEndpoint
}
@@ -252,3 +256,49 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
accountManagerMetrics: accountManagerMetrics,
}, nil
}
// NewAppMetricsWithMeter creates AppMetrics using an externally provided meter.
// The caller is responsible for exposing metrics via HTTP. Expose() and Close() are no-ops.
func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetrics, error) {
idpMetrics, err := NewIDPMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize IDP metrics: %w", err)
}
middleware, err := NewMetricsMiddleware(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize HTTP middleware metrics: %w", err)
}
grpcMetrics, err := NewGRPCMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize gRPC metrics: %w", err)
}
storeMetrics, err := NewStoreMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize store metrics: %w", err)
}
updateChannelMetrics, err := NewUpdateChannelMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize update channel metrics: %w", err)
}
accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err)
}
return &defaultAppMetrics{
Meter: meter,
ctx: ctx,
externallyManaged: true,
idpMetrics: idpMetrics,
httpMiddleware: middleware,
grpcMetrics: grpcMetrics,
storeMetrics: storeMetrics,
updateChannelMetrics: updateChannelMetrics,
accountManagerMetrics: accountManagerMetrics,
}, nil
}

View File

@@ -100,7 +100,7 @@ type Account struct {
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
ReverseProxies []*reverseproxy.ReverseProxy `gorm:"foreignKey:AccountID;references:id"`
Services []*reverseproxy.Service `gorm:"foreignKey:AccountID;references:id"`
// Settings is a dictionary of Account settings
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
@@ -377,7 +377,7 @@ func (a *Account) GetPeerNetworkMap(
// GetProxyConnectionResources returns ACL peers for the proxy-embedded peer based on exposed services.
// No firewall rules are generated here; the proxy peer is always a new on-demand client with a stateful
// firewall, so OUT rules are unnecessary. Inbound rules are handled on the target/router peer side.
func (a *Account) GetProxyConnectionResources(ctx context.Context, exposedServices map[string][]*reverseproxy.ReverseProxy) []*nbpeer.Peer {
func (a *Account) GetProxyConnectionResources(ctx context.Context, exposedServices map[string][]*reverseproxy.Service) []*nbpeer.Peer {
var aclPeers []*nbpeer.Peer
for _, peerServices := range exposedServices {
@@ -406,7 +406,7 @@ func (a *Account) GetProxyConnectionResources(ctx context.Context, exposedServic
// GetPeerProxyResources returns ACL peers and inbound firewall rules for a peer that is targeted by reverse proxy services.
// Only IN rules are generated; OUT rules are omitted since proxy peers are always new clients with stateful firewalls.
// Rules use PortRange only (not the legacy Port field) as this feature only targets current peer versions.
func (a *Account) GetPeerProxyResources(peerID string, services []*reverseproxy.ReverseProxy, proxyPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*FirewallRule) {
func (a *Account) GetPeerProxyResources(peerID string, services []*reverseproxy.Service, proxyPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*FirewallRule) {
var aclPeers []*nbpeer.Peer
var firewallRules []*FirewallRule
@@ -974,6 +974,11 @@ func (a *Account) Copy() *Account {
networkResources = append(networkResources, resource.Copy())
}
services := []*reverseproxy.Service{}
for _, service := range a.Services {
services = append(services, service.Copy())
}
return &Account{
Id: a.Id,
CreatedBy: a.CreatedBy,
@@ -995,6 +1000,7 @@ func (a *Account) Copy() *Account {
Networks: nets,
NetworkRouters: networkRouters,
NetworkResources: networkResources,
Services: services,
Onboarding: a.Onboarding,
NetworkMapCache: a.NetworkMapCache,
nmapInitOnce: a.nmapInitOnce,
@@ -1858,7 +1864,7 @@ func (a *Account) GetProxyPeers() map[string][]*nbpeer.Peer {
return proxyPeers
}
func (a *Account) GetPeerProxyRoutes(ctx context.Context, peer *nbpeer.Peer, proxies map[string][]*reverseproxy.ReverseProxy, resourcesMap map[string]*resourceTypes.NetworkResource, routers map[string]map[string]*routerTypes.NetworkRouter, proxyPeers []*nbpeer.Peer) ([]*route.Route, []*RouteFirewallRule, []*nbpeer.Peer) {
func (a *Account) GetPeerProxyRoutes(ctx context.Context, peer *nbpeer.Peer, proxies map[string][]*reverseproxy.Service, resourcesMap map[string]*resourceTypes.NetworkResource, routers map[string]map[string]*routerTypes.NetworkRouter, proxyPeers []*nbpeer.Peer) ([]*route.Route, []*RouteFirewallRule, []*nbpeer.Peer) {
sourceRanges := make([]string, 0, len(proxyPeers))
for _, proxyPeer := range proxyPeers {
sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, proxyPeer.IP))
@@ -1924,7 +1930,7 @@ func (a *Account) GetResourcesMap() map[string]*resourceTypes.NetworkResource {
}
func (a *Account) InjectProxyPolicies(ctx context.Context) {
if len(a.ReverseProxies) == 0 {
if len(a.Services) == 0 {
return
}
@@ -1933,7 +1939,7 @@ func (a *Account) InjectProxyPolicies(ctx context.Context) {
return
}
for _, service := range a.ReverseProxies {
for _, service := range a.Services {
if !service.Enabled {
continue
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -70,7 +71,7 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, nil, account.GetActiveGroupUsers(), nil, nil)
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -115,7 +116,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
b.Run("old builder", func(b *testing.B) {
for range b.N {
for _, peerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, nil, account.GetActiveGroupUsers(), nil, nil)
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -177,7 +178,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, nil, account.GetActiveGroupUsers(), nil, nil)
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -240,7 +241,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, nil, account.GetActiveGroupUsers(), nil, nil)
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -317,7 +318,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, nil, account.GetActiveGroupUsers(), nil, nil)
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -402,7 +403,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, nil, account.GetActiveGroupUsers(), nil, nil)
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -458,7 +459,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, nil, account.GetActiveGroupUsers(), nil, nil)
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -537,7 +538,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, nil, account.GetActiveGroupUsers(), nil, nil)
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -597,7 +598,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
b.Run("old builder after delete", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, nil, account.GetActiveGroupUsers(), nil, nil)
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})

View File

@@ -36,7 +36,7 @@ type PlainProxyToken string
type ProxyAccessToken struct {
ID string `gorm:"primaryKey"`
Name string
HashedToken HashedProxyToken `gorm:"uniqueIndex"`
HashedToken HashedProxyToken `gorm:"type:varchar(255);uniqueIndex"`
// AccountID is nil for management-wide tokens, set for account-scoped tokens
AccountID *string `gorm:"index"`
ExpiresAt *time.Time

View File

@@ -1,409 +0,0 @@
package types
import (
"context"
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
func TestGetProxyConnectionResources_PeerTarget(t *testing.T) {
account := &Account{
Peers: map[string]*nbpeer.Peer{
"target-peer": {ID: "target-peer", IP: net.ParseIP("100.64.0.1")},
},
}
exposedServices := map[string][]*reverseproxy.ReverseProxy{
"target-peer": {
{
ID: "proxy-1",
Enabled: true,
Targets: []reverseproxy.Target{
{
TargetType: reverseproxy.TargetTypePeer,
TargetId: "target-peer",
Port: 8080,
Enabled: true,
},
},
},
},
}
aclPeers := account.GetProxyConnectionResources(context.Background(), exposedServices)
require.Len(t, aclPeers, 1)
assert.Equal(t, "target-peer", aclPeers[0].ID)
}
func TestGetProxyConnectionResources_DisabledService(t *testing.T) {
account := &Account{
Peers: map[string]*nbpeer.Peer{
"target-peer": {ID: "target-peer", IP: net.ParseIP("100.64.0.1")},
},
}
exposedServices := map[string][]*reverseproxy.ReverseProxy{
"target-peer": {
{
ID: "proxy-1",
Enabled: false,
Targets: []reverseproxy.Target{
{
TargetType: reverseproxy.TargetTypePeer,
TargetId: "target-peer",
Port: 8080,
Enabled: true,
},
},
},
},
}
aclPeers := account.GetProxyConnectionResources(context.Background(), exposedServices)
assert.Empty(t, aclPeers)
}
func TestGetProxyConnectionResources_ResourceTargetSkipped(t *testing.T) {
account := &Account{
Peers: map[string]*nbpeer.Peer{
"router-peer": {ID: "router-peer", IP: net.ParseIP("100.64.0.2")},
},
}
exposedServices := map[string][]*reverseproxy.ReverseProxy{
"router-peer": {
{
ID: "proxy-1",
Enabled: true,
Targets: []reverseproxy.Target{
{
TargetType: reverseproxy.TargetTypeResource,
TargetId: "resource-1",
Port: 443,
Enabled: true,
},
},
},
},
}
aclPeers := account.GetProxyConnectionResources(context.Background(), exposedServices)
assert.Empty(t, aclPeers, "resource targets should not add ACL peers via GetProxyConnectionResources")
}
func TestGetPeerProxyResources_PeerTarget(t *testing.T) {
proxyPeers := []*nbpeer.Peer{
{ID: "proxy-peer-1", IP: net.ParseIP("100.64.0.10")},
{ID: "proxy-peer-2", IP: net.ParseIP("100.64.0.11")},
}
services := []*reverseproxy.ReverseProxy{
{
ID: "proxy-1",
Enabled: true,
Targets: []reverseproxy.Target{
{
TargetType: reverseproxy.TargetTypePeer,
TargetId: "target-peer",
Port: 8080,
Enabled: true,
},
},
},
}
account := &Account{}
aclPeers, fwRules := account.GetPeerProxyResources("target-peer", services, proxyPeers)
require.Len(t, aclPeers, 2, "should include all proxy peers")
require.Len(t, fwRules, 2, "should have one IN rule per proxy peer")
for i, rule := range fwRules {
assert.Equal(t, "proxy-proxy-1", rule.PolicyID)
assert.Equal(t, proxyPeers[i].IP.String(), rule.PeerIP)
assert.Equal(t, FirewallRuleDirectionIN, rule.Direction)
assert.Equal(t, "allow", rule.Action)
assert.Equal(t, string(PolicyRuleProtocolTCP), rule.Protocol)
assert.Equal(t, uint16(8080), rule.PortRange.Start)
assert.Equal(t, uint16(8080), rule.PortRange.End)
}
}
func TestGetPeerProxyResources_PeerTargetMismatch(t *testing.T) {
proxyPeers := []*nbpeer.Peer{
{ID: "proxy-peer-1", IP: net.ParseIP("100.64.0.10")},
}
services := []*reverseproxy.ReverseProxy{
{
ID: "proxy-1",
Enabled: true,
Targets: []reverseproxy.Target{
{
TargetType: reverseproxy.TargetTypePeer,
TargetId: "other-peer",
Port: 8080,
Enabled: true,
},
},
},
}
account := &Account{}
aclPeers, fwRules := account.GetPeerProxyResources("target-peer", services, proxyPeers)
require.Len(t, aclPeers, 1, "should still add proxy peers to ACL")
assert.Empty(t, fwRules, "should not generate rules when target doesn't match this peer")
}
func TestGetPeerProxyResources_ResourceAccessLocal(t *testing.T) {
proxyPeers := []*nbpeer.Peer{
{ID: "proxy-peer-1", IP: net.ParseIP("100.64.0.10")},
}
services := []*reverseproxy.ReverseProxy{
{
ID: "proxy-1",
Enabled: true,
Targets: []reverseproxy.Target{
{
TargetType: reverseproxy.TargetTypeResource,
TargetId: "resource-1",
Port: 443,
Enabled: true,
AccessLocal: true,
},
},
},
}
account := &Account{}
aclPeers, fwRules := account.GetPeerProxyResources("router-peer", services, proxyPeers)
require.Len(t, aclPeers, 1, "should include proxy peers in ACL")
require.Len(t, fwRules, 1, "should generate IN rule for AccessLocal resource")
rule := fwRules[0]
assert.Equal(t, "proxy-proxy-1", rule.PolicyID)
assert.Equal(t, "100.64.0.10", rule.PeerIP)
assert.Equal(t, FirewallRuleDirectionIN, rule.Direction)
assert.Equal(t, uint16(443), rule.PortRange.Start)
}
func TestGetPeerProxyResources_ResourceWithoutAccessLocal(t *testing.T) {
proxyPeers := []*nbpeer.Peer{
{ID: "proxy-peer-1", IP: net.ParseIP("100.64.0.10")},
}
services := []*reverseproxy.ReverseProxy{
{
ID: "proxy-1",
Enabled: true,
Targets: []reverseproxy.Target{
{
TargetType: reverseproxy.TargetTypeResource,
TargetId: "resource-1",
Port: 443,
Enabled: true,
AccessLocal: false,
},
},
},
}
account := &Account{}
aclPeers, fwRules := account.GetPeerProxyResources("router-peer", services, proxyPeers)
require.Len(t, aclPeers, 1, "should still include proxy peers in ACL")
assert.Empty(t, fwRules, "should not generate peer rules when AccessLocal is false")
}
func TestGetPeerProxyResources_MixedTargets(t *testing.T) {
proxyPeers := []*nbpeer.Peer{
{ID: "proxy-peer-1", IP: net.ParseIP("100.64.0.10")},
}
services := []*reverseproxy.ReverseProxy{
{
ID: "proxy-1",
Enabled: true,
Targets: []reverseproxy.Target{
{
TargetType: reverseproxy.TargetTypePeer,
TargetId: "target-peer",
Port: 8080,
Enabled: true,
},
{
TargetType: reverseproxy.TargetTypeResource,
TargetId: "resource-1",
Port: 443,
Enabled: true,
AccessLocal: true,
},
{
TargetType: reverseproxy.TargetTypeResource,
TargetId: "resource-2",
Port: 8443,
Enabled: true,
AccessLocal: false,
},
},
},
}
account := &Account{}
aclPeers, fwRules := account.GetPeerProxyResources("target-peer", services, proxyPeers)
require.Len(t, aclPeers, 1)
require.Len(t, fwRules, 2, "should have rules for peer target + AccessLocal resource")
ports := []uint16{fwRules[0].PortRange.Start, fwRules[1].PortRange.Start}
assert.Contains(t, ports, uint16(8080), "should include peer target port")
assert.Contains(t, ports, uint16(443), "should include AccessLocal resource port")
}
func newProxyRoutesTestAccount() *Account {
return &Account{
Peers: map[string]*nbpeer.Peer{
"router-peer": {ID: "router-peer", Key: "router-key", IP: net.ParseIP("100.64.0.2")},
"proxy-peer": {ID: "proxy-peer", Key: "proxy-key", IP: net.ParseIP("100.64.0.10")},
},
}
}
func TestGetPeerProxyRoutes_ResourceWithoutAccessLocal(t *testing.T) {
account := newProxyRoutesTestAccount()
proxyPeers := []*nbpeer.Peer{account.Peers["proxy-peer"]}
resourcesMap := map[string]*resourceTypes.NetworkResource{
"resource-1": {
ID: "resource-1",
AccountID: "accountID",
NetworkID: "net-1",
Name: "web-service",
Type: resourceTypes.Host,
Prefix: netip.MustParsePrefix("192.168.1.100/32"),
Enabled: true,
},
}
routers := map[string]map[string]*routerTypes.NetworkRouter{
"net-1": {
"router-peer": {ID: "router-1", NetworkID: "net-1", Peer: "router-peer", Masquerade: true, Metric: 100},
},
}
exposedServices := map[string][]*reverseproxy.ReverseProxy{
"router-peer": {
{
ID: "proxy-1",
Enabled: true,
Targets: []reverseproxy.Target{
{
TargetType: reverseproxy.TargetTypeResource,
TargetId: "resource-1",
Port: 443,
Enabled: true,
AccessLocal: false,
},
},
},
},
}
routes, routeFwRules, aclPeers := account.GetPeerProxyRoutes(context.Background(), account.Peers["proxy-peer"], exposedServices, resourcesMap, routers, proxyPeers)
require.NotEmpty(t, routes, "should generate routes for non-AccessLocal resource")
require.NotEmpty(t, routeFwRules, "should generate route firewall rules for non-AccessLocal resource")
require.NotEmpty(t, aclPeers, "should include router peer in ACL")
assert.Equal(t, uint16(443), routeFwRules[0].PortRange.Start)
assert.Equal(t, "192.168.1.100/32", routeFwRules[0].Destination)
}
func TestGetPeerProxyRoutes_ResourceWithAccessLocal(t *testing.T) {
account := newProxyRoutesTestAccount()
proxyPeers := []*nbpeer.Peer{account.Peers["proxy-peer"]}
resourcesMap := map[string]*resourceTypes.NetworkResource{
"resource-1": {
ID: "resource-1",
AccountID: "accountID",
NetworkID: "net-1",
Name: "local-service",
Type: resourceTypes.Host,
Prefix: netip.MustParsePrefix("192.168.1.100/32"),
Enabled: true,
},
}
routers := map[string]map[string]*routerTypes.NetworkRouter{
"net-1": {
"router-peer": {ID: "router-1", NetworkID: "net-1", Peer: "router-peer", Masquerade: true, Metric: 100},
},
}
exposedServices := map[string][]*reverseproxy.ReverseProxy{
"router-peer": {
{
ID: "proxy-1",
Enabled: true,
Targets: []reverseproxy.Target{
{
TargetType: reverseproxy.TargetTypeResource,
TargetId: "resource-1",
Port: 443,
Enabled: true,
AccessLocal: true,
},
},
},
},
}
routes, routeFwRules, aclPeers := account.GetPeerProxyRoutes(context.Background(), account.Peers["proxy-peer"], exposedServices, resourcesMap, routers, proxyPeers)
require.NotEmpty(t, routes, "should generate routes for AccessLocal resource")
require.NotEmpty(t, routeFwRules, "should generate route firewall rules for AccessLocal resource")
require.NotEmpty(t, aclPeers, "should include router peer in ACL for AccessLocal resource")
assert.Equal(t, uint16(443), routeFwRules[0].PortRange.Start)
assert.Equal(t, "192.168.1.100/32", routeFwRules[0].Destination)
}
func TestGetPeerProxyRoutes_PeerTargetSkipped(t *testing.T) {
account := newProxyRoutesTestAccount()
proxyPeers := []*nbpeer.Peer{account.Peers["proxy-peer"]}
exposedServices := map[string][]*reverseproxy.ReverseProxy{
"router-peer": {
{
ID: "proxy-1",
Enabled: true,
Targets: []reverseproxy.Target{
{
TargetType: reverseproxy.TargetTypePeer,
TargetId: "target-peer",
Port: 8080,
Enabled: true,
},
},
},
},
}
routes, routeFwRules, aclPeers := account.GetPeerProxyRoutes(context.Background(), account.Peers["proxy-peer"], exposedServices, nil, nil, proxyPeers)
assert.Empty(t, routes, "should NOT generate routes for peer targets")
assert.Empty(t, routeFwRules, "should NOT generate route firewall rules for peer targets")
assert.Empty(t, aclPeers)
}

View File

@@ -16,7 +16,7 @@ FROM gcr.io/distroless/base:debug
COPY --from=builder /app/netbird-proxy /usr/bin/netbird-proxy
COPY --from=builder /tmp/passwd /etc/passwd
COPY --from=builder /tmp/group /etc/group
COPY --from=builder --chown=1000:1000 /tmp/var/lib/netbird /var/lib/netbird
COPY --from=builder /tmp/var/lib/netbird /var/lib/netbird
COPY --from=builder --chown=1000:1000 /tmp/certs /certs
USER netbird:netbird
ENV HOME=/var/lib/netbird

View File

@@ -70,7 +70,7 @@ The following deployment configuration is available:
| `-mgmt` | `NB_PROXY_MANAGEMENT_ADDRESS` | The address of the management server for the proxy to get configuration from. | `"https://api.netbird.io:443"` |
| `-addr` | `NB_PROXY_ADDRESS` | The address that the reverse proxy will listen on. | `":443` |
| `-url` | `NB_PROXY_URL` | The URL that the proxy will be reached at (where endpoints will be CNAMEd to). If unset, this will fall back to the proxy address. | `"proxy.netbird.io"` |
| `-cert-dir` | `NB_PROXY_CERTIFICATE_DIRECTORY` | The location that certficates are stored in. | `"./certs"` |
| `-cert-dir` | `NB_PROXY_CERTIFICATE_DIRECTORY` | The location that certificates are stored in. | `"./certs"` |
| `-acme-certs` | `NB_PROXY_ACME_CERTIFICATES` | Whether to use ACME to generate certificates. | `false` |
| `-acme-addr` | `NB_PROXY_ACME_ADDRESS` | The HTTP address the proxy will listen on to respond to HTTP-01 ACME challenges | `":80"` |
| `-acme-dir` | `NB_PROXY_ACME_DIRECTORY` | The directory URL of the ACME server to be used | `"https://acme-v02.api.letsencrypt.org/directory"` |

View File

@@ -9,6 +9,7 @@ import (
"strings"
"syscall"
"github.com/netbirdio/netbird/shared/management/domain"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/crypto/acme"
@@ -21,6 +22,8 @@ import (
const DefaultManagementURL = "https://api.netbird.io:443"
// envProxyToken is the environment variable name for the proxy access token.
//
//nolint:gosec
const envProxyToken = "NB_PROXY_TOKEN"
var (
@@ -34,11 +37,12 @@ var (
debugLogs bool
mgmtAddr string
addr string
proxyURL string
proxyDomain string
certDir string
acmeCerts bool
acmeAddr string
acmeDir string
acmeCerts bool
acmeAddr string
acmeDir string
acmeChallengeType string
debugEndpoint bool
debugEndpointAddr string
healthAddr string
@@ -67,11 +71,12 @@ func init() {
rootCmd.PersistentFlags().BoolVar(&debugLogs, "debug", envBoolOrDefault("NB_PROXY_DEBUG_LOGS", false), "Enable debug logs")
rootCmd.Flags().StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to")
rootCmd.Flags().StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on")
rootCmd.Flags().StringVar(&proxyURL, "url", envStringOrDefault("NB_PROXY_URL", ""), "The URL at which this proxy will be reached")
rootCmd.Flags().StringVar(&proxyDomain, "domain", envStringOrDefault("NB_PROXY_DOMAIN", ""), "The Domain at which this proxy will be reached. e.g., netbird.example.com")
rootCmd.Flags().StringVar(&certDir, "cert-dir", envStringOrDefault("NB_PROXY_CERTIFICATE_DIRECTORY", "./certs"), "Directory to store certificates")
rootCmd.Flags().BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates using HTTP-01 challenges")
rootCmd.Flags().StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address for ACME HTTP-01 challenges")
rootCmd.Flags().BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates automatically")
rootCmd.Flags().StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address for ACME HTTP-01 challenges (only used when acme-challenge-type is http-01)")
rootCmd.Flags().StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory")
rootCmd.Flags().StringVar(&acmeChallengeType, "acme-challenge-type", envStringOrDefault("NB_PROXY_ACME_CHALLENGE_TYPE", "tls-alpn-01"), "ACME challenge type: tls-alpn-01 (default, port 443 only) or http-01 (requires port 80)")
rootCmd.Flags().BoolVar(&debugEndpoint, "debug-endpoint", envBoolOrDefault("NB_PROXY_DEBUG_ENDPOINT", false), "Enable debug HTTP endpoint")
rootCmd.Flags().StringVar(&debugEndpointAddr, "debug-endpoint-addr", envStringOrDefault("NB_PROXY_DEBUG_ENDPOINT_ADDRESS", "localhost:8444"), "Address for the debug HTTP endpoint")
rootCmd.Flags().StringVar(&healthAddr, "health-addr", envStringOrDefault("NB_PROXY_HEALTH_ADDRESS", "localhost:8080"), "Address for the health probe endpoint (liveness/readiness/startup)")
@@ -126,6 +131,11 @@ func runServer(cmd *cobra.Command, args []string) error {
return fmt.Errorf("invalid --forwarded-proto value %q: must be auto, http, or https", forwardedProto)
}
_, err := domain.ValidateDomains([]string{proxyDomain})
if err != nil {
return fmt.Errorf("invalid domain value %q: %w", proxyDomain, err)
}
parsedTrustedProxies, err := proxy.ParseTrustedProxies(trustedProxies)
if err != nil {
return fmt.Errorf("invalid --trusted-proxies: %w", err)
@@ -135,7 +145,7 @@ func runServer(cmd *cobra.Command, args []string) error {
Logger: logger,
Version: Version,
ManagementAddress: mgmtAddr,
ProxyURL: proxyURL,
ProxyURL: proxyDomain,
ProxyToken: proxyToken,
CertificateDirectory: certDir,
CertificateFile: certFile,
@@ -143,6 +153,7 @@ func runServer(cmd *cobra.Command, args []string) error {
GenerateACMECertificates: acmeCerts,
ACMEChallengeAddress: acmeAddr,
ACMEDirectory: acmeDir,
ACMEChallengeType: acmeChallengeType,
DebugEndpointEnabled: debugEndpoint,
DebugEndpointAddress: debugEndpointAddr,
HealthAddress: healthAddr,
@@ -160,7 +171,8 @@ func runServer(cmd *cobra.Command, args []string) error {
defer stop()
if err := srv.ListenAndServe(ctx, addr); err != nil {
log.Fatal(err)
log.Error(err)
return err
}
return nil
}

View File

@@ -0,0 +1,94 @@
package proxy
import (
"context"
"io"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/shared/management/proto"
)
type mockMappingStream struct {
grpc.ClientStream
messages []*proto.GetMappingUpdateResponse
idx int
}
func (m *mockMappingStream) Recv() (*proto.GetMappingUpdateResponse, error) {
if m.idx >= len(m.messages) {
return nil, io.EOF
}
msg := m.messages[m.idx]
m.idx++
return msg, nil
}
func (m *mockMappingStream) Header() (metadata.MD, error) {
return nil, nil //nolint:nilnil
}
func (m *mockMappingStream) Trailer() metadata.MD { return nil }
func (m *mockMappingStream) CloseSend() error { return nil }
func (m *mockMappingStream) Context() context.Context { return context.Background() }
func (m *mockMappingStream) SendMsg(any) error { return nil }
func (m *mockMappingStream) RecvMsg(any) error { return nil }
func TestHandleMappingStream_SyncCompleteFlag(t *testing.T) {
checker := health.NewChecker(nil, nil)
s := &Server{
Logger: log.StandardLogger(),
healthChecker: checker,
}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{InitialSyncComplete: true},
},
}
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.True(t, syncDone, "initial sync should be marked done when flag is set")
}
func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) {
checker := health.NewChecker(nil, nil)
s := &Server{
Logger: log.StandardLogger(),
healthChecker: checker,
}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{}, // no sync flag
},
}
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.False(t, syncDone, "initial sync should not be marked done without flag")
}
func TestHandleMappingStream_NilHealthChecker(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{InitialSyncComplete: true},
},
}
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.True(t, syncDone, "sync done flag should be set even without health checker")
}

View File

@@ -3,11 +3,13 @@ package accesslog
import (
"context"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -62,7 +64,12 @@ func (l *Logger) log(ctx context.Context, entry logEntry) {
// allow for resolving that on the server.
now := timestamppb.Now() // Grab the timestamp before launching the goroutine to try to prevent weird timing issues. This is probably unnecessary.
go func() {
if _, err := l.client.SendAccessLog(context.Background(), &proto.SendAccessLogRequest{
logCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if entry.AuthMechanism != auth.MethodOIDC.String() {
entry.UserId = ""
}
if _, err := l.client.SendAccessLog(logCtx, &proto.SendAccessLogRequest{
Log: &proto.AccessLog{
LogId: entry.ID,
AccountId: entry.AccountID,

View File

@@ -23,7 +23,7 @@ import (
var oidSCTList = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 2}
type certificateNotifier interface {
NotifyCertificateIssued(ctx context.Context, accountID, reverseProxyID, domain string) error
NotifyCertificateIssued(ctx context.Context, accountID, serviceID, domain string) error
}
type domainState int
@@ -35,10 +35,10 @@ const (
)
type domainInfo struct {
accountID string
reverseProxyID string
state domainState
err string
accountID string
serviceID string
state domainState
err string
}
// Manager wraps autocert.Manager with domain tracking and cross-replica
@@ -95,12 +95,12 @@ func (mgr *Manager) hostPolicy(_ context.Context, host string) error {
}
// AddDomain registers a domain for ACME certificate prefetching.
func (mgr *Manager) AddDomain(d domain.Domain, accountID, reverseProxyID string) {
func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) {
mgr.mu.Lock()
mgr.domains[d] = &domainInfo{
accountID: accountID,
reverseProxyID: reverseProxyID,
state: domainPending,
accountID: accountID,
serviceID: serviceID,
state: domainPending,
}
mgr.mu.Unlock()
@@ -164,7 +164,7 @@ func (mgr *Manager) prefetchCertificate(d domain.Domain) {
mgr.mu.RUnlock()
if info != nil && mgr.certNotifier != nil {
if err := mgr.certNotifier.NotifyCertificateIssued(ctx, info.accountID, info.reverseProxyID, name); err != nil {
if err := mgr.certNotifier.NotifyCertificateIssued(ctx, info.accountID, info.serviceID, name); err != nil {
mgr.logger.Warnf("notify certificate ready for domain %q: %v", name, err)
}
}

View File

@@ -1,40 +0,0 @@
package auth
import (
"context"
"github.com/netbirdio/netbird/proxy/auth"
)
type requestContextKey string
const (
authMethodKey requestContextKey = "authMethod"
authUserKey requestContextKey = "authUser"
)
func withAuthMethod(ctx context.Context, method auth.Method) context.Context {
return context.WithValue(ctx, authMethodKey, method)
}
func MethodFromContext(ctx context.Context) auth.Method {
v := ctx.Value(authMethodKey)
method, ok := v.(auth.Method)
if !ok {
return ""
}
return method
}
func withAuthUser(ctx context.Context, userId string) context.Context {
return context.WithValue(ctx, authUserKey, userId)
}
func UserFromContext(ctx context.Context) string {
v := ctx.Value(authUserKey)
userId, ok := v.(string)
if !ok {
return ""
}
return userId
}

View File

@@ -30,15 +30,14 @@ type SessionValidator interface {
ValidateSession(ctx context.Context, in *proto.ValidateSessionRequest, opts ...grpc.CallOption) (*proto.ValidateSessionResponse, error)
}
// Scheme defines an authentication mechanism for a domain.
type Scheme interface {
Type() auth.Method
// Authenticate should check the passed request and determine whether
// it represents an authenticated user request. If it does not, then
// an empty string should indicate an unauthenticated request which
// will be rejected; optionally, it can also return any data that should
// be included in a UI template when prompting the user to authenticate.
// If the request is authenticated, then a session token should be returned.
Authenticate(*http.Request) (token string, promptData string)
// Authenticate checks the request and determines whether it represents
// an authenticated user. An empty token indicates an unauthenticated
// request; optionally, promptData may be returned for the login UI.
// An error indicates an infrastructure failure (e.g. gRPC unavailable).
Authenticate(*http.Request) (token string, promptData string, err error)
}
type DomainConfig struct {
@@ -141,7 +140,15 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
methods := make(map[string]string)
var attemptedMethod string
for _, scheme := range config.Schemes {
token, promptData := scheme.Authenticate(r)
token, promptData, err := scheme.Authenticate(r)
if err != nil {
mw.logger.WithField("scheme", scheme.Type().String()).Warnf("authentication infrastructure error: %v", err)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
}
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
return
}
// Track if credentials were submitted but auth failed
if token == "" && wasCredentialSubmitted(r, scheme.Type()) {

View File

@@ -32,16 +32,16 @@ type stubScheme struct {
method auth.Method
token string
promptID string
authFn func(*http.Request) (string, string)
authFn func(*http.Request) (string, string, error)
}
func (s *stubScheme) Type() auth.Method { return s.method }
func (s *stubScheme) Authenticate(r *http.Request) (string, string) {
func (s *stubScheme) Authenticate(r *http.Request) (string, string, error) {
if s.authFn != nil {
return s.authFn(r)
}
return s.token, s.promptID
return s.token, s.promptID, nil
}
func newPassthroughHandler() http.Handler {
@@ -344,11 +344,11 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
scheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(r *http.Request) (string, string) {
authFn: func(r *http.Request) (string, string, error) {
if r.FormValue("pin") == "111111" {
return token, ""
return token, "", nil
}
return "", "pin"
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
@@ -391,8 +391,8 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
scheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(_ *http.Request) (string, string) {
return "", "pin"
authFn: func(_ *http.Request) (string, string, error) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
@@ -418,17 +418,17 @@ func TestProtect_MultipleSchemes(t *testing.T) {
// First scheme (PIN) always fails, second scheme (password) succeeds.
pinScheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(_ *http.Request) (string, string) {
return "", "pin"
authFn: func(_ *http.Request) (string, string, error) {
return "", "pin", nil
},
}
passwordScheme := &stubScheme{
method: auth.MethodPassword,
authFn: func(r *http.Request) (string, string) {
authFn: func(r *http.Request) (string, string, error) {
if r.FormValue("password") == "secret" {
return token, ""
return token, "", nil
}
return "", "password"
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", ""))
@@ -457,8 +457,8 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
// Return a garbage token that won't validate.
scheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(_ *http.Request) (string, string) {
return "invalid-jwt-token", ""
authFn: func(_ *http.Request) (string, string, error) {
return "invalid-jwt-token", "", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
@@ -517,8 +517,8 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
// Scheme that always fails authentication (returns empty token)
scheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(_ *http.Request) (string, string) {
return "", "pin"
authFn: func(_ *http.Request) (string, string, error) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
@@ -544,8 +544,8 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
scheme := &stubScheme{
method: auth.MethodPassword,
authFn: func(_ *http.Request) (string, string) {
return "", "password"
authFn: func(_ *http.Request) (string, string, error) {
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
@@ -571,8 +571,8 @@ func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
scheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(_ *http.Request) (string, string) {
return "", "pin"
authFn: func(_ *http.Request) (string, string, error) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))

View File

@@ -2,6 +2,7 @@ package auth
import (
"context"
"fmt"
"net/http"
"net/url"
@@ -36,12 +37,13 @@ func (OIDC) Type() auth.Method {
return auth.MethodOIDC
}
func (o OIDC) Authenticate(r *http.Request) (string, string) {
// Authenticate checks for an OIDC session token or obtains the OIDC redirect URL.
func (o OIDC) Authenticate(r *http.Request) (string, string, error) {
// Check for the session_token query param (from OIDC redirects).
// The management server passes the token in the URL because it cannot set
// cookies for the proxy's domain (cookies are domain-scoped per RFC 6265).
if token := r.URL.Query().Get("session_token"); token != "" {
return token, ""
return token, "", nil
}
redirectURL := &url.URL{
@@ -56,9 +58,8 @@ func (o OIDC) Authenticate(r *http.Request) (string, string) {
RedirectUrl: redirectURL.String(),
})
if err != nil {
// TODO: log
return "", ""
return "", "", fmt.Errorf("get OIDC URL: %w", err)
}
return "", res.GetUrl()
return "", res.GetUrl(), nil
}

View File

@@ -1,6 +1,7 @@
package auth
import (
"fmt"
"net/http"
"github.com/netbirdio/netbird/proxy/auth"
@@ -31,12 +32,12 @@ func (Password) Type() auth.Method {
// If authentication fails, the required HTTP form ID is returned
// so that it can be injected into a request from the UI so that
// authentication may be successful.
func (p Password) Authenticate(r *http.Request) (string, string) {
func (p Password) Authenticate(r *http.Request) (string, string, error) {
password := r.FormValue(passwordFormId)
if password == "" {
// This cannot be authenticated, so not worth wasting time sending the request.
return "", passwordFormId
// No password submitted; return the form ID so the UI can prompt the user.
return "", passwordFormId, nil
}
res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
@@ -49,13 +50,12 @@ func (p Password) Authenticate(r *http.Request) (string, string) {
},
})
if err != nil {
// TODO: log error here
return "", passwordFormId
return "", "", fmt.Errorf("authenticate password: %w", err)
}
if res.GetSuccess() {
return res.GetSessionToken(), ""
return res.GetSessionToken(), "", nil
}
return "", passwordFormId
return "", passwordFormId, nil
}

View File

@@ -1,6 +1,7 @@
package auth
import (
"fmt"
"net/http"
"github.com/netbirdio/netbird/proxy/auth"
@@ -31,12 +32,12 @@ func (Pin) Type() auth.Method {
// If authentication fails, the required HTTP form ID is returned
// so that it can be injected into a request from the UI so that
// authentication may be successful.
func (p Pin) Authenticate(r *http.Request) (string, string) {
func (p Pin) Authenticate(r *http.Request) (string, string, error) {
pin := r.FormValue(pinFormId)
if pin == "" {
// This cannot be authenticated, so not worth wasting time sending the request.
return "", pinFormId
// No PIN submitted; return the form ID so the UI can prompt the user.
return "", pinFormId, nil
}
res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
@@ -49,13 +50,12 @@ func (p Pin) Authenticate(r *http.Request) (string, string) {
},
})
if err != nil {
// TODO: log error here
return "", pinFormId
return "", "", fmt.Errorf("authenticate pin: %w", err)
}
if res.GetSuccess() {
return res.GetSessionToken(), ""
return res.GetSessionToken(), "", nil
}
return "", pinFormId
return "", pinFormId, nil
}

View File

@@ -88,8 +88,9 @@ func (c *Client) printHealth(data map[string]any) {
return
}
_, _ = fmt.Fprintf(c.out, "\n%-38s %-9s %-7s %-8s %s\n", "ACCOUNT ID", "HEALTHY", "MGMT", "SIGNAL", "RELAYS")
_, _ = fmt.Fprintln(c.out, strings.Repeat("-", 80))
_, _ = fmt.Fprintf(c.out, "\n%-38s %-9s %-7s %-8s %-8s %-16s %s\n",
"ACCOUNT ID", "HEALTHY", "MGMT", "SIGNAL", "RELAYS", "PEERS (P2P/RLY)", "DEGRADED")
_, _ = fmt.Fprintln(c.out, strings.Repeat("-", 110))
for accountID, v := range clients {
ch, ok := v.(map[string]any)
@@ -105,7 +106,15 @@ func (c *Client) printHealth(data map[string]any) {
relaysTotal, _ := ch["relays_total"].(float64)
relays := fmt.Sprintf("%d/%d", int(relaysConn), int(relaysTotal))
_, _ = fmt.Fprintf(c.out, "%-38s %-9s %-7s %-8s %s", accountID, healthy, mgmt, signal, relays)
peersConnected, _ := ch["peers_connected"].(float64)
peersTotal, _ := ch["peers_total"].(float64)
peersP2P, _ := ch["peers_p2p"].(float64)
peersRelayed, _ := ch["peers_relayed"].(float64)
peersDegraded, _ := ch["peers_degraded"].(float64)
peers := fmt.Sprintf("%d/%d (%d/%d)", int(peersConnected), int(peersTotal), int(peersP2P), int(peersRelayed))
degraded := fmt.Sprintf("%d", int(peersDegraded))
_, _ = fmt.Fprintf(c.out, "%-38s %-9s %-7s %-8s %-8s %-16s %s", accountID, healthy, mgmt, signal, relays, peers, degraded)
if errMsg, ok := ch["error"].(string); ok && errMsg != "" {
_, _ = fmt.Fprintf(c.out, " (%s)", errMsg)
}

View File

@@ -648,7 +648,8 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON
allHealthy, clientHealth := h.health.CheckClientsConnected(r.Context())
status := "ok"
if !ready || !allHealthy {
// No clients is not a health issue; only degrade when actual clients are unhealthy
if !ready || (!allHealthy && len(clientHealth) > 0) {
status = "degraded"
}

View File

@@ -1,6 +1,6 @@
{{define "clientDetail"}}
<!DOCTYPE html>
<html>
<html lang="en">
<head>
<title>Client {{.AccountID}}</title>
<style>{{template "style"}}</style>

Some files were not shown because too many files have changed in this diff Show More