mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
Compare commits
41 Commits
go-dns-for
...
zitadel-id
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a369357a85 | ||
|
|
537151e0f3 | ||
|
|
a9c28ef723 | ||
|
|
c29bb1a289 | ||
|
|
447cd287f5 | ||
|
|
5748bdd64e | ||
|
|
08f31fbcb3 | ||
|
|
932c02eaab | ||
|
|
abcbde26f9 | ||
|
|
90e3b8009f | ||
|
|
94d34dc0c5 | ||
|
|
44851e06fb | ||
|
|
3f4f825ec1 | ||
|
|
f538e6e9ae | ||
|
|
cb6b086164 | ||
|
|
71b6855e09 | ||
|
|
9bdc4908fb | ||
|
|
031ab11178 | ||
|
|
d2e48d4f5e | ||
|
|
27dd97c9c4 | ||
|
|
e87b4ace11 | ||
|
|
a232cf614c | ||
|
|
a293f760af | ||
|
|
10e9cf8c62 | ||
|
|
7193bd2da7 | ||
|
|
52948ccd61 | ||
|
|
4b77359042 | ||
|
|
387d43bcc1 | ||
|
|
e47d815dd2 | ||
|
|
cb83b7c0d3 | ||
|
|
ddcd182859 | ||
|
|
aca0398105 | ||
|
|
02200d790b | ||
|
|
f31bba87b4 | ||
|
|
7285fef0f0 | ||
|
|
20973063d8 | ||
|
|
ba2e9b6d88 | ||
|
|
131d7a3694 | ||
|
|
290fe2d8b9 | ||
|
|
7fb1a2fe31 | ||
|
|
32146e576d |
11
.githooks/pre-push
Executable file
11
.githooks/pre-push
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
|
||||
echo "Running pre-push hook..."
|
||||
if ! make lint; then
|
||||
echo ""
|
||||
echo "Hint: To push without verification, run:"
|
||||
echo " git push --no-verify"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "All checks passed!"
|
||||
7
.github/workflows/golang-test-darwin.yml
vendored
7
.github/workflows/golang-test-darwin.yml
vendored
@@ -15,13 +15,14 @@ jobs:
|
||||
name: "Client / Unit"
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
|
||||
2
.github/workflows/golang-test-freebsd.yml
vendored
2
.github/workflows/golang-test-freebsd.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
release: "14.2"
|
||||
prepare: |
|
||||
pkg install -y curl pkgconf xorg
|
||||
GO_TARBALL="go1.23.12.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -vLO "$GO_URL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
|
||||
71
.github/workflows/golang-test-linux.yml
vendored
71
.github/workflows/golang-test-linux.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Get Go environment
|
||||
@@ -106,15 +106,15 @@ jobs:
|
||||
arch: [ '386','amd64' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
@@ -151,15 +151,15 @@ jobs:
|
||||
needs: [ build-cache ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
id: go-env
|
||||
run: |
|
||||
@@ -200,7 +200,7 @@ jobs:
|
||||
-e GOCACHE=${CONTAINER_GOCACHE} \
|
||||
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
||||
-e CONTAINER=${CONTAINER} \
|
||||
golang:1.23-alpine \
|
||||
golang:1.24-alpine \
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
@@ -220,15 +220,15 @@ jobs:
|
||||
raceFlag: "-race"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||
@@ -270,15 +270,15 @@ jobs:
|
||||
arch: [ '386','amd64' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||
@@ -321,15 +321,15 @@ jobs:
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
@@ -408,15 +408,16 @@ jobs:
|
||||
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
||||
-p 9090:9090 \
|
||||
prom/prometheus
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
@@ -497,15 +498,15 @@ jobs:
|
||||
-p 9090:9090 \
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
@@ -561,15 +562,15 @@ jobs:
|
||||
store: [ 'sqlite', 'postgres']
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
|
||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
uses: actions/setup-go@v5
|
||||
id: go
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Get Go environment
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Install dependencies
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v3
|
||||
with:
|
||||
@@ -39,7 +39,7 @@ jobs:
|
||||
- name: Setup NDK
|
||||
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
|
||||
- name: gomobile init
|
||||
run: gomobile init
|
||||
- name: build android netbird lib
|
||||
@@ -56,9 +56,9 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
|
||||
- name: gomobile init
|
||||
run: gomobile init
|
||||
- name: build iOS netbird lib
|
||||
|
||||
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
@@ -20,7 +20,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest-m
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
@@ -136,7 +136,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
@@ -200,7 +200,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
|
||||
@@ -67,10 +67,13 @@ jobs:
|
||||
- name: Install curl
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
@@ -80,9 +83,6 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup MySQL privileges
|
||||
if: matrix.store == 'mysql'
|
||||
run: |
|
||||
|
||||
8
.github/workflows/wasm-build-validation.yml
vendored
8
.github/workflows/wasm-build-validation.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Install golangci-lint
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
- name: Build Wasm client
|
||||
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
||||
env:
|
||||
@@ -60,8 +60,8 @@ jobs:
|
||||
|
||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||
|
||||
if [ ${SIZE} -gt 52428800 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
|
||||
if [ ${SIZE} -gt 57671680 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
@@ -136,6 +136,14 @@ checked out and set up:
|
||||
go mod tidy
|
||||
```
|
||||
|
||||
6. Configure Git hooks for automatic linting:
|
||||
|
||||
```bash
|
||||
make setup-hooks
|
||||
```
|
||||
|
||||
This will configure Git to run linting automatically before each push, helping catch issues early.
|
||||
|
||||
### Dev Container Support
|
||||
|
||||
If you prefer using a dev container for development, NetBird now includes support for dev containers.
|
||||
|
||||
27
Makefile
Normal file
27
Makefile
Normal file
@@ -0,0 +1,27 @@
|
||||
.PHONY: lint lint-all lint-install setup-hooks
|
||||
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
||||
|
||||
# Install golangci-lint locally if needed
|
||||
$(GOLANGCI_LINT):
|
||||
@echo "Installing golangci-lint..."
|
||||
@mkdir -p ./bin
|
||||
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
# Lint only changed files (fast, for pre-push)
|
||||
lint: $(GOLANGCI_LINT)
|
||||
@echo "Running lint on changed files..."
|
||||
@$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m
|
||||
|
||||
# Lint entire codebase (slow, matches CI)
|
||||
lint-all: $(GOLANGCI_LINT)
|
||||
@echo "Running lint on all files..."
|
||||
@$(GOLANGCI_LINT) run --timeout=12m
|
||||
|
||||
# Just install the linter
|
||||
lint-install: $(GOLANGCI_LINT)
|
||||
|
||||
# Setup git hooks for all developers
|
||||
setup-hooks:
|
||||
@git config core.hooksPath .githooks
|
||||
@chmod +x .githooks/pre-push
|
||||
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
||||
@@ -4,10 +4,13 @@ package android
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
@@ -16,10 +19,13 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
@@ -62,17 +68,18 @@ type Client struct {
|
||||
deviceName string
|
||||
uiVersion string
|
||||
networkChangeListener listener.NetworkChangeListener
|
||||
stateFile string
|
||||
|
||||
connectClient *internal.ConnectClient
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||
func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||
execWorkaround(androidSDKVersion)
|
||||
|
||||
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
||||
return &Client{
|
||||
cfgFile: cfgFile,
|
||||
cfgFile: platformFiles.ConfigurationFilePath(),
|
||||
deviceName: deviceName,
|
||||
uiVersion: uiVersion,
|
||||
tunAdapter: tunAdapter,
|
||||
@@ -80,11 +87,12 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
|
||||
recorder: peer.NewRecorder(""),
|
||||
ctxCancelLock: &sync.Mutex{},
|
||||
networkChangeListener: networkChangeListener,
|
||||
stateFile: platformFiles.StateFilePath(),
|
||||
}
|
||||
}
|
||||
|
||||
// Run start the internal client. It is a blocker function
|
||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||
func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||
exportEnvList(envList)
|
||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
@@ -107,7 +115,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
||||
c.ctxCancelLock.Unlock()
|
||||
|
||||
auth := NewAuthWithConfig(ctx, cfg)
|
||||
err = auth.login(urlOpener)
|
||||
err = auth.login(urlOpener, isAndroidTV)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -115,7 +123,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
|
||||
}
|
||||
|
||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||
@@ -142,7 +150,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
@@ -156,6 +164,19 @@ func (c *Client) Stop() {
|
||||
c.ctxCancel()
|
||||
}
|
||||
|
||||
func (c *Client) RenewTun(fd int) error {
|
||||
if c.connectClient == nil {
|
||||
return fmt.Errorf("engine not running")
|
||||
}
|
||||
|
||||
e := c.connectClient.Engine()
|
||||
if e == nil {
|
||||
return fmt.Errorf("engine not initialized")
|
||||
}
|
||||
|
||||
return e.RenewTun(fd)
|
||||
}
|
||||
|
||||
// SetTraceLogLevel configure the logger to trace level
|
||||
func (c *Client) SetTraceLogLevel() {
|
||||
log.SetLevel(log.TraceLevel)
|
||||
@@ -177,6 +198,7 @@ func (c *Client) PeersList() *PeerInfoArray {
|
||||
p.IP,
|
||||
p.FQDN,
|
||||
p.ConnStatus.String(),
|
||||
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
||||
}
|
||||
peerInfos[n] = pi
|
||||
}
|
||||
@@ -201,31 +223,43 @@ func (c *Client) Networks() *NetworkArray {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeSelector := routeManager.GetRouteSelector()
|
||||
if routeSelector == nil {
|
||||
log.Error("could not get route selector")
|
||||
return nil
|
||||
}
|
||||
|
||||
networkArray := &NetworkArray{
|
||||
items: make([]Network, 0),
|
||||
}
|
||||
|
||||
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||
|
||||
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
||||
if len(routes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
r := routes[0]
|
||||
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
|
||||
netStr := r.Network.String()
|
||||
|
||||
if r.IsDynamic() {
|
||||
netStr = r.Domains.SafeString()
|
||||
}
|
||||
|
||||
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||
if err != nil {
|
||||
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
||||
continue
|
||||
}
|
||||
network := Network{
|
||||
Name: string(id),
|
||||
Network: netStr,
|
||||
Peer: peer.FQDN,
|
||||
Status: peer.ConnStatus.String(),
|
||||
Name: string(id),
|
||||
Network: netStr,
|
||||
Peer: routePeer.FQDN,
|
||||
Status: routePeer.ConnStatus.String(),
|
||||
IsSelected: routeSelector.IsSelected(id),
|
||||
Domains: domains,
|
||||
}
|
||||
networkArray.Add(network)
|
||||
}
|
||||
@@ -253,6 +287,69 @@ func (c *Client) RemoveConnectionListener() {
|
||||
c.recorder.RemoveConnectionListener()
|
||||
}
|
||||
|
||||
func (c *Client) toggleRoute(command routeCommand) error {
|
||||
return command.toggleRoute()
|
||||
}
|
||||
|
||||
func (c *Client) getRouteManager() (routemanager.Manager, error) {
|
||||
client := c.connectClient
|
||||
if client == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := client.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("engine is not running")
|
||||
}
|
||||
|
||||
manager := engine.GetRouteManager()
|
||||
if manager == nil {
|
||||
return nil, fmt.Errorf("could not get route manager")
|
||||
}
|
||||
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
func (c *Client) SelectRoute(route string) error {
|
||||
manager, err := c.getRouteManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.toggleRoute(selectRouteCommand{route: route, manager: manager})
|
||||
}
|
||||
|
||||
func (c *Client) DeselectRoute(route string) error {
|
||||
manager, err := c.getRouteManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.toggleRoute(deselectRouteCommand{route: route, manager: manager})
|
||||
}
|
||||
|
||||
// getNetworkDomainsFromRoute extracts domains from a route and enriches each domain
|
||||
// with its resolved IP addresses from the provided resolvedDomains map.
|
||||
func (c *Client) getNetworkDomainsFromRoute(route *route.Route, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) NetworkDomains {
|
||||
domains := NetworkDomains{}
|
||||
|
||||
for _, d := range route.Domains {
|
||||
networkDomain := NetworkDomain{
|
||||
Address: d.SafeString(),
|
||||
}
|
||||
|
||||
if info, exists := resolvedDomains[d]; exists {
|
||||
for _, prefix := range info.Prefixes {
|
||||
networkDomain.addResolvedIP(prefix.Addr().String())
|
||||
}
|
||||
}
|
||||
|
||||
domains.Add(&networkDomain)
|
||||
}
|
||||
|
||||
return domains
|
||||
}
|
||||
|
||||
func exportEnvList(list *EnvList) {
|
||||
if list == nil {
|
||||
return
|
||||
|
||||
@@ -32,7 +32,7 @@ type ErrListener interface {
|
||||
// URLOpener it is a callback interface. The Open function will be triggered if
|
||||
// the backend want to show an url for the user
|
||||
type URLOpener interface {
|
||||
Open(string)
|
||||
Open(url string, userCode string)
|
||||
OnLoginSuccess()
|
||||
}
|
||||
|
||||
@@ -148,9 +148,9 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
||||
}
|
||||
|
||||
// Login try register the client on the server
|
||||
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
|
||||
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidTV bool) {
|
||||
go func() {
|
||||
err := a.login(urlOpener)
|
||||
err := a.login(urlOpener, isAndroidTV)
|
||||
if err != nil {
|
||||
resultListener.OnError(err)
|
||||
} else {
|
||||
@@ -159,7 +159,7 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
|
||||
}()
|
||||
}
|
||||
|
||||
func (a *Auth) login(urlOpener URLOpener) error {
|
||||
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||
var needsLogin bool
|
||||
|
||||
// check if we need to generate JWT token
|
||||
@@ -173,7 +173,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
|
||||
jwtToken := ""
|
||||
if needsLogin {
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener)
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
@@ -199,8 +199,8 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -210,7 +210,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, err
|
||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||
}
|
||||
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
||||
|
||||
56
client/android/network_domains.go
Normal file
56
client/android/network_domains.go
Normal file
@@ -0,0 +1,56 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import "fmt"
|
||||
|
||||
type ResolvedIPs struct {
|
||||
resolvedIPs []string
|
||||
}
|
||||
|
||||
func (r *ResolvedIPs) Add(ipAddress string) {
|
||||
r.resolvedIPs = append(r.resolvedIPs, ipAddress)
|
||||
}
|
||||
|
||||
func (r *ResolvedIPs) Get(i int) (string, error) {
|
||||
if i < 0 || i >= len(r.resolvedIPs) {
|
||||
return "", fmt.Errorf("%d is out of range", i)
|
||||
}
|
||||
return r.resolvedIPs[i], nil
|
||||
}
|
||||
|
||||
func (r *ResolvedIPs) Size() int {
|
||||
return len(r.resolvedIPs)
|
||||
}
|
||||
|
||||
type NetworkDomain struct {
|
||||
Address string
|
||||
resolvedIPs ResolvedIPs
|
||||
}
|
||||
|
||||
func (d *NetworkDomain) addResolvedIP(resolvedIP string) {
|
||||
d.resolvedIPs.Add(resolvedIP)
|
||||
}
|
||||
|
||||
func (d *NetworkDomain) GetResolvedIPs() *ResolvedIPs {
|
||||
return &d.resolvedIPs
|
||||
}
|
||||
|
||||
type NetworkDomains struct {
|
||||
domains []*NetworkDomain
|
||||
}
|
||||
|
||||
func (n *NetworkDomains) Add(domain *NetworkDomain) {
|
||||
n.domains = append(n.domains, domain)
|
||||
}
|
||||
|
||||
func (n *NetworkDomains) Get(i int) (*NetworkDomain, error) {
|
||||
if i < 0 || i >= len(n.domains) {
|
||||
return nil, fmt.Errorf("%d is out of range", i)
|
||||
}
|
||||
return n.domains[i], nil
|
||||
}
|
||||
|
||||
func (n *NetworkDomains) Size() int {
|
||||
return len(n.domains)
|
||||
}
|
||||
@@ -3,10 +3,16 @@
|
||||
package android
|
||||
|
||||
type Network struct {
|
||||
Name string
|
||||
Network string
|
||||
Peer string
|
||||
Status string
|
||||
Name string
|
||||
Network string
|
||||
Peer string
|
||||
Status string
|
||||
IsSelected bool
|
||||
Domains NetworkDomains
|
||||
}
|
||||
|
||||
func (n Network) GetNetworkDomains() *NetworkDomains {
|
||||
return &n.Domains
|
||||
}
|
||||
|
||||
type NetworkArray struct {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||
@@ -5,6 +7,11 @@ type PeerInfo struct {
|
||||
IP string
|
||||
FQDN string
|
||||
ConnStatus string // Todo replace to enum
|
||||
Routes PeerRoutes
|
||||
}
|
||||
|
||||
func (p *PeerInfo) GetPeerRoutes() *PeerRoutes {
|
||||
return &p.Routes
|
||||
}
|
||||
|
||||
// PeerInfoArray is a wrapper of []PeerInfo
|
||||
|
||||
20
client/android/peer_routes.go
Normal file
20
client/android/peer_routes.go
Normal file
@@ -0,0 +1,20 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import "fmt"
|
||||
|
||||
type PeerRoutes struct {
|
||||
routes []string
|
||||
}
|
||||
|
||||
func (p *PeerRoutes) Get(i int) (string, error) {
|
||||
if i < 0 || i >= len(p.routes) {
|
||||
return "", fmt.Errorf("%d is out of range", i)
|
||||
}
|
||||
return p.routes[i], nil
|
||||
}
|
||||
|
||||
func (p *PeerRoutes) Size() int {
|
||||
return len(p.routes)
|
||||
}
|
||||
10
client/android/platform_files.go
Normal file
10
client/android/platform_files.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
// PlatformFiles groups paths to files used internally by the engine that can't be created/modified
|
||||
// at their default locations due to android OS restrictions.
|
||||
type PlatformFiles interface {
|
||||
ConfigurationFilePath() string
|
||||
StateFilePath() string
|
||||
}
|
||||
67
client/android/route_command.go
Normal file
67
client/android/route_command.go
Normal file
@@ -0,0 +1,67 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func executeRouteToggle(id string, manager routemanager.Manager,
|
||||
operationName string,
|
||||
routeOperation func(routes []route.NetID, allRoutes []route.NetID) error) error {
|
||||
netID := route.NetID(id)
|
||||
routes := []route.NetID{netID}
|
||||
|
||||
log.Debugf("%s with id: %s", operationName, id)
|
||||
|
||||
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
|
||||
log.Debugf("error when %s: %s", operationName, err)
|
||||
return fmt.Errorf("error %s: %w", operationName, err)
|
||||
}
|
||||
|
||||
manager.TriggerSelection(manager.GetClientRoutes())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type routeCommand interface {
|
||||
toggleRoute() error
|
||||
}
|
||||
|
||||
type selectRouteCommand struct {
|
||||
route string
|
||||
manager routemanager.Manager
|
||||
}
|
||||
|
||||
func (s selectRouteCommand) toggleRoute() error {
|
||||
routeSelector := s.manager.GetRouteSelector()
|
||||
if routeSelector == nil {
|
||||
return fmt.Errorf("no route selector available")
|
||||
}
|
||||
|
||||
routeOperation := func(routes []route.NetID, allRoutes []route.NetID) error {
|
||||
return routeSelector.SelectRoutes(routes, true, allRoutes)
|
||||
}
|
||||
|
||||
return executeRouteToggle(s.route, s.manager, "selecting route", routeOperation)
|
||||
}
|
||||
|
||||
type deselectRouteCommand struct {
|
||||
route string
|
||||
manager routemanager.Manager
|
||||
}
|
||||
|
||||
func (d deselectRouteCommand) toggleRoute() error {
|
||||
routeSelector := d.manager.GetRouteSelector()
|
||||
if routeSelector == nil {
|
||||
return fmt.Errorf("no route selector available")
|
||||
}
|
||||
|
||||
return executeRouteToggle(d.route, d.manager, "deselecting route", routeSelector.DeselectRoutes)
|
||||
}
|
||||
@@ -4,14 +4,12 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
@@ -332,7 +330,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
hint = profileState.Email
|
||||
}
|
||||
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), false, hint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -373,21 +371,13 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
||||
cmd.Println("")
|
||||
|
||||
if !noBrowser {
|
||||
if err := openBrowser(verificationURIComplete); err != nil {
|
||||
if err := util.OpenBrowser(verificationURIComplete); err != nil {
|
||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
|
||||
func openBrowser(url string) error {
|
||||
if browser := os.Getenv("BROWSER"); browser != "" {
|
||||
return exec.Command(browser, url).Start()
|
||||
}
|
||||
return open.Run(url)
|
||||
}
|
||||
|
||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
|
||||
@@ -51,6 +51,7 @@ var (
|
||||
identityFile string
|
||||
skipCachedToken bool
|
||||
requestPTY bool
|
||||
sshNoBrowser bool
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -81,6 +82,7 @@ func init() {
|
||||
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
|
||||
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
|
||||
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
sshCmd.PersistentFlags().BoolVar(&sshNoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
|
||||
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
||||
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
|
||||
@@ -185,6 +187,21 @@ func getEnvOrDefault(flagName, defaultValue string) string {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// getBoolEnvOrDefault checks for boolean environment variables with WT_ and NB_ prefixes
|
||||
func getBoolEnvOrDefault(flagName string, defaultValue bool) bool {
|
||||
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
|
||||
if parsed, err := strconv.ParseBool(envValue); err == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
|
||||
if parsed, err := strconv.ParseBool(envValue); err == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// resetSSHGlobals sets SSH globals to their default values
|
||||
func resetSSHGlobals() {
|
||||
port = sshserver.DefaultSSHPort
|
||||
@@ -196,6 +213,7 @@ func resetSSHGlobals() {
|
||||
strictHostKeyChecking = true
|
||||
knownHostsFile = ""
|
||||
identityFile = ""
|
||||
sshNoBrowser = false
|
||||
}
|
||||
|
||||
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
|
||||
@@ -370,6 +388,7 @@ type sshFlags struct {
|
||||
KnownHostsFile string
|
||||
IdentityFile string
|
||||
SkipCachedToken bool
|
||||
NoBrowser bool
|
||||
ConfigPath string
|
||||
LogLevel string
|
||||
LocalForwards []string
|
||||
@@ -381,6 +400,7 @@ type sshFlags struct {
|
||||
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
||||
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
||||
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||
defaultNoBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
|
||||
|
||||
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
||||
fs.SetOutput(nil)
|
||||
@@ -401,6 +421,7 @@ func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
||||
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
|
||||
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
|
||||
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
fs.BoolVar(&flags.NoBrowser, "no-browser", defaultNoBrowser, noBrowserDesc)
|
||||
|
||||
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
|
||||
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
|
||||
@@ -449,6 +470,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
knownHostsFile = flags.KnownHostsFile
|
||||
identityFile = flags.IdentityFile
|
||||
skipCachedToken = flags.SkipCachedToken
|
||||
sshNoBrowser = flags.NoBrowser
|
||||
|
||||
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
|
||||
configPath = flags.ConfigPath
|
||||
@@ -508,6 +530,7 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
||||
DaemonAddr: daemonAddr,
|
||||
SkipCachedToken: skipCachedToken,
|
||||
InsecureSkipVerify: !strictHostKeyChecking,
|
||||
NoBrowser: sshNoBrowser,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -763,7 +786,15 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("invalid port: %s", portStr)
|
||||
}
|
||||
|
||||
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
|
||||
// Check env var for browser setting since this command is invoked via SSH ProxyCommand
|
||||
// where command-line flags cannot be passed. Default is to open browser.
|
||||
noBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
|
||||
var browserOpener func(string) error
|
||||
if !noBrowser {
|
||||
browserOpener = util.OpenBrowser
|
||||
}
|
||||
|
||||
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr(), browserOpener)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create SSH proxy: %w", err)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
@@ -24,8 +26,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -116,15 +116,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), config)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
tableNat = "nat"
|
||||
tableNat = "nat"
|
||||
tableMangle = "mangle"
|
||||
tableRaw = "raw"
|
||||
tableSecurity = "security"
|
||||
|
||||
chainNameNatPrerouting = "PREROUTING"
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
@@ -91,11 +95,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
|
||||
var err error
|
||||
r.filterTable, err = r.loadFilterTable()
|
||||
if err != nil {
|
||||
if errors.Is(err, errFilterTableNotFound) {
|
||||
log.Warnf("table 'filter' not found for forward rules")
|
||||
} else {
|
||||
return nil, fmt.Errorf("load filter table: %w", err)
|
||||
}
|
||||
log.Debugf("ip filter table not found: %v", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
@@ -175,7 +175,7 @@ func (r *router) removeNatPreroutingRules() error {
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to list tables: %v", err)
|
||||
return nil, fmt.Errorf("list tables: %w", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
@@ -187,14 +187,39 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
return nil, errFilterTableNotFound
|
||||
}
|
||||
|
||||
func hookName(hook *nftables.ChainHook) string {
|
||||
if hook == nil {
|
||||
return "unknown"
|
||||
}
|
||||
switch *hook {
|
||||
case *nftables.ChainHookForward:
|
||||
return chainNameForward
|
||||
case *nftables.ChainHookInput:
|
||||
return chainNameInput
|
||||
default:
|
||||
return fmt.Sprintf("hook(%d)", *hook)
|
||||
}
|
||||
}
|
||||
|
||||
func familyName(family nftables.TableFamily) string {
|
||||
switch family {
|
||||
case nftables.TableFamilyIPv4:
|
||||
return "ip"
|
||||
case nftables.TableFamilyIPv6:
|
||||
return "ip6"
|
||||
case nftables.TableFamilyINet:
|
||||
return "inet"
|
||||
default:
|
||||
return fmt.Sprintf("family(%d)", family)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingFw,
|
||||
Table: r.workTable,
|
||||
})
|
||||
|
||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||
|
||||
prio := *nftables.ChainPriorityNATSource - 1
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
@@ -236,9 +261,12 @@ func (r *router) createContainers() error {
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
// Add the single NAT rule that matches on mark
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
return fmt.Errorf("add single nat rule: %v", err)
|
||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||
|
||||
r.addPostroutingRules()
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("initialize tables: %v", err)
|
||||
}
|
||||
|
||||
if err := r.addMSSClampingRules(); err != nil {
|
||||
@@ -250,11 +278,7 @@ func (r *router) createContainers() error {
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("initialize tables: %v", err)
|
||||
log.Errorf("failed to refresh rules: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -695,7 +719,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
// addPostroutingRules adds the masquerade rules
|
||||
func (r *router) addPostroutingRules() error {
|
||||
func (r *router) addPostroutingRules() {
|
||||
// First masquerade rule for traffic coming in from WireGuard interface
|
||||
exprs := []expr.Any{
|
||||
// Match on the first fwmark
|
||||
@@ -761,8 +785,6 @@ func (r *router) addPostroutingRules() error {
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs2,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
@@ -839,7 +861,7 @@ func (r *router) addMSSClampingRules() error {
|
||||
Exprs: exprsOut,
|
||||
})
|
||||
|
||||
return nil
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
@@ -939,8 +961,21 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
// This method also adds INPUT chain rules to allow traffic to the local interface.
|
||||
func (r *router) acceptForwardRules() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.acceptFilterTableRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.acceptExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) acceptFilterTableRules() error {
|
||||
if r.filterTable == nil {
|
||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -953,11 +988,11 @@ func (r *router) acceptForwardRules() error {
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
ipt, err := iptables.New()
|
||||
if err != nil {
|
||||
// filter table exists but iptables is not
|
||||
// iptables is not available but the filter table exists
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
|
||||
fw = "nftables"
|
||||
return r.acceptFilterRulesNftables()
|
||||
return r.acceptFilterRulesNftables(r.filterTable)
|
||||
}
|
||||
|
||||
return r.acceptFilterRulesIptables(ipt)
|
||||
@@ -968,7 +1003,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
|
||||
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables forward rule: %v", rule)
|
||||
}
|
||||
@@ -976,7 +1011,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
|
||||
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables input rule: %v", inputRule)
|
||||
}
|
||||
@@ -996,18 +1031,70 @@ func (r *router) getAcceptInputRule() []string {
|
||||
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
func (r *router) acceptFilterRulesNftables() error {
|
||||
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
|
||||
// This is used when iptables is not available.
|
||||
func (r *router) acceptFilterRulesNftables(table *nftables.Table) error {
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
forwardChain := &nftables.Chain{
|
||||
Name: chainNameForward,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
}
|
||||
r.insertForwardAcceptRules(forwardChain, intf)
|
||||
|
||||
inputChain := &nftables.Chain{
|
||||
Name: chainNameInput,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookInput,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
}
|
||||
r.insertInputAcceptRule(inputChain, intf)
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
|
||||
// It dynamically finds chains at call time to handle chains that may have been created after startup.
|
||||
func (r *router) acceptExternalChainsRules() error {
|
||||
chains := r.findExternalChains()
|
||||
if len(chains) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Hooknum == nil {
|
||||
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
|
||||
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
|
||||
|
||||
switch *chain.Hooknum {
|
||||
case *nftables.ChainHookForward:
|
||||
r.insertForwardAcceptRules(chain, intf)
|
||||
case *nftables.ChainHookInput:
|
||||
r.insertInputAcceptRule(chain, intf)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush external chain rules: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
|
||||
iifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameForward,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
@@ -1030,30 +1117,19 @@ func (r *router) acceptFilterRulesNftables() error {
|
||||
Data: intf,
|
||||
},
|
||||
}
|
||||
|
||||
oifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameForward,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||
}
|
||||
r.conn.InsertRule(oifRule)
|
||||
}
|
||||
|
||||
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
|
||||
inputRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameInput,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookInput,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
@@ -1067,32 +1143,44 @@ func (r *router) acceptFilterRulesNftables() error {
|
||||
UserData: []byte(userDataAcceptInputRule),
|
||||
}
|
||||
r.conn.InsertRule(inputRule)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptFilterRules() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.removeFilterTableRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.removeExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) removeFilterTableRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ipt, err := iptables.New()
|
||||
if err != nil {
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
return r.removeAcceptFilterRulesNftables()
|
||||
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||
}
|
||||
|
||||
return r.removeAcceptFilterRulesIptables(ipt)
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptFilterRulesNftables() error {
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %v", err)
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name {
|
||||
if chain.Table.Name != table.Name {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1100,27 +1188,101 @@ func (r *router) removeAcceptFilterRulesNftables() error {
|
||||
continue
|
||||
}
|
||||
|
||||
rules, err := r.conn.GetRules(r.filterTable, chain)
|
||||
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
|
||||
rules, err := r.conn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeExternalChainsRules removes our accept rules from all external chains.
|
||||
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
|
||||
// ensuring cleanup works even after a crash or if chains changed.
|
||||
func (r *router) removeExternalChainsRules() error {
|
||||
chains := r.findExternalChains()
|
||||
if len(chains) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
|
||||
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
|
||||
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
|
||||
func (r *router) findExternalChains() []*nftables.Chain {
|
||||
var chains []*nftables.Chain
|
||||
|
||||
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
|
||||
|
||||
for _, family := range families {
|
||||
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules: %v", err)
|
||||
log.Debugf("list chains for family %d: %v", family, err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule: %v", err)
|
||||
}
|
||||
for _, chain := range allChains {
|
||||
if r.isExternalChain(chain) {
|
||||
chains = append(chains, chain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
return chains
|
||||
}
|
||||
|
||||
func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
||||
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
|
||||
return false
|
||||
}
|
||||
|
||||
return nil
|
||||
// Skip all iptables-managed tables in the ip family
|
||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||
return false
|
||||
}
|
||||
|
||||
if chain.Type != nftables.ChainTypeFilter {
|
||||
return false
|
||||
}
|
||||
|
||||
if chain.Hooknum == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
|
||||
}
|
||||
|
||||
func isIptablesTable(name string) bool {
|
||||
switch name {
|
||||
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
@@ -1128,13 +1290,13 @@ func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
@@ -1196,7 +1358,7 @@ func (r *router) refreshRulesMap() error {
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(" unable to list rules: %v", err)
|
||||
return fmt.Errorf("list rules: %w", err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -19,11 +20,12 @@ import (
|
||||
|
||||
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
||||
type WGTunDevice struct {
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
mtu uint16
|
||||
iceBind *bind.ICEBind
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
mtu uint16
|
||||
iceBind *bind.ICEBind
|
||||
// todo: review if we can eliminate the TunAdapter
|
||||
tunAdapter TunAdapter
|
||||
disableDNS bool
|
||||
|
||||
@@ -32,17 +34,19 @@ type WGTunDevice struct {
|
||||
filteredDevice *FilteredDevice
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
configurer WGConfigurer
|
||||
renewableTun *RenewableTUN
|
||||
}
|
||||
|
||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
|
||||
return &WGTunDevice{
|
||||
address: address,
|
||||
port: port,
|
||||
key: key,
|
||||
mtu: mtu,
|
||||
iceBind: iceBind,
|
||||
tunAdapter: tunAdapter,
|
||||
disableDNS: disableDNS,
|
||||
address: address,
|
||||
port: port,
|
||||
key: key,
|
||||
mtu: mtu,
|
||||
iceBind: iceBind,
|
||||
tunAdapter: tunAdapter,
|
||||
disableDNS: disableDNS,
|
||||
renewableTun: NewRenewableTUN(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,14 +69,17 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
||||
unmonitoredTUN, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
log.Errorf("failed to create Android interface: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.renewableTun.AddDevice(unmonitoredTUN)
|
||||
|
||||
t.name = name
|
||||
t.filteredDevice = newDeviceFilter(tunDevice)
|
||||
t.filteredDevice = newDeviceFilter(t.renewableTun)
|
||||
|
||||
log.Debugf("attaching to interface %v", name)
|
||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
|
||||
@@ -104,6 +111,23 @@ func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *WGTunDevice) RenewTun(fd int) error {
|
||||
if t.device == nil {
|
||||
return fmt.Errorf("device not initialized")
|
||||
}
|
||||
|
||||
unmonitoredTUN, _, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
log.Errorf("failed to renew Android interface: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
t.renewableTun.AddDevice(unmonitoredTUN)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
|
||||
// todo implement
|
||||
return nil
|
||||
|
||||
@@ -2,6 +2,13 @@
|
||||
|
||||
package device
|
||||
|
||||
import "fmt"
|
||||
|
||||
func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
|
||||
return t.create()
|
||||
}
|
||||
|
||||
func (t *TunNetstackDevice) RenewTun(fd int) error {
|
||||
// Doesn't make sense in Android for Netstack.
|
||||
return fmt.Errorf("this function has not been implemented in Netstack for Android")
|
||||
}
|
||||
|
||||
309
client/iface/device/renewable_tun.go
Normal file
309
client/iface/device/renewable_tun.go
Normal file
@@ -0,0 +1,309 @@
|
||||
//go:build android
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
// closeAwareDevice wraps a tun.Device along with a flag
|
||||
// indicating whether its Close method was called.
|
||||
//
|
||||
// It also redirects tun.Device's Events() to a separate goroutine
|
||||
// and closes it when Close is called.
|
||||
//
|
||||
// The WaitGroup and CloseOnce fields are used to ensure that the
|
||||
// goroutine is awaited and closed only once.
|
||||
type closeAwareDevice struct {
|
||||
isClosed atomic.Bool
|
||||
tun.Device
|
||||
closeEventCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func newClosableDevice(tunDevice tun.Device) *closeAwareDevice {
|
||||
return &closeAwareDevice{
|
||||
Device: tunDevice,
|
||||
isClosed: atomic.Bool{},
|
||||
closeEventCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// redirectEvents redirects the Events() method of the underlying tun.Device
|
||||
// to the given channel (RenewableTUN's events channel).
|
||||
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-c.Device.Events():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if ev == tun.EventDown {
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case out <- ev:
|
||||
case <-c.closeEventCh:
|
||||
return
|
||||
}
|
||||
case <-c.closeEventCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Close calls the underlying Device's Close method
|
||||
// after setting isClosed to true.
|
||||
func (c *closeAwareDevice) Close() (err error) {
|
||||
c.closeOnce.Do(func() {
|
||||
c.isClosed.Store(true)
|
||||
close(c.closeEventCh)
|
||||
err = c.Device.Close()
|
||||
c.wg.Wait()
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *closeAwareDevice) IsClosed() bool {
|
||||
return c.isClosed.Load()
|
||||
}
|
||||
|
||||
type RenewableTUN struct {
|
||||
devices []*closeAwareDevice
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
events chan tun.Event
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func NewRenewableTUN() *RenewableTUN {
|
||||
r := &RenewableTUN{
|
||||
devices: make([]*closeAwareDevice, 0),
|
||||
mu: sync.Mutex{},
|
||||
events: make(chan tun.Event, 16),
|
||||
}
|
||||
r.cond = sync.NewCond(&r.mu)
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) File() *os.File {
|
||||
for {
|
||||
dev := r.peekLast()
|
||||
if dev == nil {
|
||||
if !r.waitForDevice() {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
file := dev.File()
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return file
|
||||
}
|
||||
}
|
||||
|
||||
// Read reads from an underlying tun.Device kept in the r.devices slice.
|
||||
// If no device is available, it waits for one to be added via AddDevice().
|
||||
//
|
||||
// On error, it retries reading from the newest device instead of returning the error
|
||||
// if the device is closed; if not, it propagates the error.
|
||||
func (r *RenewableTUN) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
for {
|
||||
dev := r.peekLast()
|
||||
if dev == nil {
|
||||
// wait until AddDevice() signals a new device via cond.Broadcast()
|
||||
if !r.waitForDevice() { // returns false if the renewable TUN itself is closed
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
n, err = dev.Read(bufs, sizes, offset)
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// swap in progress; retry on the newest instead of returning the error
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
return n, err // propagate non-swap error
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes to underlying tun.Device kept in the r.devices slice.
|
||||
// If no device is available, it waits for one to be added via AddDevice().
|
||||
//
|
||||
// On error, it retries writing to the newest device instead of returning the error
|
||||
// if the device is closed; if not, it propagates the error.
|
||||
func (r *RenewableTUN) Write(bufs [][]byte, offset int) (int, error) {
|
||||
for {
|
||||
dev := r.peekLast()
|
||||
if dev == nil {
|
||||
if !r.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
n, err := dev.Write(bufs, offset)
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) MTU() (int, error) {
|
||||
for {
|
||||
dev := r.peekLast()
|
||||
if dev == nil {
|
||||
if !r.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
mtu, err := dev.MTU()
|
||||
if err == nil {
|
||||
return mtu, nil
|
||||
}
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) Name() (string, error) {
|
||||
for {
|
||||
dev := r.peekLast()
|
||||
if dev == nil {
|
||||
if !r.waitForDevice() {
|
||||
return "", io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
name, err := dev.Name()
|
||||
if err == nil {
|
||||
return name, nil
|
||||
}
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Events returns a channel that is fed events from the underlying tun.Device's events channel
|
||||
// once it is added.
|
||||
func (r *RenewableTUN) Events() <-chan tun.Event {
|
||||
return r.events
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) Close() error {
|
||||
// Attempts to set the RenewableTUN closed flag to true.
|
||||
// If it's already true, returns immediately.
|
||||
if !r.closed.CompareAndSwap(false, true) {
|
||||
return nil // already closed: idempotent
|
||||
}
|
||||
r.mu.Lock()
|
||||
devices := r.devices
|
||||
r.devices = nil
|
||||
r.cond.Broadcast()
|
||||
r.mu.Unlock()
|
||||
|
||||
var lastErr error
|
||||
|
||||
log.Debugf("closing %d devices", len(devices))
|
||||
for _, device := range devices {
|
||||
if err := device.Close(); err != nil {
|
||||
log.Debugf("error closing a device: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
close(r.events)
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) AddDevice(device tun.Device) {
|
||||
r.mu.Lock()
|
||||
if r.closed.Load() {
|
||||
r.mu.Unlock()
|
||||
_ = device.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var toClose *closeAwareDevice
|
||||
if len(r.devices) > 0 {
|
||||
toClose = r.devices[len(r.devices)-1]
|
||||
}
|
||||
|
||||
cad := newClosableDevice(device)
|
||||
cad.redirectEvents(r.events)
|
||||
|
||||
r.devices = []*closeAwareDevice{cad}
|
||||
r.cond.Broadcast()
|
||||
|
||||
r.mu.Unlock()
|
||||
|
||||
if toClose != nil {
|
||||
if err := toClose.Close(); err != nil {
|
||||
log.Debugf("error closing last device: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) waitForDevice() bool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
for len(r.devices) == 0 && !r.closed.Load() {
|
||||
r.cond.Wait()
|
||||
}
|
||||
return !r.closed.Load()
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) peekLast() *closeAwareDevice {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if len(r.devices) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.devices[len(r.devices)-1]
|
||||
}
|
||||
@@ -21,5 +21,6 @@ type WGTunDevice interface {
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
Device() *wgdevice.Device
|
||||
GetNet() *netstack.Net
|
||||
RenewTun(fd int) error
|
||||
GetICEBind() device.EndpointManager
|
||||
}
|
||||
|
||||
@@ -24,3 +24,7 @@ func (w *WGIface) Create() error {
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||
return fmt.Errorf("this function has not implemented on non mobile")
|
||||
}
|
||||
|
||||
func (w *WGIface) RenewTun(fd int) error {
|
||||
return fmt.Errorf("this function has not been implemented on non-android")
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
||||
// Will reuse an existing one.
|
||||
// todo: review does this function really necessary or can we merge it with iOS
|
||||
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
@@ -22,3 +23,9 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s
|
||||
func (w *WGIface) Create() error {
|
||||
return fmt.Errorf("this function has not implemented on this platform")
|
||||
}
|
||||
|
||||
func (w *WGIface) RenewTun(fd int) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.tun.RenewTun(fd)
|
||||
}
|
||||
|
||||
@@ -39,3 +39,7 @@ func (w *WGIface) Create() error {
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||
return fmt.Errorf("this function has not implemented on this platform")
|
||||
}
|
||||
|
||||
func (w *WGIface) RenewTun(fd int) error {
|
||||
return fmt.Errorf("this function has not been implemented on this platform")
|
||||
}
|
||||
|
||||
@@ -60,14 +60,19 @@ func (t TokenInfo) GetTokenToUse() string {
|
||||
return t.AccessToken
|
||||
}
|
||||
|
||||
func shouldUseDeviceFlow(force bool, isUnixDesktopClient bool) bool {
|
||||
return force || (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient
|
||||
}
|
||||
|
||||
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
|
||||
//
|
||||
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
|
||||
// and if that also fails, the authentication process is deemed unsuccessful
|
||||
//
|
||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) {
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
||||
// forceDeviceCodeFlow can be used to skip PKCE and go directly to Device Code Flow (e.g., for Android TV)
|
||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, forceDeviceCodeFlow bool, hint string) (OAuthFlow, error) {
|
||||
if shouldUseDeviceFlow(forceDeviceCodeFlow, isUnixDesktopClient) {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -21,6 +22,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/templates"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
|
||||
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
||||
@@ -46,9 +48,10 @@ type PKCEAuthorizationFlow struct {
|
||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||
var availableRedirectURL string
|
||||
|
||||
// find the first available redirect URL
|
||||
excludedRanges := getSystemExcludedPortRanges()
|
||||
|
||||
for _, redirectURL := range config.RedirectURLs {
|
||||
if !isRedirectURLPortUsed(redirectURL) {
|
||||
if !isRedirectURLPortUsed(redirectURL, excludedRanges) {
|
||||
availableRedirectURL = redirectURL
|
||||
break
|
||||
}
|
||||
@@ -102,10 +105,10 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||
}
|
||||
if !p.providerConfig.DisablePromptLogin {
|
||||
if p.providerConfig.LoginFlag.IsPromptLogin() {
|
||||
switch p.providerConfig.LoginFlag {
|
||||
case common.LoginFlagPromptLogin:
|
||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||
}
|
||||
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
|
||||
case common.LoginFlagMaxAge0:
|
||||
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||
}
|
||||
}
|
||||
@@ -282,15 +285,22 @@ func createCodeChallenge(codeVerifier string) string {
|
||||
return base64.RawURLEncoding.EncodeToString(sha2[:])
|
||||
}
|
||||
|
||||
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
|
||||
func isRedirectURLPortUsed(redirectURL string) bool {
|
||||
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use or excluded on Windows.
|
||||
func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRange) bool {
|
||||
parsedURL, err := url.Parse(redirectURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse redirect URL: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf(":%s", parsedURL.Port())
|
||||
port := parsedURL.Port()
|
||||
|
||||
if isPortInExcludedRange(port, excludedRanges) {
|
||||
log.Warnf("port %s is in Windows excluded port range, skipping", port)
|
||||
return true
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf(":%s", port)
|
||||
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
||||
if err != nil {
|
||||
return false
|
||||
@@ -304,6 +314,33 @@ func isRedirectURLPortUsed(redirectURL string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// excludedPortRange represents a range of excluded ports.
|
||||
type excludedPortRange struct {
|
||||
start int
|
||||
end int
|
||||
}
|
||||
|
||||
// isPortInExcludedRange checks if the given port is in any of the excluded ranges.
|
||||
func isPortInExcludedRange(port string, excludedRanges []excludedPortRange) bool {
|
||||
if len(excludedRanges) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
portNum, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
log.Debugf("invalid port number %s: %v", port, err)
|
||||
return false
|
||||
}
|
||||
|
||||
for _, r := range excludedRanges {
|
||||
if portNum >= r.start && portNum <= r.end {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
|
||||
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
|
||||
if err != nil {
|
||||
|
||||
8
client/internal/auth/pkce_flow_other.go
Normal file
8
client/internal/auth/pkce_flow_other.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build !windows
|
||||
|
||||
package auth
|
||||
|
||||
// getSystemExcludedPortRanges returns nil on non-Windows platforms.
|
||||
func getSystemExcludedPortRanges() []excludedPortRange {
|
||||
return nil
|
||||
}
|
||||
@@ -2,8 +2,11 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
@@ -20,22 +23,28 @@ func TestPromptLogin(t *testing.T) {
|
||||
name string
|
||||
loginFlag mgm.LoginFlag
|
||||
disablePromptLogin bool
|
||||
expect string
|
||||
expectContains []string
|
||||
}{
|
||||
{
|
||||
name: "Prompt login",
|
||||
loginFlag: mgm.LoginFlagPrompt,
|
||||
expect: promptLogin,
|
||||
name: "Prompt login",
|
||||
loginFlag: mgm.LoginFlagPromptLogin,
|
||||
expectContains: []string{promptLogin},
|
||||
},
|
||||
{
|
||||
name: "Max age 0 login",
|
||||
loginFlag: mgm.LoginFlagMaxAge0,
|
||||
expect: maxAge0,
|
||||
name: "Max age 0",
|
||||
loginFlag: mgm.LoginFlagMaxAge0,
|
||||
expectContains: []string{maxAge0},
|
||||
},
|
||||
{
|
||||
name: "Disable prompt login",
|
||||
loginFlag: mgm.LoginFlagPrompt,
|
||||
loginFlag: mgm.LoginFlagPromptLogin,
|
||||
disablePromptLogin: true,
|
||||
expectContains: []string{},
|
||||
},
|
||||
{
|
||||
name: "None flag should not add parameters",
|
||||
loginFlag: mgm.LoginFlagNone,
|
||||
expectContains: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -50,6 +59,7 @@ func TestPromptLogin(t *testing.T) {
|
||||
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||
UseIDToken: true,
|
||||
LoginFlag: tc.loginFlag,
|
||||
DisablePromptLogin: tc.disablePromptLogin,
|
||||
}
|
||||
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||
if err != nil {
|
||||
@@ -60,12 +70,153 @@ func TestPromptLogin(t *testing.T) {
|
||||
t.Fatalf("Failed to request auth info: %v", err)
|
||||
}
|
||||
|
||||
if !tc.disablePromptLogin {
|
||||
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
|
||||
} else {
|
||||
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
|
||||
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
|
||||
for _, expected := range tc.expectContains {
|
||||
require.Contains(t, authInfo.VerificationURIComplete, expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPortInExcludedRange(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
port string
|
||||
excludedRanges []excludedPortRange
|
||||
expectedBlocked bool
|
||||
}{
|
||||
{
|
||||
name: "Port in excluded range",
|
||||
port: "8080",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Port at start of range",
|
||||
port: "8000",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Port at end of range",
|
||||
port: "8100",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Port before range",
|
||||
port: "7999",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Port after range",
|
||||
port: "8101",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Empty excluded ranges",
|
||||
port: "8080",
|
||||
excludedRanges: []excludedPortRange{},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Nil excluded ranges",
|
||||
port: "8080",
|
||||
excludedRanges: nil,
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Multiple ranges - port in second range",
|
||||
port: "9050",
|
||||
excludedRanges: []excludedPortRange{
|
||||
{start: 8000, end: 8100},
|
||||
{start: 9000, end: 9100},
|
||||
},
|
||||
expectedBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple ranges - port not in any range",
|
||||
port: "8500",
|
||||
excludedRanges: []excludedPortRange{
|
||||
{start: 8000, end: 8100},
|
||||
{start: 9000, end: 9100},
|
||||
},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid port string",
|
||||
port: "invalid",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Empty port string",
|
||||
port: "",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isPortInExcludedRange(tt.port, tt.excludedRanges)
|
||||
assert.Equal(t, tt.expectedBlocked, result, "Port exclusion check mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRedirectURLPortUsed(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = listener.Close()
|
||||
}()
|
||||
|
||||
usedPort := listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
redirectURL string
|
||||
excludedRanges []excludedPortRange
|
||||
expectedUsed bool
|
||||
}{
|
||||
{
|
||||
name: "Port in excluded range",
|
||||
redirectURL: "http://127.0.0.1:8080/",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedUsed: true,
|
||||
},
|
||||
{
|
||||
name: "Port actually in use",
|
||||
redirectURL: fmt.Sprintf("http://127.0.0.1:%d/", usedPort),
|
||||
excludedRanges: nil,
|
||||
expectedUsed: true,
|
||||
},
|
||||
{
|
||||
name: "Port not in use and not excluded",
|
||||
redirectURL: "http://127.0.0.1:65432/",
|
||||
excludedRanges: nil,
|
||||
expectedUsed: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid URL without port",
|
||||
redirectURL: "not-a-valid-url",
|
||||
excludedRanges: nil,
|
||||
expectedUsed: false,
|
||||
},
|
||||
{
|
||||
name: "Port excluded even if not in use",
|
||||
redirectURL: "http://127.0.0.1:8050/",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedUsed: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isRedirectURLPortUsed(tt.redirectURL, tt.excludedRanges)
|
||||
assert.Equal(t, tt.expectedUsed, result, "Port usage check mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
86
client/internal/auth/pkce_flow_windows.go
Normal file
86
client/internal/auth/pkce_flow_windows.go
Normal file
@@ -0,0 +1,86 @@
|
||||
//go:build windows
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// getSystemExcludedPortRanges retrieves the excluded port ranges from Windows using netsh.
|
||||
func getSystemExcludedPortRanges() []excludedPortRange {
|
||||
ranges, err := getExcludedPortRangesFromNetsh()
|
||||
if err != nil {
|
||||
log.Debugf("failed to get Windows excluded port ranges: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return ranges
|
||||
}
|
||||
|
||||
// getExcludedPortRangesFromNetsh retrieves excluded port ranges using netsh command.
|
||||
func getExcludedPortRangesFromNetsh() ([]excludedPortRange, error) {
|
||||
cmd := exec.Command("netsh", "interface", "ipv4", "show", "excludedportrange", "protocol=tcp")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("netsh command: %w", err)
|
||||
}
|
||||
|
||||
return parseExcludedPortRanges(string(output))
|
||||
}
|
||||
|
||||
// parseExcludedPortRanges parses the output of the netsh command to extract port ranges.
|
||||
func parseExcludedPortRanges(output string) ([]excludedPortRange, error) {
|
||||
var ranges []excludedPortRange
|
||||
scanner := bufio.NewScanner(strings.NewReader(output))
|
||||
|
||||
foundHeader := false
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
if strings.Contains(line, "Start Port") && strings.Contains(line, "End Port") {
|
||||
foundHeader = true
|
||||
continue
|
||||
}
|
||||
|
||||
if !foundHeader {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(line, "----------") {
|
||||
continue
|
||||
}
|
||||
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
startPort, err := strconv.Atoi(fields[0])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
endPort, err := strconv.Atoi(fields[1])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ranges = append(ranges, excludedPortRange{start: startPort, end: endPort})
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("scan output: %w", err)
|
||||
}
|
||||
|
||||
return ranges, nil
|
||||
}
|
||||
116
client/internal/auth/pkce_flow_windows_test.go
Normal file
116
client/internal/auth/pkce_flow_windows_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
//go:build windows
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
func TestParseExcludedPortRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
netshOutput string
|
||||
expectedRanges []excludedPortRange
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid netsh output with multiple ranges",
|
||||
netshOutput: `
|
||||
Protocol tcp Dynamic Port Range
|
||||
---------------------------------
|
||||
Start Port : 49152
|
||||
Number of Ports : 16384
|
||||
|
||||
Protocol tcp Excluded Port Ranges
|
||||
---------------------------------
|
||||
Start Port End Port
|
||||
---------- --------
|
||||
5357 5357 *
|
||||
50000 50059 *
|
||||
`,
|
||||
expectedRanges: []excludedPortRange{
|
||||
{start: 5357, end: 5357},
|
||||
{start: 50000, end: 50059},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Empty output",
|
||||
netshOutput: `
|
||||
Protocol tcp Dynamic Port Range
|
||||
---------------------------------
|
||||
Start Port : 49152
|
||||
Number of Ports : 16384
|
||||
`,
|
||||
expectedRanges: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Single range",
|
||||
netshOutput: `
|
||||
Protocol tcp Excluded Port Ranges
|
||||
---------------------------------
|
||||
Start Port End Port
|
||||
---------- --------
|
||||
8080 8090
|
||||
`,
|
||||
expectedRanges: []excludedPortRange{
|
||||
{start: 8080, end: 8090},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ranges, err := parseExcludedPortRanges(tt.netshOutput)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedRanges, ranges)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
|
||||
ranges := getSystemExcludedPortRanges()
|
||||
t.Logf("Found %d excluded port ranges on this system", len(ranges))
|
||||
|
||||
listener1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = listener1.Close()
|
||||
}()
|
||||
usedPort1 := listener1.Addr().(*net.TCPAddr).Port
|
||||
|
||||
availablePort := 65432
|
||||
|
||||
config := internal.PKCEAuthProviderConfig{
|
||||
ClientID: "test-client-id",
|
||||
Audience: "test-audience",
|
||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||
Scope: "openid email profile",
|
||||
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||
RedirectURLs: []string{
|
||||
fmt.Sprintf("http://127.0.0.1:%d/", usedPort1),
|
||||
fmt.Sprintf("http://127.0.0.1:%d/", availablePort),
|
||||
},
|
||||
UseIDToken: true,
|
||||
}
|
||||
|
||||
flow, err := NewPKCEAuthorizationFlow(config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, flow)
|
||||
assert.Contains(t, flow.oAuthConfig.RedirectURL, fmt.Sprintf(":%d", availablePort),
|
||||
"Should skip port in use and select available port")
|
||||
}
|
||||
@@ -74,6 +74,7 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
dnsReadyListener dns.ReadyListener,
|
||||
stateFilePath string,
|
||||
) error {
|
||||
// in case of non Android os these variables will be nil
|
||||
mobileDependency := MobileDependency{
|
||||
@@ -82,6 +83,7 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil)
|
||||
}
|
||||
@@ -271,11 +273,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
||||
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
||||
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||
c.engine = engine
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
||||
if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
@@ -291,12 +294,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
<-engineCtx.Done()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
engine := c.engine
|
||||
c.engine = nil
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
if engine != nil && engine.wgInterface != nil {
|
||||
// todo: consider to remove this condition. Is not thread safe.
|
||||
// We should always call Stop(), but we need to verify that it is idempotent
|
||||
if engine.wgInterface != nil {
|
||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
||||
|
||||
if err := engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
|
||||
@@ -56,6 +56,7 @@ block.prof: Block profiling information.
|
||||
heap.prof: Heap profiling information (snapshot of memory allocations).
|
||||
allocs.prof: Allocations profiling information.
|
||||
threadcreate.prof: Thread creation profiling information.
|
||||
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
|
||||
|
||||
|
||||
Anonymization Process
|
||||
@@ -109,6 +110,9 @@ go tool pprof -http=:8088 heap.prof
|
||||
|
||||
This will open a web browser tab with the profiling information.
|
||||
|
||||
Stack Trace
|
||||
The stack_trace.txt file contains a complete snapshot of all goroutine stack traces at the time the debug bundle was created.
|
||||
|
||||
Routes
|
||||
The routes.txt file contains detailed routing table information in a tabular format:
|
||||
|
||||
@@ -327,6 +331,10 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add profiles to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addStackTrace(); err != nil {
|
||||
log.Errorf("failed to add stack trace to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addSyncResponse(); err != nil {
|
||||
return fmt.Errorf("add sync response: %w", err)
|
||||
}
|
||||
@@ -522,6 +530,18 @@ func (g *BundleGenerator) addProf() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStackTrace() error {
|
||||
buf := make([]byte, 5242880) // 5 MB buffer
|
||||
n := runtime.Stack(buf, true)
|
||||
|
||||
stackTrace := bytes.NewReader(buf[:n])
|
||||
if err := g.addFileToZip(stackTrace, "stack_trace.txt"); err != nil {
|
||||
return fmt.Errorf("add stack trace file to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addInterfaces() error {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
|
||||
@@ -76,6 +76,9 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
|
||||
var records []nbdns.SimpleRecord
|
||||
|
||||
for _, zone := range config.CustomZones {
|
||||
if zone.SkipPTRProcess {
|
||||
continue
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
if record.Type != int(dns.TypeA) {
|
||||
continue
|
||||
@@ -106,8 +109,9 @@ func addReverseZone(config *nbdns.Config, network netip.Prefix) {
|
||||
records := collectPTRRecords(config, network)
|
||||
|
||||
reverseZone := nbdns.CustomZone{
|
||||
Domain: zoneName,
|
||||
Records: records,
|
||||
Domain: zoneName,
|
||||
Records: records,
|
||||
SearchDomainDisabled: true,
|
||||
}
|
||||
|
||||
config.CustomZones = append(config.CustomZones, reverseZone)
|
||||
|
||||
@@ -11,11 +11,6 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4ReverseZone = ".in-addr.arpa."
|
||||
ipv6ReverseZone = ".ip6.arpa."
|
||||
)
|
||||
|
||||
type hostManager interface {
|
||||
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
||||
restoreHostDNS() error
|
||||
@@ -110,10 +105,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip netip.Addr, port int) H
|
||||
}
|
||||
|
||||
for _, customZone := range dnsConfig.CustomZones {
|
||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
||||
MatchOnly: matchOnly,
|
||||
MatchOnly: customZone.SearchDomainDisabled,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -197,7 +197,7 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
||||
timeoutMsg += " " + peerInfo
|
||||
}
|
||||
timeoutMsg += fmt.Sprintf(" - error: %v", err)
|
||||
logger.Warnf(timeoutMsg)
|
||||
logger.Warn(timeoutMsg)
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||
|
||||
@@ -234,6 +234,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
|
||||
for i, ip := range ips {
|
||||
ips[i] = ip.Unmap()
|
||||
}
|
||||
|
||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
f.cache.set(domain, question.Qtype, ips)
|
||||
|
||||
@@ -255,7 +255,7 @@ func NewEngine(
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
|
||||
path := sm.GetStatePath()
|
||||
if runtime.GOOS == "ios" {
|
||||
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
|
||||
if !fileExists(mobileDep.StateFilePath) {
|
||||
err := createFile(mobileDep.StateFilePath)
|
||||
if err != nil {
|
||||
@@ -280,7 +280,6 @@ func (e *Engine) Stop() error {
|
||||
return nil
|
||||
}
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if e.connMgr != nil {
|
||||
e.connMgr.Close()
|
||||
@@ -298,9 +297,6 @@ func (e *Engine) Stop() error {
|
||||
|
||||
e.cleanupSSHConfig()
|
||||
|
||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||
e.stopDNSServer()
|
||||
|
||||
if e.ingressGatewayMgr != nil {
|
||||
if err := e.ingressGatewayMgr.Close(); err != nil {
|
||||
log.Warnf("failed to cleanup forward rules: %v", err)
|
||||
@@ -308,24 +304,29 @@ func (e *Engine) Stop() error {
|
||||
e.ingressGatewayMgr = nil
|
||||
}
|
||||
|
||||
e.stopDNSForwarder()
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop(e.stateManager)
|
||||
}
|
||||
|
||||
if e.srWatcher != nil {
|
||||
e.srWatcher.Close()
|
||||
}
|
||||
|
||||
log.Info("cleaning up status recorder states")
|
||||
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
||||
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
||||
|
||||
if err := e.removeAllPeers(); err != nil {
|
||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
||||
log.Errorf("failed to remove all peers: %s", err)
|
||||
}
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop(e.stateManager)
|
||||
}
|
||||
|
||||
e.stopDNSForwarder()
|
||||
|
||||
// stop/restore DNS after peers are closed but before interface goes down
|
||||
// so dbus and friends don't complain because of a missing interface
|
||||
e.stopDNSServer()
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
@@ -337,16 +338,18 @@ func (e *Engine) Stop() error {
|
||||
e.flowManager.Close()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
stateCtx, stateCancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer stateCancel()
|
||||
|
||||
if err := e.stateManager.Stop(ctx); err != nil {
|
||||
return fmt.Errorf("failed to stop state manager: %w", err)
|
||||
if err := e.stateManager.Stop(stateCtx); err != nil {
|
||||
log.Errorf("failed to stop state manager: %v", err)
|
||||
}
|
||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
timeout := e.calculateShutdownTimeout()
|
||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
@@ -432,8 +435,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create rosenpass manager: %w", err)
|
||||
}
|
||||
err := e.rpManager.Run()
|
||||
if err != nil {
|
||||
if err := e.rpManager.Run(); err != nil {
|
||||
return fmt.Errorf("run rosenpass manager: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -485,6 +487,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
}
|
||||
|
||||
if err := e.createFirewall(); err != nil {
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -750,6 +753,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||
if e.ctx.Err() != nil {
|
||||
return e.ctx.Err()
|
||||
}
|
||||
|
||||
if update.GetNetbirdConfig() != nil {
|
||||
wCfg := update.GetNetbirdConfig()
|
||||
err := e.updateTURNs(wCfg.GetTurns())
|
||||
@@ -1207,7 +1215,9 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
|
||||
|
||||
for _, zone := range protoDNSConfig.GetCustomZones() {
|
||||
dnsZone := nbdns.CustomZone{
|
||||
Domain: zone.GetDomain(),
|
||||
Domain: zone.GetDomain(),
|
||||
SearchDomainDisabled: zone.GetSearchDomainDisabled(),
|
||||
SkipPTRProcess: zone.GetSkipPTRProcess(),
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
dnsRecord := nbdns.SimpleRecord{
|
||||
@@ -1367,6 +1377,11 @@ func (e *Engine) receiveSignalEvents() {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||
if e.ctx.Err() != nil {
|
||||
return e.ctx.Err()
|
||||
}
|
||||
|
||||
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||
if !ok {
|
||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||
@@ -1831,6 +1846,18 @@ func (e *Engine) GetWgAddr() netip.Addr {
|
||||
return e.wgInterface.Address().IP
|
||||
}
|
||||
|
||||
func (e *Engine) RenewTun(fd int) error {
|
||||
e.syncMsgMux.Lock()
|
||||
wgInterface := e.wgInterface
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
if wgInterface == nil {
|
||||
return fmt.Errorf("wireguard interface not initialized")
|
||||
}
|
||||
|
||||
return wgInterface.RenewTun(fd)
|
||||
}
|
||||
|
||||
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
||||
func (e *Engine) updateDNSForwarder(
|
||||
enabled bool,
|
||||
|
||||
@@ -30,11 +30,12 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@@ -54,7 +55,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -110,6 +110,10 @@ type MockWGIface struct {
|
||||
LastActivitiesFunc func() map[string]monotime.Time
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RenewTun(_ int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1624,14 +1628,17 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
type wgIfaceBase interface {
|
||||
Create() error
|
||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
||||
RenewTun(fd int) error
|
||||
IsUserspaceBind() bool
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
|
||||
@@ -20,7 +20,7 @@ type EndpointUpdater struct {
|
||||
wgConfig WgConfig
|
||||
initiator bool
|
||||
|
||||
// mu protects updateWireGuardPeer and cancelFunc
|
||||
// mu protects cancelFunc
|
||||
mu sync.Mutex
|
||||
cancelFunc func()
|
||||
updateWg sync.WaitGroup
|
||||
@@ -86,11 +86,9 @@ func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.U
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
e.mu.Lock()
|
||||
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
|
||||
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
|
||||
}
|
||||
e.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
218
client/internal/sleep/detector_darwin.go
Normal file
218
client/internal/sleep/detector_darwin.go
Normal file
@@ -0,0 +1,218 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package sleep
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
|
||||
#include <IOKit/pwr_mgt/IOPMLib.h>
|
||||
#include <IOKit/IOMessage.h>
|
||||
#include <CoreFoundation/CoreFoundation.h>
|
||||
|
||||
extern void sleepCallbackBridge();
|
||||
extern void poweredOnCallbackBridge();
|
||||
extern void suspendedCallbackBridge();
|
||||
extern void resumedCallbackBridge();
|
||||
|
||||
|
||||
// C global variables for IOKit state
|
||||
static IONotificationPortRef g_notifyPortRef = NULL;
|
||||
static io_object_t g_notifierObject = 0;
|
||||
static io_object_t g_generalInterestNotifier = 0;
|
||||
static io_connect_t g_rootPort = 0;
|
||||
static CFRunLoopRef g_runLoop = NULL;
|
||||
|
||||
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
|
||||
switch (messageType) {
|
||||
case kIOMessageSystemWillSleep:
|
||||
sleepCallbackBridge();
|
||||
IOAllowPowerChange(g_rootPort, (long)messageArgument);
|
||||
break;
|
||||
case kIOMessageSystemHasPoweredOn:
|
||||
poweredOnCallbackBridge();
|
||||
break;
|
||||
case kIOMessageServiceIsSuspended:
|
||||
suspendedCallbackBridge();
|
||||
break;
|
||||
case kIOMessageServiceIsResumed:
|
||||
resumedCallbackBridge();
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void registerNotifications() {
|
||||
g_rootPort = IORegisterForSystemPower(
|
||||
NULL,
|
||||
&g_notifyPortRef,
|
||||
(IOServiceInterestCallback)sleepCallback,
|
||||
&g_notifierObject
|
||||
);
|
||||
|
||||
if (g_rootPort == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
CFRunLoopAddSource(CFRunLoopGetCurrent(),
|
||||
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||
kCFRunLoopCommonModes);
|
||||
|
||||
g_runLoop = CFRunLoopGetCurrent();
|
||||
CFRunLoopRun();
|
||||
}
|
||||
|
||||
static void unregisterNotifications() {
|
||||
CFRunLoopRemoveSource(g_runLoop,
|
||||
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||
kCFRunLoopCommonModes);
|
||||
|
||||
IODeregisterForSystemPower(&g_notifierObject);
|
||||
IOServiceClose(g_rootPort);
|
||||
IONotificationPortDestroy(g_notifyPortRef);
|
||||
CFRunLoopStop(g_runLoop);
|
||||
|
||||
g_notifyPortRef = NULL;
|
||||
g_notifierObject = 0;
|
||||
g_rootPort = 0;
|
||||
g_runLoop = NULL;
|
||||
}
|
||||
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
serviceRegistry = make(map[*Detector]struct{})
|
||||
serviceRegistryMu sync.Mutex
|
||||
)
|
||||
|
||||
//export sleepCallbackBridge
|
||||
func sleepCallbackBridge() {
|
||||
log.Info("sleepCallbackBridge event triggered")
|
||||
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
|
||||
for svc := range serviceRegistry {
|
||||
svc.triggerCallback(EventTypeSleep)
|
||||
}
|
||||
}
|
||||
|
||||
//export resumedCallbackBridge
|
||||
func resumedCallbackBridge() {
|
||||
log.Info("resumedCallbackBridge event triggered")
|
||||
}
|
||||
|
||||
//export suspendedCallbackBridge
|
||||
func suspendedCallbackBridge() {
|
||||
log.Info("suspendedCallbackBridge event triggered")
|
||||
}
|
||||
|
||||
//export poweredOnCallbackBridge
|
||||
func poweredOnCallbackBridge() {
|
||||
log.Info("poweredOnCallbackBridge event triggered")
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
|
||||
for svc := range serviceRegistry {
|
||||
svc.triggerCallback(EventTypeWakeUp)
|
||||
}
|
||||
}
|
||||
|
||||
type Detector struct {
|
||||
callback func(event EventType)
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewDetector() (*Detector, error) {
|
||||
return &Detector{}, nil
|
||||
}
|
||||
|
||||
func (d *Detector) Register(callback func(event EventType)) error {
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
|
||||
if _, exists := serviceRegistry[d]; exists {
|
||||
return fmt.Errorf("detector service already registered")
|
||||
}
|
||||
|
||||
d.callback = callback
|
||||
|
||||
d.ctx, d.cancel = context.WithCancel(context.Background())
|
||||
|
||||
if len(serviceRegistry) > 0 {
|
||||
serviceRegistry[d] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
serviceRegistry[d] = struct{}{}
|
||||
|
||||
// CFRunLoop must run on a single fixed OS thread
|
||||
go func() {
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
C.registerNotifications()
|
||||
}()
|
||||
|
||||
log.Info("sleep detection service started on macOS")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down
|
||||
// and the runloop is stopped and cleaned up.
|
||||
func (d *Detector) Deregister() error {
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
_, exists := serviceRegistry[d]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
// cancel and remove this detector
|
||||
d.cancel()
|
||||
delete(serviceRegistry, d)
|
||||
|
||||
// If other Detectors still exist, leave IOKit running
|
||||
if len(serviceRegistry) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info("sleep detection service stopping (deregister)")
|
||||
|
||||
// Deregister IOKit notifications, stop runloop, and free resources
|
||||
C.unregisterNotifications()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Detector) triggerCallback(event EventType) {
|
||||
doneChan := make(chan struct{})
|
||||
|
||||
timeout := time.NewTimer(500 * time.Millisecond)
|
||||
defer timeout.Stop()
|
||||
|
||||
cb := d.callback
|
||||
go func(callback func(event EventType)) {
|
||||
log.Info("sleep detection event fired")
|
||||
callback(event)
|
||||
close(doneChan)
|
||||
}(cb)
|
||||
|
||||
select {
|
||||
case <-doneChan:
|
||||
case <-d.ctx.Done():
|
||||
case <-timeout.C:
|
||||
log.Warnf("sleep callback timed out")
|
||||
}
|
||||
}
|
||||
9
client/internal/sleep/detector_notsupported.go
Normal file
9
client/internal/sleep/detector_notsupported.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !darwin || ios
|
||||
|
||||
package sleep
|
||||
|
||||
import "fmt"
|
||||
|
||||
func NewDetector() (detector, error) {
|
||||
return nil, fmt.Errorf("sleep not supported on this platform")
|
||||
}
|
||||
37
client/internal/sleep/service.go
Normal file
37
client/internal/sleep/service.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package sleep
|
||||
|
||||
var (
|
||||
EventTypeUnknown EventType = 0
|
||||
EventTypeSleep EventType = 1
|
||||
EventTypeWakeUp EventType = 2
|
||||
)
|
||||
|
||||
type EventType int
|
||||
|
||||
type detector interface {
|
||||
Register(callback func(eventType EventType)) error
|
||||
Deregister() error
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
detector detector
|
||||
}
|
||||
|
||||
func New() (*Service, error) {
|
||||
d, err := NewDetector()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Service{
|
||||
detector: d,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Register(callback func(eventType EventType)) error {
|
||||
return s.detector.Register(callback)
|
||||
}
|
||||
|
||||
func (s *Service) Deregister() error {
|
||||
return s.detector.Deregister()
|
||||
}
|
||||
@@ -1,9 +1,12 @@
|
||||
//go:build ios
|
||||
|
||||
package NetBirdSDK
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -90,7 +93,8 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s
|
||||
}
|
||||
|
||||
// Run start the internal client. It is a blocker function
|
||||
func (c *Client) Run(fd int32, interfaceName string) error {
|
||||
func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||
exportEnvList(envList)
|
||||
log.Infof("Starting NetBird client")
|
||||
log.Debugf("Tunnel uses interface: %s", interfaceName)
|
||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
@@ -228,7 +232,7 @@ func (c *Client) LoginForMobile() string {
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "")
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, false, "")
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
@@ -433,3 +437,19 @@ func toNetIDs(routes []string) []route.NetID {
|
||||
}
|
||||
return netIDs
|
||||
}
|
||||
|
||||
func exportEnvList(list *EnvList) {
|
||||
if list == nil {
|
||||
return
|
||||
}
|
||||
for k, v := range list.AllItems() {
|
||||
log.Debugf("Env variable %s's value is currently: %s", k, os.Getenv(k))
|
||||
log.Debugf("Setting env variable %s: %s", k, v)
|
||||
|
||||
if err := os.Setenv(k, v); err != nil {
|
||||
log.Errorf("could not set env variable %s: %v", k, err)
|
||||
} else {
|
||||
log.Debugf("Env variable %s was set successfully", k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
34
client/ios/NetBirdSDK/env_list.go
Normal file
34
client/ios/NetBirdSDK/env_list.go
Normal file
@@ -0,0 +1,34 @@
|
||||
//go:build ios
|
||||
|
||||
package NetBirdSDK
|
||||
|
||||
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||
|
||||
// EnvList is an exported struct to be bound by gomobile
|
||||
type EnvList struct {
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
// NewEnvList creates a new EnvList
|
||||
func NewEnvList() *EnvList {
|
||||
return &EnvList{data: make(map[string]string)}
|
||||
}
|
||||
|
||||
// Put adds a key-value pair
|
||||
func (el *EnvList) Put(key, value string) {
|
||||
el.data[key] = value
|
||||
}
|
||||
|
||||
// Get retrieves a value by key
|
||||
func (el *EnvList) Get(key string) string {
|
||||
return el.data[key]
|
||||
}
|
||||
|
||||
func (el *EnvList) AllItems() map[string]string {
|
||||
return el.data
|
||||
}
|
||||
|
||||
// GetEnvKeyNBForceRelay Exports the environment variable for the iOS client
|
||||
func GetEnvKeyNBForceRelay() string {
|
||||
return peer.EnvKeyNBForceRelay
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build ios
|
||||
|
||||
package NetBirdSDK
|
||||
|
||||
import _ "golang.org/x/mobile/bind"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build ios
|
||||
|
||||
package NetBirdSDK
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build ios
|
||||
|
||||
package NetBirdSDK
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build ios
|
||||
|
||||
package NetBirdSDK
|
||||
|
||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build ios
|
||||
|
||||
package NetBirdSDK
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build ios
|
||||
|
||||
package NetBirdSDK
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build ios
|
||||
|
||||
package NetBirdSDK
|
||||
|
||||
// RoutesSelectionInfoCollection made for Java layer to get non default types as collection
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,7 +24,7 @@ service DaemonService {
|
||||
// Status of the service.
|
||||
rpc Status(StatusRequest) returns (StatusResponse) {}
|
||||
|
||||
// Down engine work in the daemon.
|
||||
// Down stops engine work in the daemon.
|
||||
rpc Down(DownRequest) returns (DownResponse) {}
|
||||
|
||||
// GetConfig of the daemon.
|
||||
@@ -93,9 +93,26 @@ service DaemonService {
|
||||
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
||||
|
||||
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
|
||||
}
|
||||
|
||||
|
||||
|
||||
message OSLifecycleRequest {
|
||||
// avoid collision with loglevel enum
|
||||
enum CycleType {
|
||||
UNKNOWN = 0;
|
||||
SLEEP = 1;
|
||||
WAKEUP = 2;
|
||||
}
|
||||
|
||||
CycleType type = 1;
|
||||
}
|
||||
|
||||
message OSLifecycleResponse {}
|
||||
|
||||
|
||||
message LoginRequest {
|
||||
// setupKey netbird setup key.
|
||||
string setupKey = 1;
|
||||
|
||||
@@ -27,7 +27,7 @@ type DaemonServiceClient interface {
|
||||
Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error)
|
||||
// Status of the service.
|
||||
Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error)
|
||||
// Down engine work in the daemon.
|
||||
// Down stops engine work in the daemon.
|
||||
Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error)
|
||||
// GetConfig of the daemon.
|
||||
GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error)
|
||||
@@ -70,6 +70,7 @@ type DaemonServiceClient interface {
|
||||
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
|
||||
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
|
||||
}
|
||||
|
||||
type daemonServiceClient struct {
|
||||
@@ -382,6 +383,15 @@ func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTToken
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) {
|
||||
out := new(OSLifecycleResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// DaemonServiceServer is the server API for DaemonService service.
|
||||
// All implementations must embed UnimplementedDaemonServiceServer
|
||||
// for forward compatibility
|
||||
@@ -395,7 +405,7 @@ type DaemonServiceServer interface {
|
||||
Up(context.Context, *UpRequest) (*UpResponse, error)
|
||||
// Status of the service.
|
||||
Status(context.Context, *StatusRequest) (*StatusResponse, error)
|
||||
// Down engine work in the daemon.
|
||||
// Down stops engine work in the daemon.
|
||||
Down(context.Context, *DownRequest) (*DownResponse, error)
|
||||
// GetConfig of the daemon.
|
||||
GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error)
|
||||
@@ -438,6 +448,7 @@ type DaemonServiceServer interface {
|
||||
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
|
||||
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
|
||||
mustEmbedUnimplementedDaemonServiceServer()
|
||||
}
|
||||
|
||||
@@ -538,6 +549,9 @@ func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *Request
|
||||
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||
|
||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
@@ -1112,6 +1126,24 @@ func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, d
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(OSLifecycleRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/NotifyOSLifecycle",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, req.(*OSLifecycleRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@@ -1239,6 +1271,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "WaitJWTToken",
|
||||
Handler: _DaemonService_WaitJWTToken_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "NotifyOSLifecycle",
|
||||
Handler: _DaemonService_NotifyOSLifecycle_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
|
||||
77
client/server/lifecycle.go
Normal file
77
client/server/lifecycle.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
|
||||
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
|
||||
switch req.GetType() {
|
||||
case proto.OSLifecycleRequest_WAKEUP:
|
||||
return s.handleWakeUp(callerCtx)
|
||||
case proto.OSLifecycleRequest_SLEEP:
|
||||
return s.handleSleep(callerCtx)
|
||||
default:
|
||||
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
|
||||
}
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
|
||||
// handleWakeUp processes a wake-up event by triggering the Up command if the system was previously put to sleep.
|
||||
// It resets the sleep state and logs the process. Returns a response or an error if the Up command fails.
|
||||
func (s *Server) handleWakeUp(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
|
||||
if !s.sleepTriggeredDown.Load() {
|
||||
log.Info("skipping up because wasn't sleep down")
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
|
||||
// avoid other wakeup runs if sleep didn't make the computer sleep
|
||||
s.sleepTriggeredDown.Store(false)
|
||||
|
||||
log.Info("running up after wake up")
|
||||
_, err := s.Up(callerCtx, &proto.UpRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("running up failed: %v", err)
|
||||
return &proto.OSLifecycleResponse{}, err
|
||||
}
|
||||
|
||||
log.Info("running up command executed successfully")
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
|
||||
// handleSleep handles the sleep event by initiating a "down" sequence if the system is in a connected or connecting state.
|
||||
func (s *Server) handleSleep(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
|
||||
s.mutex.Lock()
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
s.mutex.Unlock()
|
||||
return &proto.OSLifecycleResponse{}, err
|
||||
}
|
||||
|
||||
if status != internal.StatusConnecting && status != internal.StatusConnected {
|
||||
log.Infof("skipping setting the agent down because status is %s", status)
|
||||
s.mutex.Unlock()
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
|
||||
log.Info("running down after system started sleeping")
|
||||
|
||||
_, err = s.Down(callerCtx, &proto.DownRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("running down failed: %v", err)
|
||||
return &proto.OSLifecycleResponse{}, err
|
||||
}
|
||||
|
||||
s.sleepTriggeredDown.Store(true)
|
||||
|
||||
log.Info("running down executed successfully")
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
219
client/server/lifecycle_test.go
Normal file
219
client/server/lifecycle_test.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
func newTestServer() *Server {
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
return &Server{
|
||||
rootCtx: ctx,
|
||||
statusRecorder: peer.NewRecorder(""),
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_WakeUp_SkipsWhenNotSleepTriggered(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
// sleepTriggeredDown is false by default
|
||||
assert.False(t, s.sleepTriggeredDown.Load())
|
||||
|
||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusIdle(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusIdle)
|
||||
|
||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_SLEEP,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is Idle")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusNeedsLogin(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusNeedsLogin)
|
||||
|
||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_SLEEP,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is NeedsLogin")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnecting(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusConnecting)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.actCancel = cancel
|
||||
|
||||
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_SLEEP,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
|
||||
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connecting")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnected(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusConnected)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.actCancel = cancel
|
||||
|
||||
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_SLEEP,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
|
||||
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connected")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_WakeUp_ResetsFlag(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
// Manually set the flag to simulate prior sleep down
|
||||
s.sleepTriggeredDown.Store(true)
|
||||
|
||||
// WakeUp will try to call Up which fails without proper setup, but flag should reset first
|
||||
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||
})
|
||||
|
||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should be reset after WakeUp attempt")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_MultipleWakeUpCalls(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
// First wakeup without prior sleep - should be no-op
|
||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load())
|
||||
|
||||
// Simulate prior sleep
|
||||
s.sleepTriggeredDown.Store(true)
|
||||
|
||||
// First wakeup after sleep - should reset flag
|
||||
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||
})
|
||||
assert.False(t, s.sleepTriggeredDown.Load())
|
||||
|
||||
// Second wakeup - should be no-op
|
||||
resp, err = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load())
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
resp, err := s.handleWakeUp(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
|
||||
s := newTestServer()
|
||||
s.sleepTriggeredDown.Store(true)
|
||||
|
||||
// Even if Up fails, flag should be reset
|
||||
_, _ = s.handleWakeUp(context.Background())
|
||||
|
||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag must be reset before calling Up")
|
||||
}
|
||||
|
||||
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status internal.StatusType
|
||||
}{
|
||||
{"Idle", internal.StatusIdle},
|
||||
{"NeedsLogin", internal.StatusNeedsLogin},
|
||||
{"LoginFailed", internal.StatusLoginFailed},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := newTestServer()
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(tt.status)
|
||||
|
||||
resp, err := s.handleSleep(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status internal.StatusType
|
||||
}{
|
||||
{"Connecting", internal.StatusConnecting},
|
||||
{"Connected", internal.StatusConnected},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := newTestServer()
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(tt.status)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.actCancel = cancel
|
||||
|
||||
resp, err := s.handleSleep(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.True(t, s.sleepTriggeredDown.Load())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -85,6 +85,9 @@ type Server struct {
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
|
||||
// sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down
|
||||
sleepTriggeredDown atomic.Bool
|
||||
|
||||
jwtCache *jwtCache
|
||||
}
|
||||
|
||||
@@ -504,7 +507,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
if msg.Hint != nil {
|
||||
hint = *msg.Hint
|
||||
}
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint)
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, false, hint)
|
||||
if err != nil {
|
||||
state.Set(internal.StatusLoginFailed)
|
||||
return nil, err
|
||||
@@ -819,6 +822,7 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if err := s.cleanupConnection(); err != nil {
|
||||
// todo review to update the status in case any type of error
|
||||
log.Errorf("failed to shut down properly: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
@@ -911,6 +915,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
|
||||
}
|
||||
|
||||
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
|
||||
// todo review to update the status in case any type of error
|
||||
log.Errorf("failed to cleanup connection: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
@@ -1235,7 +1240,7 @@ func (s *Server) RequestJWTAuth(
|
||||
}
|
||||
|
||||
isDesktop := isUnixRunningDesktop()
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, hint)
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, false, hint)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
|
||||
}
|
||||
|
||||
@@ -17,11 +17,12 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -35,7 +36,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -316,14 +316,17 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -278,6 +279,7 @@ type DialOptions struct {
|
||||
DaemonAddr string
|
||||
SkipCachedToken bool
|
||||
InsecureSkipVerify bool
|
||||
NoBrowser bool
|
||||
}
|
||||
|
||||
// Dial connects to the given ssh server with specified options
|
||||
@@ -307,7 +309,7 @@ func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, er
|
||||
config.Auth = append(config.Auth, authMethod)
|
||||
}
|
||||
|
||||
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken)
|
||||
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken, opts.NoBrowser)
|
||||
}
|
||||
|
||||
// dialSSH establishes an SSH connection without JWT authentication
|
||||
@@ -333,7 +335,7 @@ func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig
|
||||
}
|
||||
|
||||
// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection
|
||||
func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) {
|
||||
func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache, noBrowser bool) (*Client, error) {
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse address %s: %w", addr, err)
|
||||
@@ -359,7 +361,7 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo
|
||||
jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache)
|
||||
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache, noBrowser)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request JWT token: %w", err)
|
||||
}
|
||||
@@ -369,7 +371,7 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo
|
||||
}
|
||||
|
||||
// requestJWTToken requests a JWT token from the NetBird daemon
|
||||
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) {
|
||||
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache, noBrowser bool) (string, error) {
|
||||
hint := profilemanager.GetLoginHint()
|
||||
|
||||
conn, err := connectToDaemon(daemonAddr)
|
||||
@@ -379,7 +381,13 @@ func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (st
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint)
|
||||
|
||||
var browserOpener func(string) error
|
||||
if !noBrowser {
|
||||
browserOpener = util.OpenBrowser
|
||||
}
|
||||
|
||||
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint, browserOpener)
|
||||
}
|
||||
|
||||
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
|
||||
|
||||
@@ -67,8 +67,31 @@ func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKe
|
||||
return VerifyHostKey(storedKeyData, presentedKey, peerAddress)
|
||||
}
|
||||
|
||||
// printAuthInstructions prints authentication instructions to stderr
|
||||
func printAuthInstructions(stderr io.Writer, authResponse *proto.RequestJWTAuthResponse, browserWillOpen bool) {
|
||||
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
|
||||
|
||||
if browserWillOpen {
|
||||
_, _ = fmt.Fprintln(stderr, "Please do the SSO login in your browser.")
|
||||
_, _ = fmt.Fprintln(stderr, "If your browser didn't open automatically, use this URL to log in:")
|
||||
_, _ = fmt.Fprintln(stderr)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(stderr, "%s\n", authResponse.VerificationURIComplete)
|
||||
|
||||
if authResponse.UserCode != "" {
|
||||
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
|
||||
}
|
||||
|
||||
if browserWillOpen {
|
||||
_, _ = fmt.Fprintln(stderr)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
|
||||
}
|
||||
|
||||
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
|
||||
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string) (string, error) {
|
||||
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string, openBrowser func(string) error) (string, error) {
|
||||
req := &proto.RequestJWTAuthRequest{}
|
||||
if hint != "" {
|
||||
req.Hint = &hint
|
||||
@@ -84,12 +107,13 @@ func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdo
|
||||
}
|
||||
|
||||
if stderr != nil {
|
||||
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
|
||||
_, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete)
|
||||
if authResponse.UserCode != "" {
|
||||
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
|
||||
printAuthInstructions(stderr, authResponse, openBrowser != nil)
|
||||
}
|
||||
|
||||
if openBrowser != nil {
|
||||
if err := openBrowser(authResponse.VerificationURIComplete); err != nil {
|
||||
log.Debugf("open browser: %v", err)
|
||||
}
|
||||
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
|
||||
}
|
||||
|
||||
tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{
|
||||
|
||||
@@ -35,15 +35,16 @@ const (
|
||||
)
|
||||
|
||||
type SSHProxy struct {
|
||||
daemonAddr string
|
||||
targetHost string
|
||||
targetPort int
|
||||
stderr io.Writer
|
||||
conn *grpc.ClientConn
|
||||
daemonClient proto.DaemonServiceClient
|
||||
daemonAddr string
|
||||
targetHost string
|
||||
targetPort int
|
||||
stderr io.Writer
|
||||
conn *grpc.ClientConn
|
||||
daemonClient proto.DaemonServiceClient
|
||||
browserOpener func(string) error
|
||||
}
|
||||
|
||||
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) {
|
||||
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) {
|
||||
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
@@ -51,12 +52,13 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHP
|
||||
}
|
||||
|
||||
return &SSHProxy{
|
||||
daemonAddr: daemonAddr,
|
||||
targetHost: targetHost,
|
||||
targetPort: targetPort,
|
||||
stderr: stderr,
|
||||
conn: grpcConn,
|
||||
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||
daemonAddr: daemonAddr,
|
||||
targetHost: targetHost,
|
||||
targetPort: targetPort,
|
||||
stderr: stderr,
|
||||
conn: grpcConn,
|
||||
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||
browserOpener: browserOpener,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -70,7 +72,7 @@ func (p *SSHProxy) Close() error {
|
||||
func (p *SSHProxy) Connect(ctx context.Context) error {
|
||||
hint := profilemanager.GetLoginHint()
|
||||
|
||||
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint)
|
||||
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint, p.browserOpener)
|
||||
if err != nil {
|
||||
return fmt.Errorf(jwtAuthErrorMsg, err)
|
||||
}
|
||||
|
||||
@@ -153,7 +153,7 @@ func TestSSHProxy_Connect(t *testing.T) {
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, nil)
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
|
||||
@@ -42,6 +42,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// detectUtilLinuxLogin always returns false on JS/WASM
|
||||
func (s *Server) detectUtilLinuxLogin(context.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// executeCommandWithPty is not supported on JS/WASM
|
||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
logger.Errorf("PTY command execution not supported on JS/WASM")
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
@@ -75,6 +76,29 @@ func (s *Server) detectSuPtySupport(ctx context.Context) bool {
|
||||
return supported
|
||||
}
|
||||
|
||||
// detectUtilLinuxLogin checks if login is from util-linux (vs shadow-utils).
|
||||
// util-linux login uses vhangup() which requires setsid wrapper to avoid killing parent.
|
||||
// See https://bugs.debian.org/1078023 for details.
|
||||
func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
|
||||
if runtime.GOOS != "linux" {
|
||||
return false
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "login", "--version")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
log.Debugf("login --version failed (likely shadow-utils): %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
isUtilLinux := strings.Contains(string(output), "util-linux")
|
||||
log.Debugf("util-linux login detected: %v", isUtilLinux)
|
||||
return isUtilLinux
|
||||
}
|
||||
|
||||
// createSuCommand creates a command using su -l -c for privilege switching
|
||||
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||
suPath, err := exec.LookPath("su")
|
||||
@@ -144,7 +168,7 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResu
|
||||
return false
|
||||
}
|
||||
|
||||
logger.Infof("starting interactive shell: %s", execCmd.Path)
|
||||
logger.Infof("starting interactive shell: %s", strings.Join(execCmd.Args, " "))
|
||||
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
||||
}
|
||||
|
||||
|
||||
@@ -383,6 +383,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// detectUtilLinuxLogin always returns false on Windows
|
||||
func (s *Server) detectUtilLinuxLogin(context.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
command := session.RawCommand()
|
||||
|
||||
@@ -138,7 +138,8 @@ type Server struct {
|
||||
jwtExtractor *jwt.ClaimsExtractor
|
||||
jwtConfig *JWTConfig
|
||||
|
||||
suSupportsPty bool
|
||||
suSupportsPty bool
|
||||
loginIsUtilLinux bool
|
||||
}
|
||||
|
||||
type JWTConfig struct {
|
||||
@@ -193,6 +194,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||
}
|
||||
|
||||
s.suSupportsPty = s.detectSuPtySupport(ctx)
|
||||
s.loginIsUtilLinux = s.detectUtilLinuxLogin(ctx)
|
||||
|
||||
ln, addrDesc, err := s.createListener(ctx, addr)
|
||||
if err != nil {
|
||||
|
||||
@@ -87,11 +87,8 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
// Special handling for Arch Linux without /etc/pam.d/remote
|
||||
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
|
||||
return loginPath, []string{"-f", username, "-p"}, nil
|
||||
}
|
||||
return loginPath, []string{"-f", username, "-h", addrPort.Addr().String(), "-p"}, nil
|
||||
p, a := s.getLinuxLoginCmd(loginPath, username, addrPort.Addr().String())
|
||||
return p, a, nil
|
||||
case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil
|
||||
default:
|
||||
@@ -99,7 +96,37 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st
|
||||
}
|
||||
}
|
||||
|
||||
// fileExists checks if a file exists (helper for login command logic)
|
||||
// getLinuxLoginCmd returns the login command for Linux systems.
|
||||
// Handles differences between util-linux and shadow-utils login implementations.
|
||||
func (s *Server) getLinuxLoginCmd(loginPath, username, remoteIP string) (string, []string) {
|
||||
// Special handling for Arch Linux without /etc/pam.d/remote
|
||||
var loginArgs []string
|
||||
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
|
||||
loginArgs = []string{"-f", username, "-p"}
|
||||
} else {
|
||||
loginArgs = []string{"-f", username, "-h", remoteIP, "-p"}
|
||||
}
|
||||
|
||||
// util-linux login requires setsid -c to create a new session and set the
|
||||
// controlling terminal. Without this, vhangup() kills the parent process.
|
||||
// See https://bugs.debian.org/1078023 for details.
|
||||
// TODO: handle this via the executor using syscall.Setsid() + TIOCSCTTY + syscall.Exec()
|
||||
// to avoid external setsid dependency.
|
||||
if !s.loginIsUtilLinux {
|
||||
return loginPath, loginArgs
|
||||
}
|
||||
|
||||
setsidPath, err := exec.LookPath("setsid")
|
||||
if err != nil {
|
||||
log.Warnf("setsid not available but util-linux login detected, login may fail: %v", err)
|
||||
return loginPath, loginArgs
|
||||
}
|
||||
|
||||
args := append([]string{"-w", "-c", loginPath}, loginArgs...)
|
||||
return setsidPath, args
|
||||
}
|
||||
|
||||
// fileExists checks if a file exists
|
||||
func (s *Server) fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
|
||||
@@ -72,7 +72,8 @@ func IsSystemAccount(username string) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
||||
return strings.HasSuffix(username, "$")
|
||||
}
|
||||
|
||||
// RegisterTestUserCleanup registers a test user for cleanup
|
||||
|
||||
115
client/ssh/testutil/user_helpers_test.go
Normal file
115
client/ssh/testutil/user_helpers_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestUserCurrentBehavior validates user.Current() behavior on Windows.
|
||||
// When running as SYSTEM on a domain-joined machine, user.Current() returns:
|
||||
// - Username: Computer account name (e.g., "DOMAIN\MACHINE$")
|
||||
// - SID: SYSTEM SID (S-1-5-18)
|
||||
func TestUserCurrentBehavior(t *testing.T) {
|
||||
if runtime.GOOS != "windows" {
|
||||
t.Skip("Windows-specific test")
|
||||
}
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
t.Logf("Current user - Username: %s, SID: %s", currentUser.Username, currentUser.Uid)
|
||||
|
||||
// When running as SYSTEM, validate expected behavior
|
||||
if currentUser.Uid == "S-1-5-18" {
|
||||
t.Run("SYSTEM_account_behavior", func(t *testing.T) {
|
||||
// SID must be S-1-5-18 for SYSTEM
|
||||
require.Equal(t, "S-1-5-18", currentUser.Uid,
|
||||
"SYSTEM account must have SID S-1-5-18")
|
||||
|
||||
// Username can be either "NT AUTHORITY\SYSTEM" (standalone)
|
||||
// or "DOMAIN\MACHINE$" (domain-joined)
|
||||
username := currentUser.Username
|
||||
isNTAuthority := strings.Contains(strings.ToUpper(username), "NT AUTHORITY")
|
||||
isComputerAccount := strings.HasSuffix(username, "$")
|
||||
|
||||
assert.True(t, isNTAuthority || isComputerAccount,
|
||||
"Username should be either 'NT AUTHORITY\\SYSTEM' or computer account (ending with $), got: %s",
|
||||
username)
|
||||
|
||||
if isComputerAccount {
|
||||
t.Logf("SYSTEM as computer account: %s", username)
|
||||
} else if isNTAuthority {
|
||||
t.Logf("SYSTEM as NT AUTHORITY\\SYSTEM")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Validate that IsSystemAccount correctly identifies system accounts
|
||||
t.Run("IsSystemAccount_validation", func(t *testing.T) {
|
||||
// Test with current user if it's a system account
|
||||
if currentUser.Uid == "S-1-5-18" || // SYSTEM
|
||||
currentUser.Uid == "S-1-5-19" || // LOCAL SERVICE
|
||||
currentUser.Uid == "S-1-5-20" { // NETWORK SERVICE
|
||||
|
||||
result := IsSystemAccount(currentUser.Username)
|
||||
assert.True(t, result,
|
||||
"IsSystemAccount should recognize system account: %s (SID: %s)",
|
||||
currentUser.Username, currentUser.Uid)
|
||||
}
|
||||
|
||||
// Test explicit cases
|
||||
testCases := []struct {
|
||||
username string
|
||||
expected bool
|
||||
reason string
|
||||
}{
|
||||
{"NT AUTHORITY\\SYSTEM", true, "NT AUTHORITY\\SYSTEM"},
|
||||
{"system", true, "system"},
|
||||
{"SYSTEM", true, "SYSTEM (case insensitive)"},
|
||||
{"NT AUTHORITY\\LOCAL SERVICE", true, "LOCAL SERVICE"},
|
||||
{"NT AUTHORITY\\NETWORK SERVICE", true, "NETWORK SERVICE"},
|
||||
{"DOMAIN\\MACHINE$", true, "computer account (ends with $)"},
|
||||
{"WORKGROUP\\WIN2K19-C2$", true, "computer account (ends with $)"},
|
||||
{"Administrator", false, "Administrator is not a system account"},
|
||||
{"alice", false, "regular user"},
|
||||
{"DOMAIN\\alice", false, "domain user"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.username, func(t *testing.T) {
|
||||
result := IsSystemAccount(tc.username)
|
||||
assert.Equal(t, tc.expected, result,
|
||||
"IsSystemAccount(%q) should be %v because: %s",
|
||||
tc.username, tc.expected, tc.reason)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestComputerAccountDetection validates computer account detection.
|
||||
func TestComputerAccountDetection(t *testing.T) {
|
||||
if runtime.GOOS != "windows" {
|
||||
t.Skip("Windows-specific test")
|
||||
}
|
||||
|
||||
computerAccounts := []string{
|
||||
"MACHINE$",
|
||||
"WIN2K19-C2$",
|
||||
"DOMAIN\\MACHINE$",
|
||||
"WORKGROUP\\SERVER$",
|
||||
"server.domain.com$",
|
||||
}
|
||||
|
||||
for _, account := range computerAccounts {
|
||||
t.Run(account, func(t *testing.T) {
|
||||
result := IsSystemAccount(account)
|
||||
assert.True(t, result,
|
||||
"Computer account %q should be recognized as system account", account)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -38,6 +38,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/sleep"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ui/desktop"
|
||||
"github.com/netbirdio/netbird/client/ui/event"
|
||||
@@ -209,10 +210,11 @@ var iconConnectedDot []byte
|
||||
var iconDisconnectedDot []byte
|
||||
|
||||
type serviceClient struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
addr string
|
||||
conn proto.DaemonServiceClient
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
addr string
|
||||
conn proto.DaemonServiceClient
|
||||
connLock sync.Mutex
|
||||
|
||||
eventHandler *eventHandler
|
||||
|
||||
@@ -1098,6 +1100,9 @@ func (s *serviceClient) onTrayReady() {
|
||||
|
||||
go s.eventManager.Start(s.ctx)
|
||||
go s.eventHandler.listen(s.ctx)
|
||||
|
||||
// Start sleep detection listener
|
||||
go s.startSleepListener()
|
||||
}
|
||||
|
||||
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
|
||||
@@ -1134,6 +1139,8 @@ func (s *serviceClient) onTrayExit() {
|
||||
|
||||
// getSrvClient connection to the service.
|
||||
func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonServiceClient, error) {
|
||||
s.connLock.Lock()
|
||||
defer s.connLock.Unlock()
|
||||
if s.conn != nil {
|
||||
return s.conn, nil
|
||||
}
|
||||
@@ -1156,6 +1163,62 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
|
||||
return s.conn, nil
|
||||
}
|
||||
|
||||
// startSleepListener initializes the sleep detection service and listens for sleep events
|
||||
func (s *serviceClient) startSleepListener() {
|
||||
sleepService, err := sleep.New()
|
||||
if err != nil {
|
||||
log.Warnf("%v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := sleepService.Register(s.handleSleepEvents); err != nil {
|
||||
log.Errorf("failed to start sleep detection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("sleep detection service initialized")
|
||||
|
||||
// Cleanup on context cancellation
|
||||
go func() {
|
||||
<-s.ctx.Done()
|
||||
log.Info("stopping sleep event listener")
|
||||
if err := sleepService.Deregister(); err != nil {
|
||||
log.Errorf("failed to deregister sleep detection: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// handleSleepEvents sends a sleep notification to the daemon via gRPC
|
||||
func (s *serviceClient) handleSleepEvents(event sleep.EventType) {
|
||||
conn, err := s.getSrvClient(0)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get daemon client for sleep notification: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
req := &proto.OSLifecycleRequest{}
|
||||
|
||||
switch event {
|
||||
case sleep.EventTypeWakeUp:
|
||||
log.Infof("handle wakeup event: %v", event)
|
||||
req.Type = proto.OSLifecycleRequest_WAKEUP
|
||||
case sleep.EventTypeSleep:
|
||||
log.Infof("handle sleep event: %v", event)
|
||||
req.Type = proto.OSLifecycleRequest_SLEEP
|
||||
default:
|
||||
log.Infof("unknown event: %v", event)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.NotifyOSLifecycle(s.ctx, req)
|
||||
if err != nil {
|
||||
log.Errorf("failed to notify daemon about os lifecycle notification: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("successfully notified daemon about os lifecycle")
|
||||
}
|
||||
|
||||
// setSettingsEnabled enables or disables the settings menu based on the provided state
|
||||
func (s *serviceClient) setSettingsEnabled(enabled bool) {
|
||||
if s.mSettings != nil {
|
||||
|
||||
@@ -28,7 +28,8 @@ func IsAnotherProcessRunning() (int32, bool, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(strings.ToLower(runningProcessPath), processName) && isProcessOwnedByCurrentUser(p) {
|
||||
runningProcessName := strings.ToLower(filepath.Base(runningProcessPath))
|
||||
if runningProcessName == processName && isProcessOwnedByCurrentUser(p) {
|
||||
return p.Pid, true, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,6 +45,10 @@ type CustomZone struct {
|
||||
Domain string
|
||||
// Records custom zone records
|
||||
Records []SimpleRecord
|
||||
// SearchDomainDisabled indicates whether to add match domains to a search domains list or not
|
||||
SearchDomainDisabled bool
|
||||
// SkipPTRProcess indicates whether a client should process PTR records from custom zones
|
||||
SkipPTRProcess bool
|
||||
}
|
||||
|
||||
// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records
|
||||
|
||||
@@ -60,14 +60,7 @@ func (hook ContextHook) Fire(entry *logrus.Entry) error {
|
||||
|
||||
entry.Data["context"] = source
|
||||
|
||||
switch source {
|
||||
case HTTPSource:
|
||||
addHTTPFields(entry)
|
||||
case GRPCSource:
|
||||
addGRPCFields(entry)
|
||||
case SystemSource:
|
||||
addSystemFields(entry)
|
||||
}
|
||||
addFields(entry)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -99,7 +92,7 @@ func (hook ContextHook) parseSrc(filePath string) string {
|
||||
return fmt.Sprintf("%s/%s", pkg, file)
|
||||
}
|
||||
|
||||
func addHTTPFields(entry *logrus.Entry) {
|
||||
func addFields(entry *logrus.Entry) {
|
||||
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
|
||||
entry.Data[context.RequestIDKey] = ctxReqID
|
||||
}
|
||||
@@ -109,30 +102,6 @@ func addHTTPFields(entry *logrus.Entry) {
|
||||
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
|
||||
entry.Data[context.UserIDKey] = ctxInitiatorID
|
||||
}
|
||||
}
|
||||
|
||||
func addGRPCFields(entry *logrus.Entry) {
|
||||
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
|
||||
entry.Data[context.RequestIDKey] = ctxReqID
|
||||
}
|
||||
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
|
||||
entry.Data[context.AccountIDKey] = ctxAccountID
|
||||
}
|
||||
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
|
||||
entry.Data[context.PeerIDKey] = ctxDeviceID
|
||||
}
|
||||
}
|
||||
|
||||
func addSystemFields(entry *logrus.Entry) {
|
||||
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
|
||||
entry.Data[context.RequestIDKey] = ctxReqID
|
||||
}
|
||||
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
|
||||
entry.Data[context.UserIDKey] = ctxInitiatorID
|
||||
}
|
||||
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
|
||||
entry.Data[context.AccountIDKey] = ctxAccountID
|
||||
}
|
||||
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
|
||||
entry.Data[context.PeerIDKey] = ctxDeviceID
|
||||
}
|
||||
|
||||
24
go.mod
24
go.mod
@@ -1,6 +1,6 @@
|
||||
module github.com/netbirdio/netbird
|
||||
|
||||
go 1.23.1
|
||||
go 1.24.10
|
||||
|
||||
require (
|
||||
cunicu.li/go-rosenpass v0.4.0
|
||||
@@ -17,8 +17,8 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
golang.org/x/crypto v0.41.0
|
||||
golang.org/x/sys v0.35.0
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.org/x/sys v0.38.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
@@ -64,7 +64,7 @@ require (
|
||||
github.com/mdlayher/socket v0.5.1
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
@@ -105,12 +105,12 @@ require (
|
||||
go.uber.org/zap v1.27.0
|
||||
goauthentik.io/api/v3 v3.2023051.3
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
|
||||
golang.org/x/mod v0.26.0
|
||||
golang.org/x/net v0.42.0
|
||||
golang.org/x/mobile v0.0.0-20251113184115-a159579294ab
|
||||
golang.org/x/mod v0.30.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/sync v0.16.0
|
||||
golang.org/x/term v0.34.0
|
||||
golang.org/x/sync v0.18.0
|
||||
golang.org/x/term v0.37.0
|
||||
golang.org/x/time v0.12.0
|
||||
google.golang.org/api v0.177.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
@@ -251,9 +251,9 @@ require (
|
||||
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/image v0.24.0 // indirect
|
||||
golang.org/x/text v0.28.0 // indirect
|
||||
golang.org/x/tools v0.35.0 // indirect
|
||||
golang.org/x/image v0.33.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
golang.org/x/tools v0.39.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||
|
||||
44
go.sum
44
go.sum
@@ -368,8 +368,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
@@ -600,19 +600,19 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
||||
golang.org/x/image v0.24.0 h1:AN7zRgVsbvmTfNyqIbbOraYL8mSwcKncEj8ofjgzcMQ=
|
||||
golang.org/x/image v0.24.0/go.mod h1:4b/ITuLfqYq1hqZcjofwctIhi7sZh2WaCjvsBNjjya8=
|
||||
golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ=
|
||||
golang.org/x/image v0.33.0/go.mod h1:DD3OsTYT9chzuzTQt+zMcOlBHgfoKQb1gry8p76Y1sc=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
||||
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a h1:sYbmY3FwUWCBTodZL1S3JUuOvaW6kM2o+clDzzDNBWg=
|
||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a/go.mod h1:Ede7gF0KGoHlj822RtphAHK1jLdrcuRBZg0sF1Q+SPc=
|
||||
golang.org/x/mobile v0.0.0-20251113184115-a159579294ab h1:Iqyc+2zr7aGyLuEadIm0KRJP0Wwt+fhlXLa51Fxf1+Q=
|
||||
golang.org/x/mobile v0.0.0-20251113184115-a159579294ab/go.mod h1:Eq3Nh/5pFSWug2ohiudJ1iyU59SO78QFuh4qTTN++I0=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
@@ -622,8 +622,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -647,8 +647,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
|
||||
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
@@ -665,8 +665,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@@ -703,8 +703,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
@@ -717,8 +717,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
||||
golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4=
|
||||
golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
@@ -730,8 +730,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -749,8 +749,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
||||
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
35
idp/cmd/env.go
Normal file
35
idp/cmd/env.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
)
|
||||
|
||||
// setFlagsFromEnvVars reads and updates flag values from environment variables with prefix NB_IDP_
|
||||
func setFlagsFromEnvVars(cmd *cobra.Command) {
|
||||
flags := cmd.PersistentFlags()
|
||||
flags.VisitAll(func(f *pflag.Flag) {
|
||||
newEnvVar := flagNameToEnvVar(f.Name, "NB_IDP_")
|
||||
value, present := os.LookupEnv(newEnvVar)
|
||||
if !present {
|
||||
return
|
||||
}
|
||||
|
||||
err := flags.Set(f.Name, value)
|
||||
if err != nil {
|
||||
log.Infof("unable to configure flag %s using variable %s, err: %v", f.Name, newEnvVar, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// flagNameToEnvVar converts flag name to environment var name adding a prefix,
|
||||
// replacing dashes and making all uppercase (e.g. data-dir is converted to NB_IDP_DATA_DIR)
|
||||
func flagNameToEnvVar(cmdFlag string, prefix string) string {
|
||||
parsed := strings.ReplaceAll(cmdFlag, "-", "_")
|
||||
upper := strings.ToUpper(parsed)
|
||||
return prefix + upper
|
||||
}
|
||||
148
idp/cmd/root.go
Normal file
148
idp/cmd/root.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/idp/oidcprovider"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// Config holds the IdP server configuration
|
||||
type Config struct {
|
||||
ListenPort int
|
||||
Issuer string
|
||||
DataDir string
|
||||
LogLevel string
|
||||
LogFile string
|
||||
DevMode bool
|
||||
DashboardRedirectURIs []string
|
||||
CLIRedirectURIs []string
|
||||
DashboardClientID string
|
||||
CLIClientID string
|
||||
}
|
||||
|
||||
var (
|
||||
config *Config
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "idp",
|
||||
Short: "NetBird Identity Provider",
|
||||
Long: "Embedded OIDC Identity Provider for NetBird",
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
RunE: execute,
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
_ = util.InitLog("trace", util.LogConsole)
|
||||
config = &Config{}
|
||||
|
||||
rootCmd.PersistentFlags().IntVarP(&config.ListenPort, "port", "p", 33081, "port to listen on")
|
||||
rootCmd.PersistentFlags().StringVarP(&config.Issuer, "issuer", "i", "", "OIDC issuer URL (default: http://localhost:<port>)")
|
||||
rootCmd.PersistentFlags().StringVarP(&config.DataDir, "data-dir", "d", "/var/lib/netbird", "directory to store IdP data")
|
||||
rootCmd.PersistentFlags().StringVar(&config.LogLevel, "log-level", "info", "log level (trace, debug, info, warn, error)")
|
||||
rootCmd.PersistentFlags().StringVar(&config.LogFile, "log-file", "console", "log file path or 'console'")
|
||||
rootCmd.PersistentFlags().BoolVar(&config.DevMode, "dev-mode", false, "enable development mode (allows HTTP)")
|
||||
rootCmd.PersistentFlags().StringSliceVar(&config.DashboardRedirectURIs, "dashboard-redirect-uris", []string{
|
||||
"http://localhost:3000/callback",
|
||||
"http://localhost:3000/silent-callback",
|
||||
}, "allowed redirect URIs for dashboard client")
|
||||
rootCmd.PersistentFlags().StringSliceVar(&config.CLIRedirectURIs, "cli-redirect-uris", []string{
|
||||
"http://localhost:53000",
|
||||
"http://localhost:54000",
|
||||
}, "allowed redirect URIs for CLI client")
|
||||
rootCmd.PersistentFlags().StringVar(&config.DashboardClientID, "dashboard-client-id", "netbird-dashboard", "client ID for dashboard")
|
||||
rootCmd.PersistentFlags().StringVar(&config.CLIClientID, "cli-client-id", "netbird-client", "client ID for CLI")
|
||||
|
||||
// Add subcommands
|
||||
rootCmd.AddCommand(userCmd)
|
||||
|
||||
setFlagsFromEnvVars(rootCmd)
|
||||
}
|
||||
|
||||
// Execute runs the root command
|
||||
func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
func execute(cmd *cobra.Command, args []string) error {
|
||||
err := util.InitLog(config.LogLevel, config.LogFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize log: %s", err)
|
||||
}
|
||||
|
||||
// Set default issuer if not provided
|
||||
issuer := config.Issuer
|
||||
if issuer == "" {
|
||||
issuer = fmt.Sprintf("http://localhost:%d", config.ListenPort)
|
||||
}
|
||||
|
||||
log.Infof("Starting NetBird Identity Provider")
|
||||
log.Infof(" Port: %d", config.ListenPort)
|
||||
log.Infof(" Issuer: %s", issuer)
|
||||
log.Infof(" Data directory: %s", config.DataDir)
|
||||
log.Infof(" Dev mode: %v", config.DevMode)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create provider config
|
||||
providerConfig := &oidcprovider.Config{
|
||||
Issuer: issuer,
|
||||
Port: config.ListenPort,
|
||||
DataDir: config.DataDir,
|
||||
DevMode: config.DevMode,
|
||||
}
|
||||
|
||||
// Create the provider
|
||||
provider, err := oidcprovider.NewProvider(ctx, providerConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create IdP: %w", err)
|
||||
}
|
||||
|
||||
// Ensure default clients exist
|
||||
if err := provider.EnsureDefaultClients(ctx, config.DashboardRedirectURIs, config.CLIRedirectURIs); err != nil {
|
||||
return fmt.Errorf("failed to create default clients: %w", err)
|
||||
}
|
||||
|
||||
// Start the provider
|
||||
if err := provider.Start(ctx); err != nil {
|
||||
return fmt.Errorf("failed to start IdP: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("IdP is running")
|
||||
log.Infof(" Discovery: %s/.well-known/openid-configuration", issuer)
|
||||
log.Infof(" Authorization: %s/authorize", issuer)
|
||||
log.Infof(" Token: %s/oauth/token", issuer)
|
||||
log.Infof(" Device authorization: %s/device_authorization", issuer)
|
||||
log.Infof(" JWKS: %s/keys", issuer)
|
||||
log.Infof(" Login: %s/login", issuer)
|
||||
log.Infof(" Device flow: %s/device", issuer)
|
||||
|
||||
// Wait for exit signal
|
||||
waitForExitSignal()
|
||||
|
||||
log.Infof("Shutting down IdP...")
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10)
|
||||
defer shutdownCancel()
|
||||
|
||||
if err := provider.Stop(shutdownCtx); err != nil {
|
||||
return fmt.Errorf("failed to stop IdP: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("IdP stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
func waitForExitSignal() {
|
||||
osSigs := make(chan os.Signal, 1)
|
||||
signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-osSigs
|
||||
}
|
||||
249
idp/cmd/user.go
Normal file
249
idp/cmd/user.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
"text/tabwriter"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/netbirdio/netbird/idp/oidcprovider"
|
||||
)
|
||||
|
||||
var userCmd = &cobra.Command{
|
||||
Use: "user",
|
||||
Short: "Manage IdP users",
|
||||
Long: "Commands for managing users in the embedded IdP",
|
||||
}
|
||||
|
||||
var userAddCmd = &cobra.Command{
|
||||
Use: "add",
|
||||
Short: "Add a new user",
|
||||
Long: "Add a new user to the embedded IdP",
|
||||
RunE: userAdd,
|
||||
}
|
||||
|
||||
var userListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List all users",
|
||||
Long: "List all users in the embedded IdP",
|
||||
RunE: userList,
|
||||
}
|
||||
|
||||
var userDeleteCmd = &cobra.Command{
|
||||
Use: "delete <username>",
|
||||
Short: "Delete a user",
|
||||
Long: "Delete a user from the embedded IdP",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: userDelete,
|
||||
}
|
||||
|
||||
var userPasswordCmd = &cobra.Command{
|
||||
Use: "password <username>",
|
||||
Short: "Change user password",
|
||||
Long: "Change password for a user in the embedded IdP",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: userChangePassword,
|
||||
}
|
||||
|
||||
// User add flags
|
||||
var (
|
||||
userUsername string
|
||||
userEmail string
|
||||
userFirstName string
|
||||
userLastName string
|
||||
userPassword string
|
||||
)
|
||||
|
||||
func init() {
|
||||
userAddCmd.Flags().StringVarP(&userUsername, "username", "u", "", "username (required)")
|
||||
userAddCmd.Flags().StringVarP(&userEmail, "email", "e", "", "email address (required)")
|
||||
userAddCmd.Flags().StringVarP(&userFirstName, "first-name", "f", "", "first name")
|
||||
userAddCmd.Flags().StringVarP(&userLastName, "last-name", "l", "", "last name")
|
||||
userAddCmd.Flags().StringVarP(&userPassword, "password", "p", "", "password (will prompt if not provided)")
|
||||
_ = userAddCmd.MarkFlagRequired("username")
|
||||
_ = userAddCmd.MarkFlagRequired("email")
|
||||
|
||||
userCmd.AddCommand(userAddCmd)
|
||||
userCmd.AddCommand(userListCmd)
|
||||
userCmd.AddCommand(userDeleteCmd)
|
||||
userCmd.AddCommand(userPasswordCmd)
|
||||
}
|
||||
|
||||
func getStore() (*oidcprovider.Store, error) {
|
||||
ctx := context.Background()
|
||||
store, err := oidcprovider.NewStore(ctx, config.DataDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open store: %w", err)
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func userAdd(cmd *cobra.Command, args []string) error {
|
||||
store, err := getStore()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
password := userPassword
|
||||
if password == "" {
|
||||
// Prompt for password
|
||||
fmt.Print("Enter password: ")
|
||||
bytePassword, err := term.ReadPassword(int(syscall.Stdin))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read password: %w", err)
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
fmt.Print("Confirm password: ")
|
||||
byteConfirm, err := term.ReadPassword(int(syscall.Stdin))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read password confirmation: %w", err)
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
if string(bytePassword) != string(byteConfirm) {
|
||||
return fmt.Errorf("passwords do not match")
|
||||
}
|
||||
password = string(bytePassword)
|
||||
}
|
||||
|
||||
if password == "" {
|
||||
return fmt.Errorf("password cannot be empty")
|
||||
}
|
||||
|
||||
user := &oidcprovider.User{
|
||||
Username: userUsername,
|
||||
Email: userEmail,
|
||||
FirstName: userFirstName,
|
||||
LastName: userLastName,
|
||||
Password: password,
|
||||
EmailVerified: true, // Mark as verified since admin is creating the user
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if err := store.CreateUser(ctx, user); err != nil {
|
||||
return fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("User '%s' created successfully (ID: %s)\n", userUsername, user.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func userList(cmd *cobra.Command, args []string) error {
|
||||
store, err := getStore()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
users, err := store.ListUsers(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list users: %w", err)
|
||||
}
|
||||
|
||||
if len(users) == 0 {
|
||||
fmt.Println("No users found")
|
||||
return nil
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||
fmt.Fprintln(w, "ID\tUSERNAME\tEMAIL\tNAME\tVERIFIED\tCREATED")
|
||||
for _, user := range users {
|
||||
name := fmt.Sprintf("%s %s", user.FirstName, user.LastName)
|
||||
verified := "No"
|
||||
if user.EmailVerified {
|
||||
verified = "Yes"
|
||||
}
|
||||
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||
user.ID,
|
||||
user.Username,
|
||||
user.Email,
|
||||
name,
|
||||
verified,
|
||||
user.CreatedAt.Format("2006-01-02 15:04"),
|
||||
)
|
||||
}
|
||||
w.Flush()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func userDelete(cmd *cobra.Command, args []string) error {
|
||||
username := args[0]
|
||||
|
||||
store, err := getStore()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Find user by username
|
||||
user, err := store.GetUserByUsername(ctx, username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("user '%s' not found", username)
|
||||
}
|
||||
|
||||
if err := store.DeleteUser(ctx, user.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("User '%s' deleted successfully\n", username)
|
||||
return nil
|
||||
}
|
||||
|
||||
func userChangePassword(cmd *cobra.Command, args []string) error {
|
||||
username := args[0]
|
||||
|
||||
store, err := getStore()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Find user by username
|
||||
user, err := store.GetUserByUsername(ctx, username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("user '%s' not found", username)
|
||||
}
|
||||
|
||||
// Prompt for new password
|
||||
fmt.Print("Enter new password: ")
|
||||
bytePassword, err := term.ReadPassword(int(syscall.Stdin))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read password: %w", err)
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
fmt.Print("Confirm new password: ")
|
||||
byteConfirm, err := term.ReadPassword(int(syscall.Stdin))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read password confirmation: %w", err)
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
if string(bytePassword) != string(byteConfirm) {
|
||||
return fmt.Errorf("passwords do not match")
|
||||
}
|
||||
|
||||
password := string(bytePassword)
|
||||
if password == "" {
|
||||
return fmt.Errorf("password cannot be empty")
|
||||
}
|
||||
|
||||
if err := store.UpdateUserPassword(ctx, user.ID, password); err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Password updated for user '%s'\n", username)
|
||||
return nil
|
||||
}
|
||||
13
idp/main.go
Normal file
13
idp/main.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/idp/cmd"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := cmd.Execute(); err != nil {
|
||||
log.Fatalf("failed to execute command: %v", err)
|
||||
}
|
||||
}
|
||||
249
idp/oidcprovider/client.go
Normal file
249
idp/oidcprovider/client.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package oidcprovider
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
// OIDCClient wraps the database Client model and implements op.Client interface
|
||||
type OIDCClient struct {
|
||||
client *Client
|
||||
loginURL func(string) string
|
||||
redirectURIs []string
|
||||
grantTypes []oidc.GrantType
|
||||
responseTypes []oidc.ResponseType
|
||||
}
|
||||
|
||||
// NewOIDCClient creates an OIDCClient from a database Client
|
||||
func NewOIDCClient(client *Client, loginURL func(string) string) *OIDCClient {
|
||||
return &OIDCClient{
|
||||
client: client,
|
||||
loginURL: loginURL,
|
||||
redirectURIs: ParseJSONArray(client.RedirectURIs),
|
||||
grantTypes: parseGrantTypes(client.GrantTypes),
|
||||
responseTypes: parseResponseTypes(client.ResponseTypes),
|
||||
}
|
||||
}
|
||||
|
||||
// GetID returns the client ID
|
||||
func (c *OIDCClient) GetID() string {
|
||||
return c.client.ID
|
||||
}
|
||||
|
||||
// RedirectURIs returns the registered redirect URIs
|
||||
func (c *OIDCClient) RedirectURIs() []string {
|
||||
return c.redirectURIs
|
||||
}
|
||||
|
||||
// PostLogoutRedirectURIs returns the registered post-logout redirect URIs
|
||||
func (c *OIDCClient) PostLogoutRedirectURIs() []string {
|
||||
return ParseJSONArray(c.client.PostLogoutURIs)
|
||||
}
|
||||
|
||||
// ApplicationType returns the application type (native, web, user_agent)
|
||||
func (c *OIDCClient) ApplicationType() op.ApplicationType {
|
||||
switch c.client.ApplicationType {
|
||||
case "native":
|
||||
return op.ApplicationTypeNative
|
||||
case "web":
|
||||
return op.ApplicationTypeWeb
|
||||
case "user_agent":
|
||||
return op.ApplicationTypeUserAgent
|
||||
default:
|
||||
return op.ApplicationTypeWeb
|
||||
}
|
||||
}
|
||||
|
||||
// AuthMethod returns the authentication method
|
||||
func (c *OIDCClient) AuthMethod() oidc.AuthMethod {
|
||||
switch c.client.AuthMethod {
|
||||
case "none":
|
||||
return oidc.AuthMethodNone
|
||||
case "client_secret_basic":
|
||||
return oidc.AuthMethodBasic
|
||||
case "client_secret_post":
|
||||
return oidc.AuthMethodPost
|
||||
case "private_key_jwt":
|
||||
return oidc.AuthMethodPrivateKeyJWT
|
||||
default:
|
||||
return oidc.AuthMethodNone
|
||||
}
|
||||
}
|
||||
|
||||
// ResponseTypes returns the allowed response types
|
||||
func (c *OIDCClient) ResponseTypes() []oidc.ResponseType {
|
||||
return c.responseTypes
|
||||
}
|
||||
|
||||
// GrantTypes returns the allowed grant types
|
||||
func (c *OIDCClient) GrantTypes() []oidc.GrantType {
|
||||
return c.grantTypes
|
||||
}
|
||||
|
||||
// LoginURL returns the login URL for this client
|
||||
func (c *OIDCClient) LoginURL(authRequestID string) string {
|
||||
if c.loginURL != nil {
|
||||
return c.loginURL(authRequestID)
|
||||
}
|
||||
return "/login?authRequestID=" + authRequestID
|
||||
}
|
||||
|
||||
// AccessTokenType returns the access token type
|
||||
func (c *OIDCClient) AccessTokenType() op.AccessTokenType {
|
||||
switch c.client.AccessTokenType {
|
||||
case "jwt":
|
||||
return op.AccessTokenTypeJWT
|
||||
default:
|
||||
return op.AccessTokenTypeBearer
|
||||
}
|
||||
}
|
||||
|
||||
// IDTokenLifetime returns the ID token lifetime
|
||||
func (c *OIDCClient) IDTokenLifetime() time.Duration {
|
||||
if c.client.IDTokenLifetime > 0 {
|
||||
return time.Duration(c.client.IDTokenLifetime) * time.Second
|
||||
}
|
||||
return time.Hour // default 1 hour
|
||||
}
|
||||
|
||||
// DevMode returns whether the client is in development mode
|
||||
func (c *OIDCClient) DevMode() bool {
|
||||
return c.client.DevMode
|
||||
}
|
||||
|
||||
// RestrictAdditionalIdTokenScopes returns any restricted scopes for ID tokens
|
||||
func (c *OIDCClient) RestrictAdditionalIdTokenScopes() func(scopes []string) []string {
|
||||
return func(scopes []string) []string {
|
||||
return scopes
|
||||
}
|
||||
}
|
||||
|
||||
// RestrictAdditionalAccessTokenScopes returns any restricted scopes for access tokens
|
||||
func (c *OIDCClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string {
|
||||
return func(scopes []string) []string {
|
||||
return scopes
|
||||
}
|
||||
}
|
||||
|
||||
// IsScopeAllowed checks if a scope is allowed for this client
|
||||
func (c *OIDCClient) IsScopeAllowed(scope string) bool {
|
||||
// Allow all standard OIDC scopes
|
||||
switch scope {
|
||||
case oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone, oidc.ScopeAddress, oidc.ScopeOfflineAccess:
|
||||
return true
|
||||
}
|
||||
return true // Allow custom scopes as well
|
||||
}
|
||||
|
||||
// IDTokenUserinfoClaimsAssertion returns whether userinfo claims should be included in ID token
|
||||
func (c *OIDCClient) IDTokenUserinfoClaimsAssertion() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ClockSkew returns the allowed clock skew for this client
|
||||
func (c *OIDCClient) ClockSkew() time.Duration {
|
||||
if c.client.ClockSkew > 0 {
|
||||
return time.Duration(c.client.ClockSkew) * time.Second
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Helper functions for parsing grant types and response types
|
||||
|
||||
func parseGrantTypes(jsonStr string) []oidc.GrantType {
|
||||
types := ParseJSONArray(jsonStr)
|
||||
if len(types) == 0 {
|
||||
// Default grant types
|
||||
return []oidc.GrantType{
|
||||
oidc.GrantTypeCode,
|
||||
oidc.GrantTypeRefreshToken,
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]oidc.GrantType, 0, len(types))
|
||||
for _, t := range types {
|
||||
switch t {
|
||||
case "authorization_code":
|
||||
result = append(result, oidc.GrantTypeCode)
|
||||
case "refresh_token":
|
||||
result = append(result, oidc.GrantTypeRefreshToken)
|
||||
case "client_credentials":
|
||||
result = append(result, oidc.GrantTypeClientCredentials)
|
||||
case "urn:ietf:params:oauth:grant-type:device_code":
|
||||
result = append(result, oidc.GrantTypeDeviceCode)
|
||||
case "urn:ietf:params:oauth:grant-type:token-exchange":
|
||||
result = append(result, oidc.GrantTypeTokenExchange)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseResponseTypes(jsonStr string) []oidc.ResponseType {
|
||||
types := ParseJSONArray(jsonStr)
|
||||
if len(types) == 0 {
|
||||
// Default response types
|
||||
return []oidc.ResponseType{oidc.ResponseTypeCode}
|
||||
}
|
||||
|
||||
result := make([]oidc.ResponseType, 0, len(types))
|
||||
for _, t := range types {
|
||||
switch t {
|
||||
case "code":
|
||||
result = append(result, oidc.ResponseTypeCode)
|
||||
case "id_token":
|
||||
result = append(result, oidc.ResponseTypeIDToken)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// CreateNativeClient creates a native client configuration (for CLI/mobile apps with PKCE)
|
||||
func CreateNativeClient(id, name string, redirectURIs []string) *Client {
|
||||
return &Client{
|
||||
ID: id,
|
||||
Name: name,
|
||||
RedirectURIs: ToJSONArray(redirectURIs),
|
||||
ApplicationType: "native",
|
||||
AuthMethod: "none", // Public client
|
||||
ResponseTypes: ToJSONArray([]string{"code"}),
|
||||
GrantTypes: ToJSONArray([]string{"authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:device_code"}),
|
||||
AccessTokenType: "bearer",
|
||||
DevMode: true,
|
||||
IDTokenLifetime: 3600,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateWebClient creates a web client configuration (for SPAs/web apps)
|
||||
func CreateWebClient(id, secret, name string, redirectURIs []string) *Client {
|
||||
return &Client{
|
||||
ID: id,
|
||||
Secret: secret,
|
||||
Name: name,
|
||||
RedirectURIs: ToJSONArray(redirectURIs),
|
||||
ApplicationType: "web",
|
||||
AuthMethod: "client_secret_basic",
|
||||
ResponseTypes: ToJSONArray([]string{"code"}),
|
||||
GrantTypes: ToJSONArray([]string{"authorization_code", "refresh_token"}),
|
||||
AccessTokenType: "bearer",
|
||||
DevMode: false,
|
||||
IDTokenLifetime: 3600,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSPAClient creates a Single Page Application client configuration (public client for SPAs)
|
||||
func CreateSPAClient(id, name string, redirectURIs []string) *Client {
|
||||
return &Client{
|
||||
ID: id,
|
||||
Name: name,
|
||||
RedirectURIs: ToJSONArray(redirectURIs),
|
||||
ApplicationType: "user_agent",
|
||||
AuthMethod: "none", // Public client for SPA
|
||||
ResponseTypes: ToJSONArray([]string{"code"}),
|
||||
GrantTypes: ToJSONArray([]string{"authorization_code", "refresh_token"}),
|
||||
AccessTokenType: "bearer",
|
||||
DevMode: true,
|
||||
IDTokenLifetime: 3600,
|
||||
}
|
||||
}
|
||||
220
idp/oidcprovider/device.go
Normal file
220
idp/oidcprovider/device.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package oidcprovider
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"html/template"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/gorilla/securecookie"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DeviceHandler handles the device authorization flow
|
||||
type DeviceHandler struct {
|
||||
storage *OIDCStorage
|
||||
tmpl *template.Template
|
||||
secureCookie *securecookie.SecureCookie
|
||||
}
|
||||
|
||||
// NewDeviceHandler creates a new device handler
|
||||
func NewDeviceHandler(storage *OIDCStorage) (*DeviceHandler, error) {
|
||||
tmpl, err := template.ParseFS(templateFS, "templates/*.html")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate secure cookie keys
|
||||
hashKey := securecookie.GenerateRandomKey(32)
|
||||
blockKey := securecookie.GenerateRandomKey(32)
|
||||
|
||||
return &DeviceHandler{
|
||||
storage: storage,
|
||||
tmpl: tmpl,
|
||||
secureCookie: securecookie.New(hashKey, blockKey),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Router returns the device flow router
|
||||
func (h *DeviceHandler) Router() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/", h.userCodePage)
|
||||
r.Post("/login", h.handleLogin)
|
||||
r.Post("/confirm", h.handleConfirm)
|
||||
return r
|
||||
}
|
||||
|
||||
// userCodePage displays the user code entry form
|
||||
func (h *DeviceHandler) userCodePage(w http.ResponseWriter, r *http.Request) {
|
||||
userCode := r.URL.Query().Get("user_code")
|
||||
|
||||
data := map[string]interface{}{
|
||||
"UserCode": userCode,
|
||||
"Error": "",
|
||||
"Step": "code", // code, login, or confirm
|
||||
}
|
||||
|
||||
if userCode != "" {
|
||||
// Verify the user code exists
|
||||
_, err := h.storage.GetDeviceAuthorizationByUserCode(r.Context(), userCode)
|
||||
if err != nil {
|
||||
data["Error"] = "Invalid or expired user code"
|
||||
data["UserCode"] = ""
|
||||
} else {
|
||||
data["Step"] = "login"
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.tmpl.ExecuteTemplate(w, "device.html", data); err != nil {
|
||||
log.Errorf("failed to render device template: %v", err)
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// handleLogin processes the login form on the device flow
|
||||
func (h *DeviceHandler) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "invalid form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
userCode := r.FormValue("user_code")
|
||||
username := r.FormValue("username")
|
||||
password := r.FormValue("password")
|
||||
|
||||
data := map[string]interface{}{
|
||||
"UserCode": userCode,
|
||||
"Error": "",
|
||||
"Step": "login",
|
||||
}
|
||||
|
||||
if userCode == "" || username == "" || password == "" {
|
||||
data["Error"] = "Please fill in all fields"
|
||||
h.tmpl.ExecuteTemplate(w, "device.html", data)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate credentials
|
||||
userID, err := h.storage.CheckUsernamePasswordSimple(username, password)
|
||||
if err != nil {
|
||||
log.Warnf("device login failed for user %s: %v", username, err)
|
||||
data["Error"] = "Invalid username or password"
|
||||
h.tmpl.ExecuteTemplate(w, "device.html", data)
|
||||
return
|
||||
}
|
||||
|
||||
// Get device authorization info
|
||||
authState, err := h.storage.GetDeviceAuthorizationByUserCode(r.Context(), userCode)
|
||||
if err != nil {
|
||||
data["Error"] = "Invalid or expired user code"
|
||||
data["Step"] = "code"
|
||||
data["UserCode"] = ""
|
||||
h.tmpl.ExecuteTemplate(w, "device.html", data)
|
||||
return
|
||||
}
|
||||
|
||||
// Set secure cookie with user info for confirmation step
|
||||
cookieValue := map[string]string{
|
||||
"user_code": userCode,
|
||||
"user_id": userID,
|
||||
}
|
||||
|
||||
encoded, err := h.secureCookie.Encode("device_auth", cookieValue)
|
||||
if err != nil {
|
||||
log.Errorf("failed to encode cookie: %v", err)
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "device_auth",
|
||||
Value: encoded,
|
||||
Path: "/device",
|
||||
HttpOnly: true,
|
||||
Secure: r.TLS != nil,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
|
||||
// Show confirmation page
|
||||
data["Step"] = "confirm"
|
||||
data["ClientID"] = authState.ClientID
|
||||
data["Scopes"] = authState.Scopes
|
||||
data["UserID"] = userID
|
||||
|
||||
h.tmpl.ExecuteTemplate(w, "device.html", data)
|
||||
}
|
||||
|
||||
// handleConfirm processes the authorization decision
|
||||
func (h *DeviceHandler) handleConfirm(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "invalid form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Get values from cookie
|
||||
cookie, err := r.Cookie("device_auth")
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/device", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
var cookieValue map[string]string
|
||||
if err := h.secureCookie.Decode("device_auth", cookie.Value, &cookieValue); err != nil {
|
||||
http.Redirect(w, r, "/device", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
userCode := cookieValue["user_code"]
|
||||
userID := cookieValue["user_id"]
|
||||
action := r.FormValue("action")
|
||||
|
||||
data := map[string]interface{}{
|
||||
"Step": "result",
|
||||
}
|
||||
|
||||
// Clear the cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "device_auth",
|
||||
Value: "",
|
||||
Path: "/device",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
})
|
||||
|
||||
if action == "allow" {
|
||||
if err := h.storage.CompleteDeviceAuthorization(r.Context(), userCode, userID); err != nil {
|
||||
log.Errorf("failed to complete device authorization: %v", err)
|
||||
data["Error"] = "Failed to authorize device"
|
||||
} else {
|
||||
data["Success"] = true
|
||||
data["Message"] = "Device authorized successfully! You can now close this window."
|
||||
}
|
||||
} else {
|
||||
if err := h.storage.DenyDeviceAuthorization(r.Context(), userCode); err != nil {
|
||||
log.Errorf("failed to deny device authorization: %v", err)
|
||||
}
|
||||
data["Success"] = false
|
||||
data["Message"] = "Authorization denied. You can close this window."
|
||||
}
|
||||
|
||||
h.tmpl.ExecuteTemplate(w, "device.html", data)
|
||||
}
|
||||
|
||||
// GenerateUserCode generates a user-friendly code for device flow
|
||||
func GenerateUserCode() string {
|
||||
// Generate a base20 code (BCDFGHJKLMNPQRSTVWXZ - no vowels to avoid words)
|
||||
chars := "BCDFGHJKLMNPQRSTVWXZ"
|
||||
b := securecookie.GenerateRandomKey(8)
|
||||
result := make([]byte, 8)
|
||||
for i := range result {
|
||||
result[i] = chars[int(b[i])%len(chars)]
|
||||
}
|
||||
// Format as XXXX-XXXX
|
||||
return string(result[:4]) + "-" + string(result[4:])
|
||||
}
|
||||
|
||||
// GenerateDeviceCode generates a secure device code
|
||||
func GenerateDeviceCode() string {
|
||||
b := securecookie.GenerateRandomKey(32)
|
||||
return base64.RawURLEncoding.EncodeToString(b)
|
||||
}
|
||||
105
idp/oidcprovider/login.go
Normal file
105
idp/oidcprovider/login.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package oidcprovider
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"html/template"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
//go:embed templates/*.html
|
||||
var templateFS embed.FS
|
||||
|
||||
// LoginHandler handles the login flow
|
||||
type LoginHandler struct {
|
||||
storage *OIDCStorage
|
||||
callback func(string) string
|
||||
tmpl *template.Template
|
||||
}
|
||||
|
||||
// NewLoginHandler creates a new login handler
|
||||
func NewLoginHandler(storage *OIDCStorage, callback func(string) string) (*LoginHandler, error) {
|
||||
tmpl, err := template.ParseFS(templateFS, "templates/*.html")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &LoginHandler{
|
||||
storage: storage,
|
||||
callback: callback,
|
||||
tmpl: tmpl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Router returns the login router
|
||||
func (h *LoginHandler) Router() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/", h.loginPage)
|
||||
r.Post("/", h.handleLogin)
|
||||
return r
|
||||
}
|
||||
|
||||
// loginPage displays the login form
|
||||
func (h *LoginHandler) loginPage(w http.ResponseWriter, r *http.Request) {
|
||||
authRequestID := r.URL.Query().Get("authRequestID")
|
||||
if authRequestID == "" {
|
||||
http.Error(w, "missing auth request ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"AuthRequestID": authRequestID,
|
||||
"Error": "",
|
||||
}
|
||||
|
||||
if err := h.tmpl.ExecuteTemplate(w, "login.html", data); err != nil {
|
||||
log.Errorf("failed to render login template: %v", err)
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// handleLogin processes the login form submission
|
||||
func (h *LoginHandler) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "invalid form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
authRequestID := r.FormValue("authRequestID")
|
||||
username := r.FormValue("username")
|
||||
password := r.FormValue("password")
|
||||
|
||||
if authRequestID == "" || username == "" || password == "" {
|
||||
data := map[string]interface{}{
|
||||
"AuthRequestID": authRequestID,
|
||||
"Error": "Please fill in all fields",
|
||||
}
|
||||
h.tmpl.ExecuteTemplate(w, "login.html", data)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate credentials and get user ID
|
||||
userID, err := h.storage.CheckUsernamePasswordSimple(username, password)
|
||||
if err != nil {
|
||||
log.Warnf("login failed for user %s: %v", username, err)
|
||||
data := map[string]interface{}{
|
||||
"AuthRequestID": authRequestID,
|
||||
"Error": "Invalid username or password",
|
||||
}
|
||||
h.tmpl.ExecuteTemplate(w, "login.html", data)
|
||||
return
|
||||
}
|
||||
|
||||
// Complete the auth request
|
||||
if err := h.storage.CompleteAuthRequest(r.Context(), authRequestID, userID); err != nil {
|
||||
log.Errorf("failed to complete auth request: %v", err)
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect to callback
|
||||
callbackURL := h.callback(authRequestID)
|
||||
http.Redirect(w, r, callbackURL, http.StatusFound)
|
||||
}
|
||||
136
idp/oidcprovider/models.go
Normal file
136
idp/oidcprovider/models.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package oidcprovider
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// User represents an OIDC user stored in the database
|
||||
type User struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
Username string `gorm:"uniqueIndex;not null"`
|
||||
Password string `gorm:"not null"` // bcrypt hashed
|
||||
Email string
|
||||
EmailVerified bool
|
||||
FirstName string
|
||||
LastName string
|
||||
Phone string
|
||||
PhoneVerified bool
|
||||
PreferredLanguage string // language tag string
|
||||
IsAdmin bool
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// GetPreferredLanguage returns the user's preferred language as a language.Tag
|
||||
func (u *User) GetPreferredLanguage() language.Tag {
|
||||
if u.PreferredLanguage == "" {
|
||||
return language.English
|
||||
}
|
||||
tag, err := language.Parse(u.PreferredLanguage)
|
||||
if err != nil {
|
||||
return language.English
|
||||
}
|
||||
return tag
|
||||
}
|
||||
|
||||
// Client represents an OIDC client (application) stored in the database
|
||||
type Client struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
Secret string // bcrypt hashed, empty for public clients
|
||||
Name string
|
||||
RedirectURIs string // JSON array of redirect URIs
|
||||
PostLogoutURIs string // JSON array of post-logout redirect URIs
|
||||
ApplicationType string // native, web, user_agent
|
||||
AuthMethod string // none, client_secret_basic, client_secret_post, private_key_jwt
|
||||
ResponseTypes string // JSON array: code, id_token, token
|
||||
GrantTypes string // JSON array: authorization_code, refresh_token, client_credentials, urn:ietf:params:oauth:grant-type:device_code
|
||||
AccessTokenType string // bearer or jwt
|
||||
DevMode bool // allows non-HTTPS redirect URIs
|
||||
IDTokenLifetime int64 // in seconds, default 3600 (1 hour)
|
||||
ClockSkew int64 // in seconds, allowed clock skew
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// AuthRequest represents an ongoing authorization request
|
||||
type AuthRequest struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
ClientID string `gorm:"index"`
|
||||
Scopes string // JSON array of scopes
|
||||
RedirectURI string
|
||||
State string
|
||||
Nonce string
|
||||
ResponseType string
|
||||
ResponseMode string
|
||||
CodeChallenge string
|
||||
CodeMethod string // S256 or plain
|
||||
UserID string // set after user authentication
|
||||
Done bool // true when user has authenticated
|
||||
AuthTime time.Time
|
||||
CreatedAt time.Time
|
||||
MaxAge int64 // max authentication age in seconds
|
||||
Prompt string // none, login, consent, select_account
|
||||
UILocales string // space-separated list of locales
|
||||
LoginHint string
|
||||
ACRValues string // space-separated list of ACR values
|
||||
}
|
||||
|
||||
// AuthCode represents an authorization code
|
||||
type AuthCode struct {
|
||||
Code string `gorm:"primaryKey"`
|
||||
AuthRequestID string `gorm:"index"`
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// AccessToken represents an access token
|
||||
type AccessToken struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
ApplicationID string `gorm:"index"`
|
||||
Subject string `gorm:"index"`
|
||||
Audience string // JSON array
|
||||
Scopes string // JSON array
|
||||
Expiration time.Time
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// RefreshToken represents a refresh token
|
||||
type RefreshToken struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
Token string `gorm:"uniqueIndex"`
|
||||
AuthRequestID string
|
||||
ApplicationID string `gorm:"index"`
|
||||
Subject string `gorm:"index"`
|
||||
Audience string // JSON array
|
||||
Scopes string // JSON array
|
||||
AMR string // JSON array of authentication methods
|
||||
AuthTime time.Time
|
||||
Expiration time.Time
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// DeviceAuth represents a device authorization request
|
||||
type DeviceAuth struct {
|
||||
DeviceCode string `gorm:"primaryKey"`
|
||||
UserCode string `gorm:"uniqueIndex"`
|
||||
ClientID string `gorm:"index"`
|
||||
Scopes string // JSON array
|
||||
Subject string // set after user authentication
|
||||
Audience string // JSON array
|
||||
Done bool // true when user has authorized
|
||||
Denied bool // true when user has denied
|
||||
Expiration time.Time
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// SigningKey represents a signing key for JWTs
|
||||
type SigningKey struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
Algorithm string // RS256
|
||||
PrivateKey []byte // PEM encoded
|
||||
PublicKey []byte // PEM encoded
|
||||
CreatedAt time.Time
|
||||
Active bool
|
||||
}
|
||||
662
idp/oidcprovider/oidc_storage.go
Normal file
662
idp/oidcprovider/oidc_storage.go
Normal file
@@ -0,0 +1,662 @@
|
||||
package oidcprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ErrInvalidRefreshToken is returned when a token is not a valid refresh token
|
||||
var ErrInvalidRefreshToken = errors.New("invalid refresh token")
|
||||
|
||||
// OIDCStorage implements op.Storage interface for the OIDC provider
|
||||
type OIDCStorage struct {
|
||||
store *Store
|
||||
issuer string
|
||||
loginURL func(string) string
|
||||
}
|
||||
|
||||
// NewOIDCStorage creates a new OIDCStorage
|
||||
func NewOIDCStorage(store *Store, issuer string) *OIDCStorage {
|
||||
return &OIDCStorage{
|
||||
store: store,
|
||||
issuer: issuer,
|
||||
}
|
||||
}
|
||||
|
||||
// SetLoginURL sets the login URL generator function
|
||||
func (s *OIDCStorage) SetLoginURL(fn func(string) string) {
|
||||
s.loginURL = fn
|
||||
}
|
||||
|
||||
// Health checks if the storage is healthy
|
||||
func (s *OIDCStorage) Health(ctx context.Context) error {
|
||||
sqlDB, err := s.store.db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.PingContext(ctx)
|
||||
}
|
||||
|
||||
// CreateAuthRequest creates and stores a new authorization request
|
||||
func (s *OIDCStorage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, userID string) (op.AuthRequest, error) {
|
||||
req := &AuthRequest{
|
||||
ID: uuid.New().String(),
|
||||
ClientID: authReq.ClientID,
|
||||
Scopes: ToJSONArray(authReq.Scopes),
|
||||
RedirectURI: authReq.RedirectURI,
|
||||
State: authReq.State,
|
||||
Nonce: authReq.Nonce,
|
||||
ResponseType: string(authReq.ResponseType),
|
||||
ResponseMode: string(authReq.ResponseMode),
|
||||
CodeChallenge: authReq.CodeChallenge,
|
||||
CodeMethod: string(authReq.CodeChallengeMethod),
|
||||
UserID: userID,
|
||||
Done: userID != "",
|
||||
CreatedAt: time.Now(),
|
||||
Prompt: spaceSeparated(authReq.Prompt),
|
||||
UILocales: authReq.UILocales.String(),
|
||||
LoginHint: authReq.LoginHint,
|
||||
ACRValues: spaceSeparated(authReq.ACRValues),
|
||||
}
|
||||
|
||||
if authReq.MaxAge != nil {
|
||||
req.MaxAge = int64(*authReq.MaxAge)
|
||||
}
|
||||
|
||||
if userID != "" {
|
||||
req.AuthTime = time.Now()
|
||||
}
|
||||
|
||||
if err := s.store.SaveAuthRequest(ctx, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &OIDCAuthRequest{req: req, storage: s}, nil
|
||||
}
|
||||
|
||||
// AuthRequestByID retrieves an authorization request by ID
|
||||
func (s *OIDCStorage) AuthRequestByID(ctx context.Context, id string) (op.AuthRequest, error) {
|
||||
req, err := s.store.GetAuthRequestByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("auth request not found: %s", id)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &OIDCAuthRequest{req: req, storage: s}, nil
|
||||
}
|
||||
|
||||
// AuthRequestByCode retrieves an authorization request by code
|
||||
func (s *OIDCStorage) AuthRequestByCode(ctx context.Context, code string) (op.AuthRequest, error) {
|
||||
authCode, err := s.store.GetAuthCodeByCode(ctx, code)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("auth code not found: %s", code)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if time.Now().After(authCode.ExpiresAt) {
|
||||
_ = s.store.DeleteAuthCode(ctx, code)
|
||||
return nil, errors.New("auth code expired")
|
||||
}
|
||||
|
||||
req, err := s.store.GetAuthRequestByID(ctx, authCode.AuthRequestID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &OIDCAuthRequest{req: req, storage: s}, nil
|
||||
}
|
||||
|
||||
// SaveAuthCode saves an authorization code linked to an auth request
|
||||
func (s *OIDCStorage) SaveAuthCode(ctx context.Context, id, code string) error {
|
||||
authCode := &AuthCode{
|
||||
Code: code,
|
||||
AuthRequestID: id,
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
}
|
||||
return s.store.SaveAuthCode(ctx, authCode)
|
||||
}
|
||||
|
||||
// DeleteAuthRequest deletes an authorization request
|
||||
func (s *OIDCStorage) DeleteAuthRequest(ctx context.Context, id string) error {
|
||||
return s.store.DeleteAuthRequest(ctx, id)
|
||||
}
|
||||
|
||||
// CreateAccessToken creates and stores an access token
|
||||
func (s *OIDCStorage) CreateAccessToken(ctx context.Context, request op.TokenRequest) (string, time.Time, error) {
|
||||
tokenID := uuid.New().String()
|
||||
expiration := time.Now().Add(5 * time.Minute)
|
||||
|
||||
// Get client ID from the request if possible
|
||||
var clientID string
|
||||
if authReq, ok := request.(op.AuthRequest); ok {
|
||||
clientID = authReq.GetClientID()
|
||||
} else if refreshReq, ok := request.(op.RefreshTokenRequest); ok {
|
||||
clientID = refreshReq.GetClientID()
|
||||
}
|
||||
|
||||
token := &AccessToken{
|
||||
ID: tokenID,
|
||||
ApplicationID: clientID,
|
||||
Subject: request.GetSubject(),
|
||||
Audience: ToJSONArray(request.GetAudience()),
|
||||
Scopes: ToJSONArray(request.GetScopes()),
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
if err := s.store.SaveAccessToken(ctx, token); err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
|
||||
return tokenID, expiration, nil
|
||||
}
|
||||
|
||||
// CreateAccessAndRefreshTokens creates both access and refresh tokens
|
||||
func (s *OIDCStorage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) {
|
||||
// Delete old refresh token if provided
|
||||
if currentRefreshToken != "" {
|
||||
_ = s.store.DeleteRefreshTokenByToken(ctx, currentRefreshToken)
|
||||
}
|
||||
|
||||
// Create access token
|
||||
accessTokenID, expiration, err = s.CreateAccessToken(ctx, request)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
|
||||
// Get additional info from the request if possible
|
||||
var clientID string
|
||||
var authTime time.Time
|
||||
var amr []string
|
||||
|
||||
if authReq, ok := request.(op.AuthRequest); ok {
|
||||
clientID = authReq.GetClientID()
|
||||
authTime = authReq.GetAuthTime()
|
||||
amr = authReq.GetAMR()
|
||||
} else if refreshReq, ok := request.(op.RefreshTokenRequest); ok {
|
||||
clientID = refreshReq.GetClientID()
|
||||
authTime = refreshReq.GetAuthTime()
|
||||
amr = refreshReq.GetAMR()
|
||||
}
|
||||
|
||||
// Create refresh token
|
||||
refreshToken := &RefreshToken{
|
||||
ID: uuid.New().String(),
|
||||
Token: uuid.New().String(),
|
||||
ApplicationID: clientID,
|
||||
Subject: request.GetSubject(),
|
||||
Audience: ToJSONArray(request.GetAudience()),
|
||||
Scopes: ToJSONArray(request.GetScopes()),
|
||||
AuthTime: authTime,
|
||||
AMR: ToJSONArray(amr),
|
||||
Expiration: time.Now().Add(5 * time.Hour), // 5 hour refresh token lifetime
|
||||
}
|
||||
|
||||
if authReq, ok := request.(op.AuthRequest); ok {
|
||||
refreshToken.AuthRequestID = authReq.GetID()
|
||||
}
|
||||
|
||||
if err := s.store.SaveRefreshToken(ctx, refreshToken); err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
|
||||
return accessTokenID, refreshToken.Token, expiration, nil
|
||||
}
|
||||
|
||||
// TokenRequestByRefreshToken retrieves token request info from refresh token
|
||||
func (s *OIDCStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) {
|
||||
token, err := s.store.GetRefreshToken(ctx, refreshToken)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("refresh token not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if time.Now().After(token.Expiration) {
|
||||
_ = s.store.DeleteRefreshTokenByToken(ctx, refreshToken)
|
||||
return nil, errors.New("refresh token expired")
|
||||
}
|
||||
|
||||
return &OIDCRefreshToken{token: token}, nil
|
||||
}
|
||||
|
||||
// TerminateSession terminates a user session
|
||||
func (s *OIDCStorage) TerminateSession(ctx context.Context, userID, clientID string) error {
|
||||
// For now, we don't track sessions separately
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeToken revokes a token
|
||||
func (s *OIDCStorage) RevokeToken(ctx context.Context, tokenOrID string, userID string, clientID string) *oidc.Error {
|
||||
// Try to delete as refresh token
|
||||
if err := s.store.DeleteRefreshTokenByToken(ctx, tokenOrID); err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to delete as access token
|
||||
if err := s.store.DeleteAccessToken(ctx, tokenOrID); err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil // Silently succeed even if token not found (per spec)
|
||||
}
|
||||
|
||||
// GetRefreshTokenInfo returns info about a refresh token
|
||||
func (s *OIDCStorage) GetRefreshTokenInfo(ctx context.Context, clientID string, token string) (userID string, tokenID string, err error) {
|
||||
refreshToken, err := s.store.GetRefreshToken(ctx, token)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", "", ErrInvalidRefreshToken
|
||||
}
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if refreshToken.ApplicationID != clientID {
|
||||
return "", "", ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
return refreshToken.Subject, refreshToken.ID, nil
|
||||
}
|
||||
|
||||
// GetClientByClientID retrieves a client by ID
|
||||
func (s *OIDCStorage) GetClientByClientID(ctx context.Context, clientID string) (op.Client, error) {
|
||||
client, err := s.store.GetClientByID(ctx, clientID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("client not found: %s", clientID)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return NewOIDCClient(client, s.loginURL), nil
|
||||
}
|
||||
|
||||
// AuthorizeClientIDSecret validates client credentials
|
||||
func (s *OIDCStorage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error {
|
||||
_, err := s.store.ValidateClientSecret(ctx, clientID, clientSecret)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetUserinfoFromScopes sets userinfo claims based on scopes
|
||||
func (s *OIDCStorage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error {
|
||||
return s.setUserinfo(ctx, userinfo, userID, scopes)
|
||||
}
|
||||
|
||||
// SetUserinfoFromToken sets userinfo claims from an access token
|
||||
func (s *OIDCStorage) SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error {
|
||||
token, err := s.store.GetAccessTokenByID(ctx, tokenID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.setUserinfo(ctx, userinfo, token.Subject, ParseJSONArray(token.Scopes))
|
||||
}
|
||||
|
||||
// setUserinfo populates userinfo based on user data and scopes
|
||||
func (s *OIDCStorage) setUserinfo(ctx context.Context, userinfo *oidc.UserInfo, userID string, scopes []string) error {
|
||||
user, err := s.store.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, scope := range scopes {
|
||||
switch scope {
|
||||
case oidc.ScopeOpenID:
|
||||
userinfo.Subject = user.ID
|
||||
case oidc.ScopeProfile:
|
||||
userinfo.Name = fmt.Sprintf("%s %s", user.FirstName, user.LastName)
|
||||
userinfo.GivenName = user.FirstName
|
||||
userinfo.FamilyName = user.LastName
|
||||
userinfo.PreferredUsername = user.Username
|
||||
userinfo.Locale = oidc.NewLocale(user.GetPreferredLanguage())
|
||||
case oidc.ScopeEmail:
|
||||
userinfo.Email = user.Email
|
||||
userinfo.EmailVerified = oidc.Bool(user.EmailVerified)
|
||||
case oidc.ScopePhone:
|
||||
userinfo.PhoneNumber = user.Phone
|
||||
userinfo.PhoneNumberVerified = user.PhoneVerified
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetIntrospectionFromToken sets introspection response from token
|
||||
func (s *OIDCStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) error {
|
||||
token, err := s.store.GetAccessTokenByID(ctx, tokenID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
introspection.Active = true
|
||||
introspection.Subject = token.Subject
|
||||
introspection.ClientID = token.ApplicationID
|
||||
introspection.Scope = ParseJSONArray(token.Scopes)
|
||||
introspection.Expiration = oidc.FromTime(token.Expiration)
|
||||
introspection.IssuedAt = oidc.FromTime(token.CreatedAt)
|
||||
introspection.Audience = ParseJSONArray(token.Audience)
|
||||
introspection.Issuer = s.issuer
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPrivateClaimsFromScopes returns additional claims based on scopes
|
||||
func (s *OIDCStorage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetKeyByIDAndClientID retrieves a key by ID for a client
|
||||
func (s *OIDCStorage) GetKeyByIDAndClientID(ctx context.Context, keyID, clientID string) (*jose.JSONWebKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// ValidateJWTProfileScopes validates scopes for JWT profile grant
|
||||
func (s *OIDCStorage) ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error) {
|
||||
return scopes, nil
|
||||
}
|
||||
|
||||
// SigningKey returns the active signing key for token signing
|
||||
func (s *OIDCStorage) SigningKey(ctx context.Context) (op.SigningKey, error) {
|
||||
key, err := s.store.GetSigningKey(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(key.PrivateKey)
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode private key PEM")
|
||||
}
|
||||
|
||||
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||
}
|
||||
|
||||
return &signingKey{
|
||||
id: key.ID,
|
||||
algorithm: jose.RS256,
|
||||
privateKey: privateKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SignatureAlgorithms returns supported signature algorithms
|
||||
func (s *OIDCStorage) SignatureAlgorithms(ctx context.Context) ([]jose.SignatureAlgorithm, error) {
|
||||
return []jose.SignatureAlgorithm{jose.RS256}, nil
|
||||
}
|
||||
|
||||
// KeySet returns the public key set for token verification
|
||||
func (s *OIDCStorage) KeySet(ctx context.Context) ([]op.Key, error) {
|
||||
key, err := s.store.GetSigningKey(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(key.PublicKey)
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode public key PEM")
|
||||
}
|
||||
|
||||
publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
|
||||
rsaKey, ok := publicKey.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, errors.New("public key is not RSA")
|
||||
}
|
||||
|
||||
return []op.Key{
|
||||
&publicKeyInfo{
|
||||
id: key.ID,
|
||||
algorithm: jose.RS256,
|
||||
publicKey: rsaKey,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Device Authorization Flow methods
|
||||
|
||||
// StoreDeviceAuthorization stores a device authorization request
|
||||
func (s *OIDCStorage) StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) error {
|
||||
auth := &DeviceAuth{
|
||||
DeviceCode: deviceCode,
|
||||
UserCode: userCode,
|
||||
ClientID: clientID,
|
||||
Scopes: ToJSONArray(scopes),
|
||||
Expiration: expires,
|
||||
}
|
||||
return s.store.SaveDeviceAuth(ctx, auth)
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationState retrieves the state of a device authorization
|
||||
func (s *OIDCStorage) GetDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string) (*op.DeviceAuthorizationState, error) {
|
||||
auth, err := s.store.GetDeviceAuthByDeviceCode(ctx, deviceCode)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("device authorization not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if auth.ClientID != clientID {
|
||||
return nil, errors.New("client ID mismatch")
|
||||
}
|
||||
|
||||
if time.Now().After(auth.Expiration) {
|
||||
_ = s.store.DeleteDeviceAuth(ctx, deviceCode)
|
||||
return &op.DeviceAuthorizationState{Expires: auth.Expiration}, nil
|
||||
}
|
||||
|
||||
state := &op.DeviceAuthorizationState{
|
||||
ClientID: auth.ClientID,
|
||||
Scopes: ParseJSONArray(auth.Scopes),
|
||||
Expires: auth.Expiration,
|
||||
}
|
||||
|
||||
if auth.Denied {
|
||||
state.Denied = true
|
||||
} else if auth.Done {
|
||||
state.Done = true
|
||||
state.Subject = auth.Subject
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationByUserCode retrieves device auth by user code
|
||||
func (s *OIDCStorage) GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error) {
|
||||
auth, err := s.store.GetDeviceAuthByUserCode(ctx, userCode)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("device authorization not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if time.Now().After(auth.Expiration) {
|
||||
return nil, errors.New("device authorization expired")
|
||||
}
|
||||
|
||||
return &op.DeviceAuthorizationState{
|
||||
ClientID: auth.ClientID,
|
||||
Scopes: ParseJSONArray(auth.Scopes),
|
||||
Expires: auth.Expiration,
|
||||
Done: auth.Done,
|
||||
Denied: auth.Denied,
|
||||
Subject: auth.Subject,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CompleteDeviceAuthorization marks a device authorization as complete
|
||||
func (s *OIDCStorage) CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error {
|
||||
auth, err := s.store.GetDeviceAuthByUserCode(ctx, userCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
auth.Done = true
|
||||
auth.Subject = subject
|
||||
return s.store.UpdateDeviceAuth(ctx, auth)
|
||||
}
|
||||
|
||||
// DenyDeviceAuthorization marks a device authorization as denied
|
||||
func (s *OIDCStorage) DenyDeviceAuthorization(ctx context.Context, userCode string) error {
|
||||
auth, err := s.store.GetDeviceAuthByUserCode(ctx, userCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
auth.Denied = true
|
||||
return s.store.UpdateDeviceAuth(ctx, auth)
|
||||
}
|
||||
|
||||
// User authentication methods
|
||||
|
||||
// CheckUsernamePassword validates user credentials
|
||||
func (s *OIDCStorage) CheckUsernamePassword(username, password, authRequestID string) error {
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.store.ValidateUserPassword(ctx, username, password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckUsernamePasswordSimple validates user credentials and returns the user ID
|
||||
func (s *OIDCStorage) CheckUsernamePasswordSimple(username, password string) (string, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
user, err := s.store.ValidateUserPassword(ctx, username, password)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return user.ID, nil
|
||||
}
|
||||
|
||||
// CompleteAuthRequest completes an auth request after user authentication
|
||||
func (s *OIDCStorage) CompleteAuthRequest(ctx context.Context, authRequestID, userID string) error {
|
||||
req, err := s.store.GetAuthRequestByID(ctx, authRequestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.UserID = userID
|
||||
req.Done = true
|
||||
req.AuthTime = time.Now()
|
||||
|
||||
return s.store.UpdateAuthRequest(ctx, req)
|
||||
}
|
||||
|
||||
// Helper types
|
||||
|
||||
// signingKey implements op.SigningKey
|
||||
type signingKey struct {
|
||||
id string
|
||||
algorithm jose.SignatureAlgorithm
|
||||
privateKey *rsa.PrivateKey
|
||||
}
|
||||
|
||||
func (k *signingKey) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||
return k.algorithm
|
||||
}
|
||||
|
||||
func (k *signingKey) Key() interface{} {
|
||||
return k.privateKey
|
||||
}
|
||||
|
||||
func (k *signingKey) ID() string {
|
||||
return k.id
|
||||
}
|
||||
|
||||
// publicKeyInfo implements op.Key
|
||||
type publicKeyInfo struct {
|
||||
id string
|
||||
algorithm jose.SignatureAlgorithm
|
||||
publicKey *rsa.PublicKey
|
||||
}
|
||||
|
||||
func (k *publicKeyInfo) ID() string {
|
||||
return k.id
|
||||
}
|
||||
|
||||
func (k *publicKeyInfo) Algorithm() jose.SignatureAlgorithm {
|
||||
return k.algorithm
|
||||
}
|
||||
|
||||
func (k *publicKeyInfo) Use() string {
|
||||
return "sig"
|
||||
}
|
||||
|
||||
func (k *publicKeyInfo) Key() interface{} {
|
||||
return k.publicKey
|
||||
}
|
||||
|
||||
// OIDCAuthRequest wraps AuthRequest for the op.AuthRequest interface
|
||||
type OIDCAuthRequest struct {
|
||||
req *AuthRequest
|
||||
storage *OIDCStorage
|
||||
}
|
||||
|
||||
func (r *OIDCAuthRequest) GetID() string { return r.req.ID }
|
||||
func (r *OIDCAuthRequest) GetACR() string { return "" }
|
||||
func (r *OIDCAuthRequest) GetAMR() []string { return []string{"pwd"} }
|
||||
func (r *OIDCAuthRequest) GetAudience() []string { return []string{r.req.ClientID} }
|
||||
func (r *OIDCAuthRequest) GetAuthTime() time.Time { return r.req.AuthTime }
|
||||
func (r *OIDCAuthRequest) GetClientID() string { return r.req.ClientID }
|
||||
func (r *OIDCAuthRequest) GetCodeChallenge() *oidc.CodeChallenge {
|
||||
if r.req.CodeChallenge == "" {
|
||||
return nil
|
||||
}
|
||||
return &oidc.CodeChallenge{
|
||||
Challenge: r.req.CodeChallenge,
|
||||
Method: oidc.CodeChallengeMethod(r.req.CodeMethod),
|
||||
}
|
||||
}
|
||||
func (r *OIDCAuthRequest) GetNonce() string { return r.req.Nonce }
|
||||
func (r *OIDCAuthRequest) GetRedirectURI() string { return r.req.RedirectURI }
|
||||
func (r *OIDCAuthRequest) GetResponseType() oidc.ResponseType {
|
||||
return oidc.ResponseType(r.req.ResponseType)
|
||||
}
|
||||
func (r *OIDCAuthRequest) GetResponseMode() oidc.ResponseMode {
|
||||
return oidc.ResponseMode(r.req.ResponseMode)
|
||||
}
|
||||
func (r *OIDCAuthRequest) GetScopes() []string { return ParseJSONArray(r.req.Scopes) }
|
||||
func (r *OIDCAuthRequest) GetState() string { return r.req.State }
|
||||
func (r *OIDCAuthRequest) GetSubject() string { return r.req.UserID }
|
||||
func (r *OIDCAuthRequest) Done() bool { return r.req.Done }
|
||||
|
||||
// OIDCRefreshToken wraps RefreshToken for the op.RefreshTokenRequest interface
|
||||
type OIDCRefreshToken struct {
|
||||
token *RefreshToken
|
||||
}
|
||||
|
||||
func (r *OIDCRefreshToken) GetAMR() []string { return ParseJSONArray(r.token.AMR) }
|
||||
func (r *OIDCRefreshToken) GetAudience() []string { return ParseJSONArray(r.token.Audience) }
|
||||
func (r *OIDCRefreshToken) GetAuthTime() time.Time { return r.token.AuthTime }
|
||||
func (r *OIDCRefreshToken) GetClientID() string { return r.token.ApplicationID }
|
||||
func (r *OIDCRefreshToken) GetScopes() []string { return ParseJSONArray(r.token.Scopes) }
|
||||
func (r *OIDCRefreshToken) GetSubject() string { return r.token.Subject }
|
||||
func (r *OIDCRefreshToken) SetCurrentScopes(scopes []string) {}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func spaceSeparated(items []string) string {
|
||||
return strings.Join(items, " ")
|
||||
}
|
||||
265
idp/oidcprovider/provider.go
Normal file
265
idp/oidcprovider/provider.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package oidcprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
// Config holds the configuration for the OIDC provider
|
||||
type Config struct {
|
||||
// Issuer is the OIDC issuer URL (e.g., "https://idp.example.com")
|
||||
Issuer string
|
||||
// Port is the port to listen on
|
||||
Port int
|
||||
// DataDir is the directory to store OIDC data (SQLite database)
|
||||
DataDir string
|
||||
// DevMode enables development mode (allows HTTP, localhost)
|
||||
DevMode bool
|
||||
}
|
||||
|
||||
// Provider represents the embedded OIDC provider
|
||||
type Provider struct {
|
||||
config *Config
|
||||
store *Store
|
||||
storage *OIDCStorage
|
||||
provider op.OpenIDProvider
|
||||
router chi.Router
|
||||
httpServer *http.Server
|
||||
}
|
||||
|
||||
// NewProvider creates a new OIDC provider
|
||||
func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
||||
// Create the SQLite store
|
||||
store, err := NewStore(ctx, config.DataDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OIDC store: %w", err)
|
||||
}
|
||||
|
||||
// Create the OIDC storage adapter
|
||||
storage := NewOIDCStorage(store, config.Issuer)
|
||||
|
||||
p := &Provider{
|
||||
config: config,
|
||||
store: store,
|
||||
storage: storage,
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Start starts the OIDC provider server
|
||||
func (p *Provider) Start(ctx context.Context) error {
|
||||
// Create the router
|
||||
router := chi.NewRouter()
|
||||
router.Use(middleware.Logger)
|
||||
router.Use(middleware.Recoverer)
|
||||
router.Use(middleware.RequestID)
|
||||
|
||||
// Create the OIDC provider
|
||||
key := sha256.Sum256([]byte(p.config.Issuer + "encryption-key"))
|
||||
|
||||
opConfig := &op.Config{
|
||||
CryptoKey: key,
|
||||
DefaultLogoutRedirectURI: "/logged-out",
|
||||
CodeMethodS256: true,
|
||||
AuthMethodPost: true,
|
||||
AuthMethodPrivateKeyJWT: true,
|
||||
GrantTypeRefreshToken: true,
|
||||
RequestObjectSupported: true,
|
||||
DeviceAuthorization: op.DeviceAuthorizationConfig{
|
||||
Lifetime: 5 * time.Minute,
|
||||
PollInterval: 5 * time.Second,
|
||||
UserFormPath: "/device",
|
||||
UserCode: op.UserCodeBase20,
|
||||
},
|
||||
}
|
||||
|
||||
// Set the login URL generator
|
||||
p.storage.SetLoginURL(func(authRequestID string) string {
|
||||
return fmt.Sprintf("/login?authRequestID=%s", authRequestID)
|
||||
})
|
||||
|
||||
// Create the provider with options
|
||||
var opts []op.Option
|
||||
if p.config.DevMode {
|
||||
opts = append(opts, op.WithAllowInsecure())
|
||||
}
|
||||
|
||||
provider, err := op.NewProvider(opConfig, p.storage, op.StaticIssuer(p.config.Issuer), opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create OIDC provider: %w", err)
|
||||
}
|
||||
p.provider = provider
|
||||
|
||||
// Set up login handler
|
||||
loginHandler, err := NewLoginHandler(p.storage, func(authRequestID string) string {
|
||||
return provider.AuthorizationEndpoint().Absolute("/authorize/callback") + "?id=" + authRequestID
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create login handler: %w", err)
|
||||
}
|
||||
|
||||
// Set up device handler
|
||||
deviceHandler, err := NewDeviceHandler(p.storage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create device handler: %w", err)
|
||||
}
|
||||
|
||||
// Mount routes
|
||||
router.Mount("/login", loginHandler.Router())
|
||||
router.Mount("/device", deviceHandler.Router())
|
||||
router.Get("/logged-out", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Write([]byte(`<!DOCTYPE html><html><head><title>Logged Out</title></head><body><h1>You have been logged out</h1><p>You can close this window.</p></body></html>`))
|
||||
})
|
||||
|
||||
// Mount the OIDC provider at root
|
||||
router.Mount("/", provider)
|
||||
|
||||
p.router = router
|
||||
|
||||
// Create HTTP server
|
||||
addr := fmt.Sprintf(":%d", p.config.Port)
|
||||
p.httpServer = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: router,
|
||||
}
|
||||
|
||||
// Start server in goroutine
|
||||
go func() {
|
||||
log.Infof("Starting OIDC provider on %s (issuer: %s)", addr, p.config.Issuer)
|
||||
if err := p.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Errorf("OIDC provider server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start cleanup goroutine
|
||||
go p.cleanupLoop(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the OIDC provider server
|
||||
func (p *Provider) Stop(ctx context.Context) error {
|
||||
if p.httpServer != nil {
|
||||
if err := p.httpServer.Shutdown(ctx); err != nil {
|
||||
return fmt.Errorf("failed to shutdown OIDC server: %w", err)
|
||||
}
|
||||
}
|
||||
if p.store != nil {
|
||||
if err := p.store.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close OIDC store: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupLoop periodically cleans up expired tokens
|
||||
func (p *Provider) cleanupLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(15 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := p.store.CleanupExpired(ctx); err != nil {
|
||||
log.Warnf("OIDC cleanup error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store returns the underlying store for user/client management
|
||||
func (p *Provider) Store() *Store {
|
||||
return p.store
|
||||
}
|
||||
|
||||
// GetIssuer returns the issuer URL
|
||||
func (p *Provider) GetIssuer() string {
|
||||
return p.config.Issuer
|
||||
}
|
||||
|
||||
// GetDiscoveryEndpoint returns the OpenID Connect discovery endpoint
|
||||
func (p *Provider) GetDiscoveryEndpoint() string {
|
||||
return p.config.Issuer + "/.well-known/openid-configuration"
|
||||
}
|
||||
|
||||
// GetTokenEndpoint returns the token endpoint
|
||||
func (p *Provider) GetTokenEndpoint() string {
|
||||
return p.config.Issuer + "/oauth/token"
|
||||
}
|
||||
|
||||
// GetAuthorizationEndpoint returns the authorization endpoint
|
||||
func (p *Provider) GetAuthorizationEndpoint() string {
|
||||
return p.config.Issuer + "/authorize"
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationEndpoint returns the device authorization endpoint
|
||||
func (p *Provider) GetDeviceAuthorizationEndpoint() string {
|
||||
return p.config.Issuer + "/device_authorization"
|
||||
}
|
||||
|
||||
// GetJWKSEndpoint returns the JWKS endpoint
|
||||
func (p *Provider) GetJWKSEndpoint() string {
|
||||
return p.config.Issuer + "/keys"
|
||||
}
|
||||
|
||||
// GetUserInfoEndpoint returns the userinfo endpoint
|
||||
func (p *Provider) GetUserInfoEndpoint() string {
|
||||
return p.config.Issuer + "/userinfo"
|
||||
}
|
||||
|
||||
// EnsureDefaultClients ensures the default NetBird clients exist
|
||||
func (p *Provider) EnsureDefaultClients(ctx context.Context, dashboardRedirectURIs, cliRedirectURIs []string) error {
|
||||
// Check if CLI client exists
|
||||
_, err := p.store.GetClientByID(ctx, "netbird-client")
|
||||
if err != nil {
|
||||
// Create CLI client (native, public, supports PKCE and device flow)
|
||||
cliClient := CreateNativeClient("netbird-client", "NetBird CLI", cliRedirectURIs)
|
||||
if err := p.store.CreateClient(ctx, cliClient); err != nil {
|
||||
return fmt.Errorf("failed to create CLI client: %w", err)
|
||||
}
|
||||
log.Info("Created default NetBird CLI client")
|
||||
}
|
||||
|
||||
// Check if dashboard client exists
|
||||
_, err = p.store.GetClientByID(ctx, "netbird-dashboard")
|
||||
if err != nil {
|
||||
// Create dashboard client (SPA, public, supports PKCE)
|
||||
dashboardClient := CreateSPAClient("netbird-dashboard", "NetBird Dashboard", dashboardRedirectURIs)
|
||||
if err := p.store.CreateClient(ctx, dashboardClient); err != nil {
|
||||
return fmt.Errorf("failed to create dashboard client: %w", err)
|
||||
}
|
||||
log.Info("Created default NetBird Dashboard client")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user (convenience method)
|
||||
func (p *Provider) CreateUser(ctx context.Context, username, password, email, firstName, lastName string) (*User, error) {
|
||||
user := &User{
|
||||
Username: username,
|
||||
Password: password, // Will be hashed by store
|
||||
Email: email,
|
||||
EmailVerified: false,
|
||||
FirstName: firstName,
|
||||
LastName: lastName,
|
||||
}
|
||||
|
||||
if err := p.store.CreateUser(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
493
idp/oidcprovider/store.go
Normal file
493
idp/oidcprovider/store.go
Normal file
@@ -0,0 +1,493 @@
|
||||
package oidcprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// Store handles persistence for OIDC provider data
|
||||
type Store struct {
|
||||
db *gorm.DB
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewStore creates a new Store with SQLite backend
|
||||
func NewStore(ctx context.Context, dataDir string) (*Store, error) {
|
||||
dbPath := fmt.Sprintf("%s/oidc.db", dataDir)
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open OIDC database: %w", err)
|
||||
}
|
||||
|
||||
// Enable WAL mode for better concurrency
|
||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to enable WAL mode: %v", err)
|
||||
}
|
||||
|
||||
// Auto-migrate tables
|
||||
if err := db.AutoMigrate(
|
||||
&User{},
|
||||
&Client{},
|
||||
&AuthRequest{},
|
||||
&AuthCode{},
|
||||
&AccessToken{},
|
||||
&RefreshToken{},
|
||||
&DeviceAuth{},
|
||||
&SigningKey{},
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to migrate OIDC database: %w", err)
|
||||
}
|
||||
|
||||
store := &Store{db: db}
|
||||
|
||||
// Ensure we have a signing key
|
||||
if err := store.ensureSigningKey(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to ensure signing key: %w", err)
|
||||
}
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (s *Store) Close() error {
|
||||
sqlDB, err := s.db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}
|
||||
|
||||
// ensureSigningKey creates a signing key if one doesn't exist
|
||||
func (s *Store) ensureSigningKey(ctx context.Context) error {
|
||||
var key SigningKey
|
||||
err := s.db.WithContext(ctx).Where("active = ?", true).First(&key).Error
|
||||
if err == nil {
|
||||
return nil // Key exists
|
||||
}
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Generate new RSA key pair
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate RSA key: %w", err)
|
||||
}
|
||||
|
||||
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
})
|
||||
|
||||
publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal public key: %w", err)
|
||||
}
|
||||
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: publicKeyBytes,
|
||||
})
|
||||
|
||||
newKey := &SigningKey{
|
||||
ID: uuid.New().String(),
|
||||
Algorithm: "RS256",
|
||||
PrivateKey: privateKeyPEM,
|
||||
PublicKey: publicKeyPEM,
|
||||
CreatedAt: time.Now(),
|
||||
Active: true,
|
||||
}
|
||||
|
||||
return s.db.WithContext(ctx).Create(newKey).Error
|
||||
}
|
||||
|
||||
// GetSigningKey returns the active signing key
|
||||
func (s *Store) GetSigningKey(ctx context.Context) (*SigningKey, error) {
|
||||
var key SigningKey
|
||||
err := s.db.WithContext(ctx).Where("active = ?", true).First(&key).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
// User operations
|
||||
|
||||
// CreateUser creates a new user with bcrypt hashed password
|
||||
func (s *Store) CreateUser(ctx context.Context, user *User) error {
|
||||
if user.ID == "" {
|
||||
user.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
user.Password = string(hashedPassword)
|
||||
user.CreatedAt = time.Now()
|
||||
user.UpdatedAt = time.Now()
|
||||
|
||||
return s.db.WithContext(ctx).Create(user).Error
|
||||
}
|
||||
|
||||
// GetUserByID retrieves a user by ID
|
||||
func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) {
|
||||
var user User
|
||||
err := s.db.WithContext(ctx).Where("id = ?", id).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserByUsername retrieves a user by username
|
||||
func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) {
|
||||
var user User
|
||||
err := s.db.WithContext(ctx).Where("username = ?", username).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// ValidateUserPassword validates a user's password
|
||||
func (s *Store) ValidateUserPassword(ctx context.Context, username, password string) (*User, error) {
|
||||
user, err := s.GetUserByUsername(ctx, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
|
||||
return nil, errors.New("invalid password")
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// ListUsers returns all users
|
||||
func (s *Store) ListUsers(ctx context.Context) ([]*User, error) {
|
||||
var users []*User
|
||||
err := s.db.WithContext(ctx).Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
// UpdateUser updates a user
|
||||
func (s *Store) UpdateUser(ctx context.Context, user *User) error {
|
||||
user.UpdatedAt = time.Now()
|
||||
return s.db.WithContext(ctx).Save(user).Error
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user
|
||||
func (s *Store) DeleteUser(ctx context.Context, id string) error {
|
||||
return s.db.WithContext(ctx).Delete(&User{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// UpdateUserPassword updates a user's password
|
||||
func (s *Store) UpdateUserPassword(ctx context.Context, id, password string) error {
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
return s.db.WithContext(ctx).Model(&User{}).Where("id = ?", id).Updates(map[string]interface{}{
|
||||
"password": string(hashedPassword),
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// Client operations
|
||||
|
||||
// CreateClient creates a new OIDC client
|
||||
func (s *Store) CreateClient(ctx context.Context, client *Client) error {
|
||||
if client.ID == "" {
|
||||
client.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
// Hash secret if provided
|
||||
if client.Secret != "" {
|
||||
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(client.Secret), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash client secret: %w", err)
|
||||
}
|
||||
client.Secret = string(hashedSecret)
|
||||
}
|
||||
|
||||
client.CreatedAt = time.Now()
|
||||
client.UpdatedAt = time.Now()
|
||||
|
||||
return s.db.WithContext(ctx).Create(client).Error
|
||||
}
|
||||
|
||||
// GetClientByID retrieves a client by ID
|
||||
func (s *Store) GetClientByID(ctx context.Context, id string) (*Client, error) {
|
||||
var client Client
|
||||
err := s.db.WithContext(ctx).Where("id = ?", id).First(&client).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &client, nil
|
||||
}
|
||||
|
||||
// ValidateClientSecret validates a client's secret
|
||||
func (s *Store) ValidateClientSecret(ctx context.Context, clientID, secret string) (*Client, error) {
|
||||
client, err := s.GetClientByID(ctx, clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Public clients have no secret
|
||||
if client.Secret == "" && secret == "" {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(secret)); err != nil {
|
||||
return nil, errors.New("invalid client secret")
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// ListClients returns all clients
|
||||
func (s *Store) ListClients(ctx context.Context) ([]*Client, error) {
|
||||
var clients []*Client
|
||||
err := s.db.WithContext(ctx).Find(&clients).Error
|
||||
return clients, err
|
||||
}
|
||||
|
||||
// DeleteClient deletes a client
|
||||
func (s *Store) DeleteClient(ctx context.Context, id string) error {
|
||||
return s.db.WithContext(ctx).Delete(&Client{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// AuthRequest operations
|
||||
|
||||
// SaveAuthRequest saves an authorization request
|
||||
func (s *Store) SaveAuthRequest(ctx context.Context, req *AuthRequest) error {
|
||||
if req.ID == "" {
|
||||
req.ID = uuid.New().String()
|
||||
}
|
||||
req.CreatedAt = time.Now()
|
||||
return s.db.WithContext(ctx).Create(req).Error
|
||||
}
|
||||
|
||||
// GetAuthRequestByID retrieves an auth request by ID
|
||||
func (s *Store) GetAuthRequestByID(ctx context.Context, id string) (*AuthRequest, error) {
|
||||
var req AuthRequest
|
||||
err := s.db.WithContext(ctx).Where("id = ?", id).First(&req).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
// UpdateAuthRequest updates an auth request
|
||||
func (s *Store) UpdateAuthRequest(ctx context.Context, req *AuthRequest) error {
|
||||
return s.db.WithContext(ctx).Save(req).Error
|
||||
}
|
||||
|
||||
// DeleteAuthRequest deletes an auth request
|
||||
func (s *Store) DeleteAuthRequest(ctx context.Context, id string) error {
|
||||
return s.db.WithContext(ctx).Delete(&AuthRequest{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// AuthCode operations
|
||||
|
||||
// SaveAuthCode saves an authorization code
|
||||
func (s *Store) SaveAuthCode(ctx context.Context, code *AuthCode) error {
|
||||
code.CreatedAt = time.Now()
|
||||
if code.ExpiresAt.IsZero() {
|
||||
code.ExpiresAt = time.Now().Add(10 * time.Minute) // 10 minute expiry
|
||||
}
|
||||
return s.db.WithContext(ctx).Create(code).Error
|
||||
}
|
||||
|
||||
// GetAuthCodeByCode retrieves an auth code
|
||||
func (s *Store) GetAuthCodeByCode(ctx context.Context, code string) (*AuthCode, error) {
|
||||
var authCode AuthCode
|
||||
err := s.db.WithContext(ctx).Where("code = ?", code).First(&authCode).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &authCode, nil
|
||||
}
|
||||
|
||||
// DeleteAuthCode deletes an auth code
|
||||
func (s *Store) DeleteAuthCode(ctx context.Context, code string) error {
|
||||
return s.db.WithContext(ctx).Delete(&AuthCode{}, "code = ?", code).Error
|
||||
}
|
||||
|
||||
// Token operations
|
||||
|
||||
// SaveAccessToken saves an access token
|
||||
func (s *Store) SaveAccessToken(ctx context.Context, token *AccessToken) error {
|
||||
if token.ID == "" {
|
||||
token.ID = uuid.New().String()
|
||||
}
|
||||
token.CreatedAt = time.Now()
|
||||
return s.db.WithContext(ctx).Create(token).Error
|
||||
}
|
||||
|
||||
// GetAccessTokenByID retrieves an access token
|
||||
func (s *Store) GetAccessTokenByID(ctx context.Context, id string) (*AccessToken, error) {
|
||||
var token AccessToken
|
||||
err := s.db.WithContext(ctx).Where("id = ?", id).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// DeleteAccessToken deletes an access token
|
||||
func (s *Store) DeleteAccessToken(ctx context.Context, id string) error {
|
||||
return s.db.WithContext(ctx).Delete(&AccessToken{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// RefreshToken operations
|
||||
|
||||
// SaveRefreshToken saves a refresh token
|
||||
func (s *Store) SaveRefreshToken(ctx context.Context, token *RefreshToken) error {
|
||||
if token.ID == "" {
|
||||
token.ID = uuid.New().String()
|
||||
}
|
||||
if token.Token == "" {
|
||||
token.Token = uuid.New().String()
|
||||
}
|
||||
token.CreatedAt = time.Now()
|
||||
return s.db.WithContext(ctx).Create(token).Error
|
||||
}
|
||||
|
||||
// GetRefreshToken retrieves a refresh token by token value
|
||||
func (s *Store) GetRefreshToken(ctx context.Context, token string) (*RefreshToken, error) {
|
||||
var rt RefreshToken
|
||||
err := s.db.WithContext(ctx).Where("token = ?", token).First(&rt).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rt, nil
|
||||
}
|
||||
|
||||
// DeleteRefreshToken deletes a refresh token
|
||||
func (s *Store) DeleteRefreshToken(ctx context.Context, id string) error {
|
||||
return s.db.WithContext(ctx).Delete(&RefreshToken{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// DeleteRefreshTokenByToken deletes a refresh token by token value
|
||||
func (s *Store) DeleteRefreshTokenByToken(ctx context.Context, token string) error {
|
||||
return s.db.WithContext(ctx).Delete(&RefreshToken{}, "token = ?", token).Error
|
||||
}
|
||||
|
||||
// DeviceAuth operations
|
||||
|
||||
// SaveDeviceAuth saves a device authorization
|
||||
func (s *Store) SaveDeviceAuth(ctx context.Context, auth *DeviceAuth) error {
|
||||
auth.CreatedAt = time.Now()
|
||||
return s.db.WithContext(ctx).Create(auth).Error
|
||||
}
|
||||
|
||||
// GetDeviceAuthByDeviceCode retrieves device auth by device code
|
||||
func (s *Store) GetDeviceAuthByDeviceCode(ctx context.Context, deviceCode string) (*DeviceAuth, error) {
|
||||
var auth DeviceAuth
|
||||
err := s.db.WithContext(ctx).Where("device_code = ?", deviceCode).First(&auth).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &auth, nil
|
||||
}
|
||||
|
||||
// GetDeviceAuthByUserCode retrieves device auth by user code
|
||||
func (s *Store) GetDeviceAuthByUserCode(ctx context.Context, userCode string) (*DeviceAuth, error) {
|
||||
var auth DeviceAuth
|
||||
err := s.db.WithContext(ctx).Where("user_code = ?", userCode).First(&auth).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &auth, nil
|
||||
}
|
||||
|
||||
// UpdateDeviceAuth updates a device authorization
|
||||
func (s *Store) UpdateDeviceAuth(ctx context.Context, auth *DeviceAuth) error {
|
||||
return s.db.WithContext(ctx).Save(auth).Error
|
||||
}
|
||||
|
||||
// DeleteDeviceAuth deletes a device authorization
|
||||
func (s *Store) DeleteDeviceAuth(ctx context.Context, deviceCode string) error {
|
||||
return s.db.WithContext(ctx).Delete(&DeviceAuth{}, "device_code = ?", deviceCode).Error
|
||||
}
|
||||
|
||||
// Cleanup operations
|
||||
|
||||
// CleanupExpired removes expired tokens and auth requests
|
||||
func (s *Store) CleanupExpired(ctx context.Context) error {
|
||||
now := time.Now()
|
||||
|
||||
// Delete expired auth codes
|
||||
if err := s.db.WithContext(ctx).Delete(&AuthCode{}, "expires_at < ?", now).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete expired access tokens
|
||||
if err := s.db.WithContext(ctx).Delete(&AccessToken{}, "expiration < ?", now).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete expired refresh tokens
|
||||
if err := s.db.WithContext(ctx).Delete(&RefreshToken{}, "expiration < ?", now).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete expired device authorizations
|
||||
if err := s.db.WithContext(ctx).Delete(&DeviceAuth{}, "expiration < ?", now).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete old auth requests (older than 1 hour)
|
||||
oneHourAgo := now.Add(-1 * time.Hour)
|
||||
if err := s.db.WithContext(ctx).Delete(&AuthRequest{}, "created_at < ?", oneHourAgo).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper functions for JSON serialization
|
||||
|
||||
// ParseJSONArray parses a JSON array string into a slice
|
||||
func ParseJSONArray(jsonStr string) []string {
|
||||
if jsonStr == "" {
|
||||
return nil
|
||||
}
|
||||
var result []string
|
||||
if err := json.Unmarshal([]byte(jsonStr), &result); err != nil {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ToJSONArray converts a slice to a JSON array string
|
||||
func ToJSONArray(arr []string) string {
|
||||
if len(arr) == 0 {
|
||||
return "[]"
|
||||
}
|
||||
data, err := json.Marshal(arr)
|
||||
if err != nil {
|
||||
return "[]"
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
261
idp/oidcprovider/templates/device.html
Normal file
261
idp/oidcprovider/templates/device.html
Normal file
@@ -0,0 +1,261 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Device Authorization - NetBird</title>
|
||||
<style>
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 20px;
|
||||
}
|
||||
.container {
|
||||
background: white;
|
||||
padding: 40px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
|
||||
width: 100%;
|
||||
max-width: 450px;
|
||||
}
|
||||
.logo {
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.logo h1 {
|
||||
font-size: 28px;
|
||||
color: #333;
|
||||
font-weight: 600;
|
||||
}
|
||||
.logo p {
|
||||
color: #666;
|
||||
margin-top: 8px;
|
||||
}
|
||||
.form-group {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
label {
|
||||
display: block;
|
||||
margin-bottom: 8px;
|
||||
color: #333;
|
||||
font-weight: 500;
|
||||
}
|
||||
input[type="text"],
|
||||
input[type="password"] {
|
||||
width: 100%;
|
||||
padding: 14px 16px;
|
||||
border: 2px solid #e1e5eb;
|
||||
border-radius: 8px;
|
||||
font-size: 16px;
|
||||
transition: border-color 0.2s, box-shadow 0.2s;
|
||||
}
|
||||
input.code-input {
|
||||
text-align: center;
|
||||
font-size: 24px;
|
||||
letter-spacing: 4px;
|
||||
text-transform: uppercase;
|
||||
}
|
||||
input[type="text"]:focus,
|
||||
input[type="password"]:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.2);
|
||||
}
|
||||
button {
|
||||
width: 100%;
|
||||
padding: 14px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 8px;
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: transform 0.2s, box-shadow 0.2s;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
button:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 10px 20px rgba(102, 126, 234, 0.3);
|
||||
}
|
||||
button:active {
|
||||
transform: translateY(0);
|
||||
}
|
||||
button.secondary {
|
||||
background: #e1e5eb;
|
||||
color: #333;
|
||||
}
|
||||
button.secondary:hover {
|
||||
background: #d1d5db;
|
||||
box-shadow: none;
|
||||
}
|
||||
button.deny {
|
||||
background: #dc2626;
|
||||
}
|
||||
button.deny:hover {
|
||||
background: #b91c1c;
|
||||
}
|
||||
.error {
|
||||
background: #fee;
|
||||
color: #c00;
|
||||
padding: 12px 16px;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 20px;
|
||||
font-size: 14px;
|
||||
border: 1px solid #fcc;
|
||||
}
|
||||
.success {
|
||||
background: #d4edda;
|
||||
color: #155724;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
text-align: center;
|
||||
font-size: 16px;
|
||||
border: 1px solid #c3e6cb;
|
||||
}
|
||||
.info {
|
||||
background: #e8f4fd;
|
||||
color: #0c5460;
|
||||
padding: 16px;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 20px;
|
||||
font-size: 14px;
|
||||
border: 1px solid #bee5eb;
|
||||
}
|
||||
.scopes {
|
||||
background: #f8f9fa;
|
||||
padding: 16px;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.scopes h3 {
|
||||
font-size: 14px;
|
||||
color: #666;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
.scopes ul {
|
||||
list-style: none;
|
||||
padding: 0;
|
||||
}
|
||||
.scopes li {
|
||||
padding: 8px 0;
|
||||
border-bottom: 1px solid #e1e5eb;
|
||||
color: #333;
|
||||
}
|
||||
.scopes li:last-child {
|
||||
border-bottom: none;
|
||||
}
|
||||
.button-group {
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
}
|
||||
.button-group button {
|
||||
flex: 1;
|
||||
}
|
||||
.footer {
|
||||
text-align: center;
|
||||
margin-top: 24px;
|
||||
color: #888;
|
||||
font-size: 13px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="logo">
|
||||
<h1>NetBird</h1>
|
||||
<p>Device Authorization</p>
|
||||
</div>
|
||||
|
||||
{{if .Error}}
|
||||
<div class="error">{{.Error}}</div>
|
||||
{{end}}
|
||||
|
||||
{{if eq .Step "code"}}
|
||||
<!-- Step 1: Enter user code -->
|
||||
<div class="info">
|
||||
Enter the code shown on your device to authorize it.
|
||||
</div>
|
||||
<form method="GET" action="/device">
|
||||
<div class="form-group">
|
||||
<label for="user_code">Device Code</label>
|
||||
<input type="text" id="user_code" name="user_code" class="code-input"
|
||||
placeholder="XXXX-XXXX" required autofocus
|
||||
pattern="[A-Za-z]{4}-?[A-Za-z]{4}">
|
||||
</div>
|
||||
<button type="submit">Continue</button>
|
||||
</form>
|
||||
{{end}}
|
||||
|
||||
{{if eq .Step "login"}}
|
||||
<!-- Step 2: Login -->
|
||||
<div class="info">
|
||||
Sign in to authorize the device.
|
||||
</div>
|
||||
<form method="POST" action="/device/login">
|
||||
<input type="hidden" name="user_code" value="{{.UserCode}}">
|
||||
<div class="form-group">
|
||||
<label for="username">Username</label>
|
||||
<input type="text" id="username" name="username" required autofocus>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="password">Password</label>
|
||||
<input type="password" id="password" name="password" required>
|
||||
</div>
|
||||
<button type="submit">Sign In</button>
|
||||
</form>
|
||||
{{end}}
|
||||
|
||||
{{if eq .Step "confirm"}}
|
||||
<!-- Step 3: Confirm authorization -->
|
||||
<div class="info">
|
||||
<strong>{{.ClientID}}</strong> is requesting access to your account.
|
||||
</div>
|
||||
|
||||
{{if .Scopes}}
|
||||
<div class="scopes">
|
||||
<h3>This application will have access to:</h3>
|
||||
<ul>
|
||||
{{range .Scopes}}
|
||||
<li>{{.}}</li>
|
||||
{{end}}
|
||||
</ul>
|
||||
</div>
|
||||
{{end}}
|
||||
|
||||
<form method="POST" action="/device/confirm">
|
||||
<div class="button-group">
|
||||
<button type="submit" name="action" value="allow">Allow</button>
|
||||
<button type="submit" name="action" value="deny" class="deny">Deny</button>
|
||||
</div>
|
||||
</form>
|
||||
{{end}}
|
||||
|
||||
{{if eq .Step "result"}}
|
||||
<!-- Result -->
|
||||
{{if .Success}}
|
||||
<div class="success">
|
||||
{{.Message}}
|
||||
</div>
|
||||
{{else}}
|
||||
<div class="info">
|
||||
{{.Message}}
|
||||
</div>
|
||||
{{end}}
|
||||
{{end}}
|
||||
|
||||
<div class="footer">
|
||||
Powered by NetBird Identity Provider
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
129
idp/oidcprovider/templates/login.html
Normal file
129
idp/oidcprovider/templates/login.html
Normal file
@@ -0,0 +1,129 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Login - NetBird</title>
|
||||
<style>
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 20px;
|
||||
}
|
||||
.login-container {
|
||||
background: white;
|
||||
padding: 40px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
|
||||
width: 100%;
|
||||
max-width: 400px;
|
||||
}
|
||||
.logo {
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.logo h1 {
|
||||
font-size: 28px;
|
||||
color: #333;
|
||||
font-weight: 600;
|
||||
}
|
||||
.logo p {
|
||||
color: #666;
|
||||
margin-top: 8px;
|
||||
}
|
||||
.form-group {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
label {
|
||||
display: block;
|
||||
margin-bottom: 8px;
|
||||
color: #333;
|
||||
font-weight: 500;
|
||||
}
|
||||
input[type="text"],
|
||||
input[type="password"] {
|
||||
width: 100%;
|
||||
padding: 14px 16px;
|
||||
border: 2px solid #e1e5eb;
|
||||
border-radius: 8px;
|
||||
font-size: 16px;
|
||||
transition: border-color 0.2s, box-shadow 0.2s;
|
||||
}
|
||||
input[type="text"]:focus,
|
||||
input[type="password"]:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.2);
|
||||
}
|
||||
button {
|
||||
width: 100%;
|
||||
padding: 14px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 8px;
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: transform 0.2s, box-shadow 0.2s;
|
||||
}
|
||||
button:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 10px 20px rgba(102, 126, 234, 0.3);
|
||||
}
|
||||
button:active {
|
||||
transform: translateY(0);
|
||||
}
|
||||
.error {
|
||||
background: #fee;
|
||||
color: #c00;
|
||||
padding: 12px 16px;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 20px;
|
||||
font-size: 14px;
|
||||
border: 1px solid #fcc;
|
||||
}
|
||||
.footer {
|
||||
text-align: center;
|
||||
margin-top: 24px;
|
||||
color: #888;
|
||||
font-size: 13px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="login-container">
|
||||
<div class="logo">
|
||||
<h1>NetBird</h1>
|
||||
<p>Sign in to your account</p>
|
||||
</div>
|
||||
{{if .Error}}
|
||||
<div class="error">{{.Error}}</div>
|
||||
{{end}}
|
||||
<form method="POST" action="/login">
|
||||
<input type="hidden" name="authRequestID" value="{{.AuthRequestID}}">
|
||||
<div class="form-group">
|
||||
<label for="username">Username</label>
|
||||
<input type="text" id="username" name="username" required autofocus>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="password">Password</label>
|
||||
<input type="password" id="password" name="password" required>
|
||||
</div>
|
||||
<button type="submit">Sign In</button>
|
||||
</form>
|
||||
<div class="footer">
|
||||
Powered by NetBird Identity Provider
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@@ -42,6 +43,7 @@ type Controller struct {
|
||||
accountManagerMetrics *telemetry.AccountManagerMetrics
|
||||
peersUpdateManager network_map.PeersUpdateManager
|
||||
settingsManager settings.Manager
|
||||
EphemeralPeersManager ephemeral.Manager
|
||||
|
||||
accountUpdateLocks sync.Map
|
||||
sendAccountUpdateLocks sync.Map
|
||||
@@ -70,7 +72,7 @@ type bufferUpdate struct {
|
||||
|
||||
var _ network_map.Controller = (*Controller)(nil)
|
||||
|
||||
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, config *config.Config) *Controller {
|
||||
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
|
||||
nMetrics, err := newMetrics(metrics.UpdateChannelMetrics())
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
||||
@@ -99,7 +101,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
dnsDomain: dnsDomain,
|
||||
config: config,
|
||||
|
||||
proxyController: proxyController,
|
||||
proxyController: proxyController,
|
||||
EphemeralPeersManager: ephemeralPeersManager,
|
||||
|
||||
holder: types.NewHolder(),
|
||||
expNewNetworkMap: newNetworkMapBuilder,
|
||||
@@ -107,6 +110,31 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
|
||||
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get peer %s: %v", peerID, err)
|
||||
}
|
||||
|
||||
c.EphemeralPeersManager.OnPeerConnected(ctx, peer)
|
||||
|
||||
return c.peersUpdateManager.CreateChannel(ctx, peerID), nil
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerDisconnected(ctx context.Context, accountID string, peerID string) {
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peer %s: %v", peerID, err)
|
||||
return
|
||||
}
|
||||
c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
|
||||
func (c *Controller) CountStreams() int {
|
||||
return c.peersUpdateManager.CountStreams()
|
||||
}
|
||||
|
||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
var (
|
||||
@@ -366,38 +394,6 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peers, err := c.repo.GetAccountPeers(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
|
||||
c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: network.CurrentSerial(),
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
FirewallRules: []*proto.FirewallRule{},
|
||||
FirewallRulesIsEmpty: true,
|
||||
DNSConfig: &proto.DNSConfig{
|
||||
ForwarderPort: dnsFwdPort,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if isRequiresApproval {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
@@ -698,35 +694,83 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) {
|
||||
c.UpdatePeerInNetworkMapCache(accountId, peer)
|
||||
_ = c.bufferSendUpdateAccountPeers(context.Background(), accountId)
|
||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get peers by ids: %w", err)
|
||||
}
|
||||
|
||||
for _, peer := range peers {
|
||||
c.UpdatePeerInNetworkMapCache(accountID, peer)
|
||||
}
|
||||
|
||||
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error {
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
for _, peerID := range peerIDs {
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.onPeerAddedUpdNetworkMapCache(account, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
err = c.onPeerAddedUpdNetworkMapCache(account, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error {
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peers, err := c.repo.GetAccountPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
|
||||
for _, peerID := range peerIDs {
|
||||
c.peersUpdateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: network.CurrentSerial(),
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
FirewallRules: []*proto.FirewallRule{},
|
||||
FirewallRulesIsEmpty: true,
|
||||
DNSConfig: &proto.DNSConfig{
|
||||
ForwarderPort: dnsFwdPort,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
|
||||
continue
|
||||
}
|
||||
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -778,10 +822,6 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
return networkMap, nil
|
||||
}
|
||||
|
||||
func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) {
|
||||
func (c *Controller) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) {
|
||||
c.peersUpdateManager.CloseChannels(ctx, peerIDs)
|
||||
}
|
||||
|
||||
func (c *Controller) IsConnected(peerID string) bool {
|
||||
return c.peersUpdateManager.HasChannel(peerID)
|
||||
}
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
@@ -114,131 +107,3 @@ func TestComputeForwarderPort(t *testing.T) {
|
||||
t.Errorf("Expected %d for peers with unknown version, got %d", network_map.OldForwarderPort, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBufferUpdateAccountPeers(t *testing.T) {
|
||||
const (
|
||||
peersCount = 1000
|
||||
updateAccountInterval = 50 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32
|
||||
uapLastRun, dpLastRun atomic.Int64
|
||||
|
||||
totalNewRuns, totalOldRuns int
|
||||
)
|
||||
|
||||
uap := func(ctx context.Context, accountID string) {
|
||||
updatePeersDeleted.Store(deletedPeers.Load())
|
||||
updatePeersRuns.Add(1)
|
||||
uapLastRun.Store(time.Now().UnixMilli())
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Run("new approach", func(t *testing.T) {
|
||||
updatePeersRuns.Store(0)
|
||||
updatePeersDeleted.Store(0)
|
||||
deletedPeers.Store(0)
|
||||
|
||||
var mustore sync.Map
|
||||
bufupd := func(ctx context.Context, accountID string) {
|
||||
mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := mu.(*bufferUpdate)
|
||||
|
||||
if !b.mu.TryLock() {
|
||||
b.update.Store(true)
|
||||
return
|
||||
}
|
||||
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
uap(ctx, accountID)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
b.next = time.AfterFunc(updateAccountInterval, func() {
|
||||
uap(ctx, accountID)
|
||||
})
|
||||
}()
|
||||
}
|
||||
dp := func(ctx context.Context, accountID, peerID, userID string) error {
|
||||
deletedPeers.Add(1)
|
||||
dpLastRun.Store(time.Now().UnixMilli())
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
bufupd(ctx, accountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
am := mock_server.MockAccountManager{
|
||||
UpdateAccountPeersFunc: uap,
|
||||
BufferUpdateAccountPeersFunc: bufupd,
|
||||
DeletePeerFunc: dp,
|
||||
}
|
||||
empty := ""
|
||||
for range peersCount {
|
||||
//nolint
|
||||
am.DeletePeer(context.Background(), empty, empty, empty)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
|
||||
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
|
||||
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
|
||||
|
||||
totalNewRuns = int(updatePeersRuns.Load())
|
||||
})
|
||||
|
||||
t.Run("old approach", func(t *testing.T) {
|
||||
updatePeersRuns.Store(0)
|
||||
updatePeersDeleted.Store(0)
|
||||
deletedPeers.Store(0)
|
||||
|
||||
var mustore sync.Map
|
||||
bufupd := func(ctx context.Context, accountID string) {
|
||||
mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{})
|
||||
b := mu.(*sync.Mutex)
|
||||
|
||||
if !b.TryLock() {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(updateAccountInterval)
|
||||
b.Unlock()
|
||||
uap(ctx, accountID)
|
||||
}()
|
||||
}
|
||||
dp := func(ctx context.Context, accountID, peerID, userID string) error {
|
||||
deletedPeers.Add(1)
|
||||
dpLastRun.Store(time.Now().UnixMilli())
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
bufupd(ctx, accountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
am := mock_server.MockAccountManager{
|
||||
UpdateAccountPeersFunc: uap,
|
||||
BufferUpdateAccountPeersFunc: bufupd,
|
||||
DeletePeerFunc: dp,
|
||||
}
|
||||
empty := ""
|
||||
for range peersCount {
|
||||
//nolint
|
||||
am.DeletePeer(context.Background(), empty, empty, empty)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
|
||||
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
|
||||
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
|
||||
|
||||
totalOldRuns = int(updatePeersRuns.Load())
|
||||
})
|
||||
assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
|
||||
t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ type Repository interface {
|
||||
GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error)
|
||||
GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error)
|
||||
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
|
||||
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
|
||||
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
@@ -37,3 +39,11 @@ func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]*
|
||||
func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) {
|
||||
return r.store.GetAccountByPeerID(ctx, peerID)
|
||||
}
|
||||
|
||||
func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) {
|
||||
return r.store.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, peerIDs)
|
||||
}
|
||||
|
||||
func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) {
|
||||
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
@@ -28,12 +28,12 @@ type Controller interface {
|
||||
GetDNSDomain(settings *types.Settings) string
|
||||
StartWarmup(context.Context)
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
CountStreams() int
|
||||
|
||||
DeletePeer(ctx context.Context, accountId string, peerId string) error
|
||||
|
||||
OnPeerUpdated(accountId string, peer *nbpeer.Peer)
|
||||
OnPeerAdded(ctx context.Context, accountID string, peerID string) error
|
||||
OnPeerDeleted(ctx context.Context, accountID string, peerID string) error
|
||||
DisconnectPeers(ctx context.Context, peerIDs []string)
|
||||
IsConnected(peerID string) bool
|
||||
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error
|
||||
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error
|
||||
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
|
||||
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
|
||||
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user