mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-28 11:39:55 +00:00
Compare commits
1 Commits
add-steamo
...
feature/an
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a94dfd732f |
158
.github/workflows/release.yml
vendored
158
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.1.4"
|
||||
SIGN_PIPE_VER: "v0.1.2"
|
||||
GORELEASER_VER: "v2.14.3"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
@@ -115,12 +115,6 @@ jobs:
|
||||
|
||||
release:
|
||||
runs-on: ubuntu-latest-m
|
||||
outputs:
|
||||
release_artifact_url: ${{ steps.upload_release.outputs.artifact-url }}
|
||||
linux_packages_artifact_url: ${{ steps.upload_linux_packages.outputs.artifact-url }}
|
||||
windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }}
|
||||
macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }}
|
||||
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
@@ -219,13 +213,10 @@ jobs:
|
||||
if: always()
|
||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||
- name: Tag and push images (amd64 only)
|
||||
id: tag_and_push_images
|
||||
if: |
|
||||
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
|
||||
(github.event_name == 'push' && github.ref == 'refs/heads/main')
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
resolve_tags() {
|
||||
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||
echo "pr-${{ github.event.pull_request.number }}"
|
||||
@@ -234,17 +225,6 @@ jobs:
|
||||
fi
|
||||
}
|
||||
|
||||
ghcr_package_url() {
|
||||
local image="$1" package encoded_package
|
||||
package="${image#ghcr.io/}"
|
||||
package="${package#*/}"
|
||||
package="${package%%:*}"
|
||||
encoded_package="${package//\//%2F}"
|
||||
echo "https://github.com/orgs/netbirdio/packages/container/package/${encoded_package}"
|
||||
}
|
||||
|
||||
image_refs=()
|
||||
|
||||
tag_and_push() {
|
||||
local src="$1" img_name tag dst
|
||||
img_name="${src%%:*}"
|
||||
@@ -253,56 +233,35 @@ jobs:
|
||||
echo "Tagging ${src} -> ${dst}"
|
||||
docker tag "$src" "$dst"
|
||||
docker push "$dst"
|
||||
image_refs+=("$dst")
|
||||
done
|
||||
}
|
||||
|
||||
cat > /tmp/goreleaser-artifacts.json <<'JSON'
|
||||
${{ steps.goreleaser.outputs.artifacts }}
|
||||
JSON
|
||||
export -f tag_and_push resolve_tags
|
||||
|
||||
mapfile -t src_images < <(
|
||||
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json
|
||||
)
|
||||
|
||||
for src in "${src_images[@]}"; do
|
||||
tag_and_push "$src"
|
||||
done
|
||||
|
||||
{
|
||||
echo "images_markdown<<EOF"
|
||||
if [[ ${#image_refs[@]} -eq 0 ]]; then
|
||||
echo "_No GHCR images were pushed._"
|
||||
else
|
||||
printf '%s\n' "${image_refs[@]}" | sort -u | while read -r image; do
|
||||
printf -- '- [`%s`](%s)\n' "$image" "$(ghcr_package_url "$image")"
|
||||
done
|
||||
fi
|
||||
echo "EOF"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
|
||||
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
|
||||
grep '^ghcr.io/' | while read -r SRC; do
|
||||
tag_and_push "$SRC"
|
||||
done
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release
|
||||
path: dist/
|
||||
retention-days: 7
|
||||
- name: upload linux packages
|
||||
id: upload_linux_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: linux-packages
|
||||
path: dist/netbird_linux**
|
||||
retention-days: 7
|
||||
- name: upload windows packages
|
||||
id: upload_windows_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-packages
|
||||
path: dist/netbird_windows**
|
||||
retention-days: 7
|
||||
- name: upload macos packages
|
||||
id: upload_macos_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: macos-packages
|
||||
@@ -311,8 +270,6 @@ jobs:
|
||||
|
||||
release_ui:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
||||
steps:
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
@@ -403,7 +360,6 @@ jobs:
|
||||
if: always()
|
||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-ui
|
||||
@@ -412,8 +368,6 @@ jobs:
|
||||
|
||||
release_ui_darwin:
|
||||
runs-on: macos-latest
|
||||
outputs:
|
||||
release_ui_darwin_artifact_url: ${{ steps.upload_release_ui_darwin.outputs.artifact-url }}
|
||||
steps:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
@@ -448,110 +402,12 @@ jobs:
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui_darwin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-ui-darwin
|
||||
path: dist/
|
||||
retention-days: 3
|
||||
|
||||
comment_release_artifacts:
|
||||
name: Comment release artifacts
|
||||
runs-on: ubuntu-latest
|
||||
needs: [release, release_ui, release_ui_darwin]
|
||||
if: ${{ always() && github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Create or update PR comment
|
||||
uses: actions/github-script@v7
|
||||
env:
|
||||
RELEASE_RESULT: ${{ needs.release.result }}
|
||||
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
|
||||
RELEASE_UI_DARWIN_RESULT: ${{ needs.release_ui_darwin.result }}
|
||||
RELEASE_ARTIFACT_URL: ${{ needs.release.outputs.release_artifact_url }}
|
||||
LINUX_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.linux_packages_artifact_url }}
|
||||
WINDOWS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.windows_packages_artifact_url }}
|
||||
MACOS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.macos_packages_artifact_url }}
|
||||
RELEASE_UI_ARTIFACT_URL: ${{ needs.release_ui.outputs.release_ui_artifact_url }}
|
||||
RELEASE_UI_DARWIN_ARTIFACT_URL: ${{ needs.release_ui_darwin.outputs.release_ui_darwin_artifact_url }}
|
||||
GHCR_IMAGES_MARKDOWN: ${{ needs.release.outputs.ghcr_images }}
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const marker = '<!-- netbird-release-artifacts -->';
|
||||
const { owner, repo } = context.repo;
|
||||
const issue_number = context.payload.pull_request.number;
|
||||
const runUrl = `${context.serverUrl}/${owner}/${repo}/actions/runs/${context.runId}`;
|
||||
const shortSha = context.payload.pull_request.head.sha.slice(0, 7);
|
||||
|
||||
const artifactCell = (url, result) => {
|
||||
if (url) return `[Download](${url})`;
|
||||
return result && result !== 'success' ? `_Not available (${result})_` : '_Not available_';
|
||||
};
|
||||
|
||||
const artifacts = [
|
||||
['All release artifacts', process.env.RELEASE_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
||||
['Linux packages', process.env.LINUX_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
||||
['Windows packages', process.env.WINDOWS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
||||
['macOS packages', process.env.MACOS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
||||
['UI artifacts', process.env.RELEASE_UI_ARTIFACT_URL, process.env.RELEASE_UI_RESULT],
|
||||
['UI macOS artifacts', process.env.RELEASE_UI_DARWIN_ARTIFACT_URL, process.env.RELEASE_UI_DARWIN_RESULT],
|
||||
];
|
||||
|
||||
const artifactRows = artifacts
|
||||
.map(([name, url, result]) => `| ${name} | ${artifactCell(url, result)} |`)
|
||||
.join('\n');
|
||||
|
||||
const ghcrImages = (process.env.GHCR_IMAGES_MARKDOWN || '').trim() || '_No GHCR images were pushed._';
|
||||
|
||||
const body = [
|
||||
marker,
|
||||
'## Release artifacts',
|
||||
'',
|
||||
`Built for PR head \`${shortSha}\` in [workflow run #${process.env.GITHUB_RUN_NUMBER}](${runUrl}).`,
|
||||
'',
|
||||
'| Artifact | Link |',
|
||||
'| --- | --- |',
|
||||
artifactRows,
|
||||
'',
|
||||
'### GHCR images (amd64)',
|
||||
ghcrImages,
|
||||
'',
|
||||
'_This comment is updated by the Release workflow. Artifact links expire according to the workflow retention policy._',
|
||||
].join('\n');
|
||||
|
||||
const comments = await github.paginate(github.rest.issues.listComments, {
|
||||
owner,
|
||||
repo,
|
||||
issue_number,
|
||||
per_page: 100,
|
||||
});
|
||||
|
||||
const previous = comments.find(comment =>
|
||||
comment.user?.type === 'Bot' && comment.body?.includes(marker)
|
||||
);
|
||||
|
||||
if (previous) {
|
||||
await github.rest.issues.updateComment({
|
||||
owner,
|
||||
repo,
|
||||
comment_id: previous.id,
|
||||
body,
|
||||
});
|
||||
core.info(`Updated release artifacts comment ${previous.id}`);
|
||||
} else {
|
||||
const { data } = await github.rest.issues.createComment({
|
||||
owner,
|
||||
repo,
|
||||
issue_number,
|
||||
body,
|
||||
});
|
||||
core.info(`Created release artifacts comment ${data.id}`);
|
||||
}
|
||||
|
||||
trigger_signer:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [release, release_ui, release_ui_darwin]
|
||||
|
||||
434
client/android/ssh_client.go
Normal file
434
client/android/ssh_client.go
Normal file
@@ -0,0 +1,434 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
)
|
||||
|
||||
const (
|
||||
sshDialTimeout = 30 * time.Second
|
||||
sshDetectionTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// SSHTerminalListener receives SSH session events. It is implemented in Java.
|
||||
//
|
||||
// All callbacks are invoked from goroutines and may run concurrently with each
|
||||
// other; the implementation must be safe to call from any thread.
|
||||
type SSHTerminalListener interface {
|
||||
OnConnected()
|
||||
OnData(data []byte)
|
||||
OnClose(reason string)
|
||||
OnError(message string)
|
||||
}
|
||||
|
||||
// SSHClient is a NetBird-aware SSH client exposed to Java via gomobile.
|
||||
//
|
||||
// It dials through the running NetBird tunnel and runs a standard SSH session
|
||||
// on top with PTY enabled. Host-key verification uses the NetBird-provided
|
||||
// peer SSH host keys, identical to the desktop client.
|
||||
type SSHClient struct {
|
||||
nb *Client
|
||||
mu sync.Mutex
|
||||
listener SSHTerminalListener
|
||||
urlOpener URLOpener
|
||||
|
||||
sshClient *gossh.Client
|
||||
session *gossh.Session
|
||||
stdin io.WriteCloser
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewSSHClient creates a new SSH client bound to the running NetBird Client.
|
||||
func NewSSHClient(c *Client) *SSHClient {
|
||||
return &SSHClient{nb: c}
|
||||
}
|
||||
|
||||
// SetListener registers the Java listener. Must be called before Connect to
|
||||
// receive any events.
|
||||
func (s *SSHClient) SetListener(l SSHTerminalListener) {
|
||||
s.mu.Lock()
|
||||
s.listener = l
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetURLOpener registers the Java URL opener used to display the device-code
|
||||
// authorization page in a Custom Tabs window when the target peer requires
|
||||
// JWT authentication. Must be set before Connect to be effective.
|
||||
func (s *SSHClient) SetURLOpener(opener URLOpener) {
|
||||
s.mu.Lock()
|
||||
s.urlOpener = opener
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// Connect dials the SSH server through the NetBird tunnel and performs the
|
||||
// SSH handshake. It auto-detects the server type via SSH banner inspection
|
||||
// and selects the appropriate authentication path:
|
||||
//
|
||||
// - NetBird-SSH server requiring JWT: launches the OAuth 2.0 device-code
|
||||
// flow, opens the verification URL through the registered URLOpener, and
|
||||
// uses the resulting token as the SSH password. Host-key verification
|
||||
// uses the NetBird peer registry.
|
||||
// - NetBird-SSH server without JWT: authenticates with the NetBird SSH
|
||||
// private key. Host-key verification uses the NetBird peer registry.
|
||||
// - Regular SSH server (e.g. OpenSSH): authenticates with the NetBird key
|
||||
// first (so a user-installed NetBird public key works), then falls back
|
||||
// to the supplied password if non-empty. Host-key verification is
|
||||
// disabled (TOFU pending).
|
||||
//
|
||||
// The password parameter is only consulted for regular SSH servers.
|
||||
func (s *SSHClient) Connect(host string, port int, user, password string) error {
|
||||
cfg, _, cc := s.nb.stateSnapshot()
|
||||
if cc == nil {
|
||||
return errors.New("netbird client not running")
|
||||
}
|
||||
if cfg == nil {
|
||||
return errors.New("netbird config not loaded")
|
||||
}
|
||||
engine := cc.Engine()
|
||||
if engine == nil {
|
||||
return errors.New("netbird engine not available")
|
||||
}
|
||||
|
||||
serverType := detectServerType(host, port)
|
||||
log.Infof("SSH server type for %s:%d: %s", host, port, serverType)
|
||||
|
||||
authMethods, hostKeyCallback, err := s.buildAuth(cfg, engine, serverType, password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientConfig := &gossh.ClientConfig{
|
||||
User: user,
|
||||
Auth: authMethods,
|
||||
HostKeyCallback: hostKeyCallback,
|
||||
Timeout: sshDialTimeout,
|
||||
}
|
||||
return s.dialAndHandshake(host, port, clientConfig)
|
||||
}
|
||||
|
||||
// StartSession requests a PTY and starts an interactive shell. Output from
|
||||
// the session is forwarded to the listener via OnData.
|
||||
func (s *SSHClient) StartSession(cols, rows int) error {
|
||||
log.Debugf("SSH: starting session %dx%d", cols, rows)
|
||||
s.mu.Lock()
|
||||
sshClient := s.sshClient
|
||||
s.mu.Unlock()
|
||||
|
||||
if sshClient == nil {
|
||||
return errors.New("ssh client not connected")
|
||||
}
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("new session: %w", err)
|
||||
}
|
||||
|
||||
modes := gossh.TerminalModes{
|
||||
gossh.ECHO: 1,
|
||||
gossh.TTY_OP_ISPEED: 14400,
|
||||
gossh.TTY_OP_OSPEED: 14400,
|
||||
gossh.VINTR: 3,
|
||||
gossh.VQUIT: 28,
|
||||
gossh.VERASE: 127,
|
||||
}
|
||||
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
|
||||
closeQuiet(session, "session after pty error")
|
||||
return fmt.Errorf("request pty: %w", err)
|
||||
}
|
||||
|
||||
stdin, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
closeQuiet(session, "session after stdin error")
|
||||
return fmt.Errorf("stdin pipe: %w", err)
|
||||
}
|
||||
stdout, err := session.StdoutPipe()
|
||||
if err != nil {
|
||||
closeQuiet(session, "session after stdout error")
|
||||
return fmt.Errorf("stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := session.StderrPipe()
|
||||
if err != nil {
|
||||
closeQuiet(session, "session after stderr error")
|
||||
return fmt.Errorf("stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
closeQuiet(session, "session after shell error")
|
||||
return fmt.Errorf("start shell: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.session = session
|
||||
s.stdin = stdin
|
||||
s.mu.Unlock()
|
||||
|
||||
go s.readLoop(stdout, "stdout")
|
||||
go s.readLoop(stderr, "stderr")
|
||||
log.Debug("SSH: session started, shell running")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write sends data to the SSH session stdin.
|
||||
func (s *SSHClient) Write(data []byte) error {
|
||||
s.mu.Lock()
|
||||
stdin := s.stdin
|
||||
s.mu.Unlock()
|
||||
if stdin == nil {
|
||||
return errors.New("ssh session not started")
|
||||
}
|
||||
if _, err := stdin.Write(data); err != nil {
|
||||
return fmt.Errorf("write stdin: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resize updates the PTY window size.
|
||||
func (s *SSHClient) Resize(cols, rows int) error {
|
||||
s.mu.Lock()
|
||||
session := s.session
|
||||
s.mu.Unlock()
|
||||
if session == nil {
|
||||
return errors.New("ssh session not started")
|
||||
}
|
||||
return session.WindowChange(rows, cols)
|
||||
}
|
||||
|
||||
// Close terminates the SSH session and underlying connection. Safe to call
|
||||
// multiple times.
|
||||
func (s *SSHClient) Close() error {
|
||||
s.mu.Lock()
|
||||
sshClient := s.sshClient
|
||||
session := s.session
|
||||
stdin := s.stdin
|
||||
s.sshClient = nil
|
||||
s.session = nil
|
||||
s.stdin = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
if stdin != nil {
|
||||
if err := stdin.Close(); err != nil {
|
||||
log.Debugf("ssh: stdin close: %v", err)
|
||||
}
|
||||
}
|
||||
if session != nil {
|
||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Debugf("ssh: session close: %v", err)
|
||||
}
|
||||
}
|
||||
var firstErr error
|
||||
if sshClient != nil {
|
||||
if err := sshClient.Close(); err != nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
s.notifyClose("closed by client")
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (s *SSHClient) buildAuth(cfg *profilemanager.Config, engine *internal.Engine,
|
||||
serverType detection.ServerType, password string) ([]gossh.AuthMethod, gossh.HostKeyCallback, error) {
|
||||
|
||||
switch serverType {
|
||||
case detection.ServerTypeNetBirdJWT:
|
||||
token, err := s.requestJWTToken(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("jwt: %w", err)
|
||||
}
|
||||
auths := []gossh.AuthMethod{gossh.Password(token)}
|
||||
return auths, nbssh.CreateHostKeyCallback(&engineHostKeyVerifier{engine: engine}), nil
|
||||
|
||||
case detection.ServerTypeNetBirdNoJWT:
|
||||
if cfg.SSHKey == "" {
|
||||
return nil, nil, errors.New("no NetBird SSH key available")
|
||||
}
|
||||
signer, err := gossh.ParsePrivateKey([]byte(cfg.SSHKey))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("parse netbird ssh key: %w", err)
|
||||
}
|
||||
auths := []gossh.AuthMethod{gossh.PublicKeys(signer)}
|
||||
return auths, nbssh.CreateHostKeyCallback(&engineHostKeyVerifier{engine: engine}), nil
|
||||
|
||||
default: // regular SSH
|
||||
var auths []gossh.AuthMethod
|
||||
if cfg.SSHKey != "" {
|
||||
if signer, err := gossh.ParsePrivateKey([]byte(cfg.SSHKey)); err == nil {
|
||||
auths = append(auths, gossh.PublicKeys(signer))
|
||||
} else {
|
||||
log.Debugf("ssh: parse netbird key for regular auth: %v", err)
|
||||
}
|
||||
}
|
||||
if password != "" {
|
||||
pw := password
|
||||
auths = append(auths, gossh.Password(pw))
|
||||
auths = append(auths, gossh.KeyboardInteractive(func(_, _ string, questions []string, _ []bool) ([]string, error) {
|
||||
answers := make([]string, len(questions))
|
||||
for i := range questions {
|
||||
answers[i] = pw
|
||||
}
|
||||
return answers, nil
|
||||
}))
|
||||
}
|
||||
if len(auths) == 0 {
|
||||
return nil, nil, errors.New("no auth method available: provide a password or configure NetBird SSH key")
|
||||
}
|
||||
return auths, gossh.InsecureIgnoreHostKey(), nil // nolint:gosec // TOFU not yet implemented
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHClient) requestJWTToken(cfg *profilemanager.Config) (string, error) {
|
||||
s.mu.Lock()
|
||||
urlOpener := s.urlOpener
|
||||
s.mu.Unlock()
|
||||
if urlOpener == nil {
|
||||
return "", errors.New("URL opener not configured for JWT auth")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
flow, err := auth.NewOAuthFlow(ctx, cfg, false, true, profilemanager.GetLoginHint())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create oauth flow: %w", err)
|
||||
}
|
||||
|
||||
flowInfo, err := flow.RequestAuthInfo(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request auth info: %w", err)
|
||||
}
|
||||
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
|
||||
tokenInfo, err := flow.WaitToken(ctx, flowInfo)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("wait for token: %w", err)
|
||||
}
|
||||
|
||||
token := tokenInfo.GetTokenToUse()
|
||||
if token == "" {
|
||||
return "", errors.New("empty token returned by IdP")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *SSHClient) dialAndHandshake(host string, port int, clientConfig *gossh.ClientConfig) error {
|
||||
addr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
log.Infof("SSH: connecting to %s as %s", addr, clientConfig.User)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
var dialer net.Dialer
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial %s: %w", addr, err)
|
||||
}
|
||||
|
||||
sshConn, chans, reqs, err := gossh.NewClientConn(conn, addr, clientConfig)
|
||||
if err != nil {
|
||||
if cerr := conn.Close(); cerr != nil {
|
||||
log.Debugf("ssh: close after handshake error: %v", cerr)
|
||||
}
|
||||
return fmt.Errorf("ssh handshake: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.sshClient = gossh.NewClient(sshConn, chans, reqs)
|
||||
listener := s.listener
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Infof("SSH: connected to %s", addr)
|
||||
if listener != nil {
|
||||
listener.OnConnected()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SSHClient) readLoop(r io.Reader, name string) {
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
if n > 0 {
|
||||
s.mu.Lock()
|
||||
listener := s.listener
|
||||
s.mu.Unlock()
|
||||
if listener != nil {
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
listener.OnData(chunk)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
log.Debugf("ssh %s read: %v", name, err)
|
||||
}
|
||||
s.notifyClose(err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHClient) notifyClose(reason string) {
|
||||
s.mu.Lock()
|
||||
if s.closed {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
s.closed = true
|
||||
listener := s.listener
|
||||
s.mu.Unlock()
|
||||
if listener != nil {
|
||||
listener.OnClose(reason)
|
||||
}
|
||||
}
|
||||
|
||||
// engineHostKeyVerifier adapts *internal.Engine to nbssh.HostKeyVerifier.
|
||||
type engineHostKeyVerifier struct {
|
||||
engine *internal.Engine
|
||||
}
|
||||
|
||||
func (v *engineHostKeyVerifier) VerifySSHHostKey(peerAddress string, presented []byte) error {
|
||||
storedKey, found := v.engine.GetPeerSSHKey(peerAddress)
|
||||
if !found {
|
||||
return nbssh.ErrPeerNotFound
|
||||
}
|
||||
return nbssh.VerifyHostKey(storedKey, presented, peerAddress)
|
||||
}
|
||||
|
||||
func closeQuiet(c io.Closer, label string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if err := c.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Debugf("ssh: close %s: %v", label, err)
|
||||
}
|
||||
}
|
||||
|
||||
func detectServerType(host string, port int) detection.ServerType {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sshDetectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||
if err != nil {
|
||||
log.Debugf("ssh: server detection for %s:%d failed: %v (assuming regular SSH)", host, port, err)
|
||||
return detection.ServerTypeRegular
|
||||
}
|
||||
return serverType
|
||||
}
|
||||
@@ -135,7 +135,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
// Package firewalld integrates with the firewalld daemon so NetBird can place
|
||||
// its wg interface into firewalld's "trusted" zone. This is required because
|
||||
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
|
||||
// versions, which returns EPERM to any other process that tries to insert
|
||||
// rules into them. The workaround mirrors what Tailscale does: let firewalld
|
||||
// itself add the accept rules to its own chains by trusting the interface.
|
||||
package firewalld
|
||||
|
||||
// TrustedZone is the firewalld zone name used for interfaces whose traffic
|
||||
// should bypass firewalld filtering.
|
||||
const TrustedZone = "trusted"
|
||||
@@ -1,260 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
dbusDest = "org.fedoraproject.FirewallD1"
|
||||
dbusPath = "/org/fedoraproject/FirewallD1"
|
||||
dbusRootIface = "org.fedoraproject.FirewallD1"
|
||||
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
|
||||
|
||||
errZoneAlreadySet = "ZONE_ALREADY_SET"
|
||||
errAlreadyEnabled = "ALREADY_ENABLED"
|
||||
errUnknownIface = "UNKNOWN_INTERFACE"
|
||||
errNotEnabled = "NOT_ENABLED"
|
||||
|
||||
// callTimeout bounds each individual DBus or firewall-cmd invocation.
|
||||
// A fresh context is created for each call so a slow DBus probe can't
|
||||
// exhaust the deadline before the firewall-cmd fallback gets to run.
|
||||
callTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
errDBusUnavailable = errors.New("firewalld dbus unavailable")
|
||||
|
||||
// trustLogOnce ensures the "added to trusted zone" message is logged at
|
||||
// Info level only for the first successful add per process; repeat adds
|
||||
// from other init paths are quieter.
|
||||
trustLogOnce sync.Once
|
||||
|
||||
parentCtxMu sync.RWMutex
|
||||
parentCtx context.Context = context.Background()
|
||||
)
|
||||
|
||||
// SetParentContext installs a parent context whose cancellation aborts any
|
||||
// in-flight TrustInterface call. It does not affect UntrustInterface, which
|
||||
// always uses a fresh Background-rooted timeout so cleanup can still run
|
||||
// during engine shutdown when the engine context is already cancelled.
|
||||
func SetParentContext(ctx context.Context) {
|
||||
parentCtxMu.Lock()
|
||||
parentCtx = ctx
|
||||
parentCtxMu.Unlock()
|
||||
}
|
||||
|
||||
func getParentContext() context.Context {
|
||||
parentCtxMu.RLock()
|
||||
defer parentCtxMu.RUnlock()
|
||||
return parentCtx
|
||||
}
|
||||
|
||||
// TrustInterface places iface into firewalld's trusted zone if firewalld is
|
||||
// running. It is idempotent and best-effort: errors are returned so callers
|
||||
// can log, but a non-running firewalld is not an error. Only the first
|
||||
// successful call per process logs at Info. Respects the parent context set
|
||||
// via SetParentContext so startup-time cancellation unblocks it.
|
||||
func TrustInterface(iface string) error {
|
||||
parent := getParentContext()
|
||||
if !isRunning(parent) {
|
||||
return nil
|
||||
}
|
||||
if err := addTrusted(parent, iface); err != nil {
|
||||
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
|
||||
}
|
||||
trustLogOnce.Do(func() {
|
||||
log.Infof("added %s to firewalld trusted zone", iface)
|
||||
})
|
||||
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
|
||||
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
|
||||
// during shutdown after the engine context has been cancelled.
|
||||
func UntrustInterface(iface string) error {
|
||||
if !isRunning(context.Background()) {
|
||||
return nil
|
||||
}
|
||||
if err := removeTrusted(context.Background(), iface); err != nil {
|
||||
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(parent, callTimeout)
|
||||
}
|
||||
|
||||
func isRunning(parent context.Context) bool {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
ok, err := isRunningDBus(ctx)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return ok
|
||||
}
|
||||
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return isRunningCLI(ctx)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func addTrusted(parent context.Context, iface string) error {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
err := addDBus(ctx, iface)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, errDBusUnavailable) {
|
||||
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
|
||||
}
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return addCLI(ctx, iface)
|
||||
}
|
||||
|
||||
func removeTrusted(parent context.Context, iface string) error {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
err := removeDBus(ctx, iface)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, errDBusUnavailable) {
|
||||
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
|
||||
}
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return removeCLI(ctx, iface)
|
||||
}
|
||||
|
||||
func isRunningDBus(ctx context.Context) (bool, error) {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
var zone string
|
||||
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
|
||||
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func isRunningCLI(ctx context.Context) bool {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return false
|
||||
}
|
||||
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
|
||||
}
|
||||
|
||||
func addDBus(ctx context.Context, iface string) error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
|
||||
if call.Err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errAlreadyEnabled) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errZoneAlreadySet) {
|
||||
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
|
||||
if move.Err != nil {
|
||||
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("firewalld addInterface: %w", call.Err)
|
||||
}
|
||||
|
||||
func removeDBus(ctx context.Context, iface string) error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
|
||||
if call.Err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
|
||||
}
|
||||
|
||||
func addCLI(ctx context.Context, iface string) error {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||
}
|
||||
|
||||
// --change-interface (no --permanent) binds the interface for the
|
||||
// current runtime only; we do not want membership to persist across
|
||||
// reboots because netbird re-asserts it on every startup.
|
||||
out, err := exec.CommandContext(ctx,
|
||||
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
|
||||
).CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeCLI(ctx context.Context, iface string) error {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||
}
|
||||
|
||||
out, err := exec.CommandContext(ctx,
|
||||
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
|
||||
).CombinedOutput()
|
||||
if err != nil {
|
||||
msg := strings.TrimSpace(string(out))
|
||||
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbusErrContains(err error, code string) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var de dbus.Error
|
||||
if errors.As(err, &de) {
|
||||
for _, b := range de.Body {
|
||||
if s, ok := b.(string); ok && strings.Contains(s, code) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Contains(err.Error(), code)
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
)
|
||||
|
||||
func TestDBusErrContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
code string
|
||||
want bool
|
||||
}{
|
||||
{"nil error", nil, errZoneAlreadySet, false},
|
||||
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
|
||||
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
|
||||
{
|
||||
"dbus.Error body match",
|
||||
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
|
||||
errZoneAlreadySet,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"dbus.Error body miss",
|
||||
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
|
||||
errAlreadyEnabled,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"dbus.Error non-string body falls back to Error()",
|
||||
dbus.Error{Name: "x", Body: []any{123}},
|
||||
"x",
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := dbusErrContains(tc.err, tc.code)
|
||||
if got != tc.want {
|
||||
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import "context"
|
||||
|
||||
// SetParentContext is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func SetParentContext(context.Context) {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
}
|
||||
|
||||
// TrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func TrustInterface(string) error {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
return nil
|
||||
}
|
||||
|
||||
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func UntrustInterface(string) error {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
return nil
|
||||
}
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -87,12 +86,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
||||
}
|
||||
|
||||
// Trust after all fatal init steps so a later failure doesn't leave the
|
||||
// interface in firewalld's trusted zone without a corresponding Close.
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
@@ -198,12 +191,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||
}
|
||||
|
||||
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
||||
// stays persisted and the crash-recovery path retries firewalld cleanup.
|
||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
// attempt to delete state only if all other operations succeeded
|
||||
if merr == nil {
|
||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||
@@ -230,11 +217,6 @@ func (m *Manager) AllowNetbird() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -218,10 +217,6 @@ func (m *Manager) AllowNetbird() error {
|
||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
@@ -41,8 +40,6 @@ const (
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
firewalldTableName = "firewalld"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||
userDataAcceptInputRule = "inputaccept"
|
||||
@@ -136,10 +133,6 @@ func (r *router) Reset() error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||
}
|
||||
|
||||
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.removeNatPreroutingRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||
}
|
||||
@@ -287,10 +280,6 @@ func (r *router) createContainers() error {
|
||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
log.Errorf("failed to refresh rules: %s", err)
|
||||
}
|
||||
@@ -1330,13 +1319,6 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip firewalld-owned chains. Firewalld creates its chains with the
|
||||
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
|
||||
// We delegate acceptance to firewalld by trusting the interface instead.
|
||||
if chain.Table.Name == firewalldTableName {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip all iptables-managed tables in the ip family
|
||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||
return false
|
||||
|
||||
@@ -3,9 +3,6 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -19,9 +16,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Close(stateManager)
|
||||
}
|
||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to untrust interface in firewalld: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -30,8 +24,5 @@ func (m *Manager) AllowNetbird() error {
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AllowNetbird()
|
||||
}
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() wgaddr.Address
|
||||
GetWGDevice() *wgdevice.Device
|
||||
|
||||
@@ -31,20 +31,12 @@ var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
type IFaceMock struct {
|
||||
NameFunc func() string
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
AddressFunc func() wgaddr.Address
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
}
|
||||
|
||||
func (i *IFaceMock) Name() string {
|
||||
if i.NameFunc == nil {
|
||||
return "wgtest"
|
||||
}
|
||||
return i.NameFunc()
|
||||
}
|
||||
|
||||
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||
if i.GetWGDeviceFunc == nil {
|
||||
return nil
|
||||
|
||||
@@ -239,12 +239,8 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
|
||||
ipv6Count++
|
||||
}
|
||||
|
||||
// Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
|
||||
// routing-correctness checks above are the real assertions; the counts
|
||||
// are a sanity bound to catch a totally silent path.
|
||||
minDelivered := packetsPerFamily * 80 / 100
|
||||
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
|
||||
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
|
||||
assert.Equal(t, packetsPerFamily, ipv4Count)
|
||||
assert.Equal(t, packetsPerFamily, ipv6Count)
|
||||
}
|
||||
|
||||
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
||||
|
||||
@@ -201,18 +201,7 @@ Pop $0
|
||||
|
||||
Function .onInit
|
||||
StrCpy $INSTDIR "${INSTALL_DIR}"
|
||||
; Default autostart to enabled so silent installs (/S) match the interactive default
|
||||
StrCpy $AutostartEnabled "1"
|
||||
|
||||
; Pre-0.70.1 installers ran without SetRegView, so their uninstall keys live
|
||||
; in the 32-bit view. Fall back to it so upgrades still find them.
|
||||
SetRegView 64
|
||||
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
|
||||
${If} $R0 == ""
|
||||
SetRegView 32
|
||||
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
|
||||
SetRegView 64
|
||||
${EndIf}
|
||||
${If} $R0 != ""
|
||||
# if silent install jump to uninstall step
|
||||
IfSilent uninstall
|
||||
@@ -225,10 +214,6 @@ ${If} $R0 != ""
|
||||
|
||||
${EndIf}
|
||||
FunctionEnd
|
||||
|
||||
Function un.onInit
|
||||
SetRegView 64
|
||||
FunctionEnd
|
||||
######################################################################
|
||||
Section -MainProgram
|
||||
${INSTALL_TYPE}
|
||||
@@ -243,7 +228,6 @@ Section -MainProgram
|
||||
!else
|
||||
File /r "..\\dist\\netbird_windows_amd64\\"
|
||||
!endif
|
||||
File "..\\client\\ui\\assets\\netbird.png"
|
||||
SectionEnd
|
||||
######################################################################
|
||||
|
||||
@@ -263,11 +247,9 @@ WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
||||
; Create autostart registry entry based on checkbox
|
||||
DetailPrint "Autostart enabled: $AutostartEnabled"
|
||||
${If} $AutostartEnabled == "1"
|
||||
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
|
||||
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
|
||||
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
||||
${Else}
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DetailPrint "Autostart not enabled by user"
|
||||
${EndIf}
|
||||
@@ -301,8 +283,6 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||
|
||||
; Remove autostart registry entry
|
||||
DetailPrint "Removing autostart registry entry if exists..."
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
|
||||
; Handle data deletion based on checkbox
|
||||
@@ -341,7 +321,6 @@ DetailPrint "Removing registry keys..."
|
||||
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
||||
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
||||
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
|
||||
DeleteRegKey HKCU "Software\Classes\AppUserModelId\${APP_NAME}"
|
||||
|
||||
DetailPrint "Removing application directory from PATH..."
|
||||
EnVar::SetHKLM
|
||||
|
||||
@@ -333,10 +333,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
c.statusRecorder.MarkSignalConnected()
|
||||
|
||||
relayURLs, token := parseRelayInfo(loginResp)
|
||||
if override, ok := peer.OverrideRelayURLs(); ok {
|
||||
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
|
||||
relayURLs = override
|
||||
}
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
||||
|
||||
@@ -3,12 +3,10 @@ package debug
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -21,10 +19,8 @@ func TestUpload(t *testing.T) {
|
||||
t.Skip("Skipping upload test on docker ci")
|
||||
}
|
||||
testDir := t.TempDir()
|
||||
addr := reserveLoopbackPort(t)
|
||||
testURL := "http://" + addr
|
||||
testURL := "http://localhost:8080"
|
||||
t.Setenv("SERVER_URL", testURL)
|
||||
t.Setenv("SERVER_ADDRESS", addr)
|
||||
t.Setenv("STORE_DIR", testDir)
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
@@ -37,7 +33,6 @@ func TestUpload(t *testing.T) {
|
||||
t.Errorf("Failed to stop server: %v", err)
|
||||
}
|
||||
})
|
||||
waitForServer(t, addr)
|
||||
|
||||
file := filepath.Join(t.TempDir(), "tmpfile")
|
||||
fileContent := []byte("test file content")
|
||||
@@ -52,30 +47,3 @@ func TestUpload(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fileContent, createdFileContent)
|
||||
}
|
||||
|
||||
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
|
||||
// address, then releases it so the server under test can rebind. The close/
|
||||
// rebind window is racy in theory; on loopback with a kernel-assigned port
|
||||
// it's essentially never contended in practice.
|
||||
func reserveLoopbackPort(t *testing.T) string {
|
||||
t.Helper()
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := l.Addr().String()
|
||||
require.NoError(t, l.Close())
|
||||
return addr
|
||||
}
|
||||
|
||||
func waitForServer(t *testing.T, addr string) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("server did not start listening on %s in time", addr)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
const (
|
||||
defaultResolvConfPath = "/etc/resolv.conf"
|
||||
nsswitchConfPath = "/etc/nsswitch.conf"
|
||||
)
|
||||
|
||||
type resolvConf struct {
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -195,12 +192,6 @@ func (c *HandlerChain) logHandlers() {
|
||||
}
|
||||
|
||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
c.dispatch(w, r, math.MaxInt)
|
||||
}
|
||||
|
||||
// dispatch routes a DNS request through the chain, skipping handlers with
|
||||
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
|
||||
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -225,9 +216,6 @@ func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority in
|
||||
|
||||
// Try handlers in priority order
|
||||
for _, entry := range handlers {
|
||||
if entry.Priority > maxPriority {
|
||||
continue
|
||||
}
|
||||
if !c.isHandlerMatch(qname, entry) {
|
||||
continue
|
||||
}
|
||||
@@ -285,55 +273,6 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
||||
cw.response.Len(), meta, time.Since(startTime))
|
||||
}
|
||||
|
||||
// ResolveInternal runs an in-process DNS query against the chain, skipping any
|
||||
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
|
||||
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
|
||||
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
|
||||
// (bounded by the invoked handler's internal timeout).
|
||||
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
||||
if len(r.Question) == 0 {
|
||||
return nil, fmt.Errorf("empty question")
|
||||
}
|
||||
|
||||
base := &internalResponseWriter{}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.dispatch(base, r, maxPriority)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
// Prefer a completed response if dispatch finished concurrently with cancellation.
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
|
||||
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
|
||||
strings.ToLower(r.Question[0].Name), maxPriority)
|
||||
}
|
||||
return base.response, nil
|
||||
}
|
||||
|
||||
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
|
||||
// priority ≤ maxPriority.
|
||||
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
for _, h := range c.handlers {
|
||||
if h.Pattern == "." && h.Priority <= maxPriority {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
switch {
|
||||
case entry.Pattern == ".":
|
||||
@@ -352,36 +291,3 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// internalResponseWriter captures a dns.Msg for in-process chain queries.
|
||||
type internalResponseWriter struct {
|
||||
response *dns.Msg
|
||||
}
|
||||
|
||||
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
|
||||
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
|
||||
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
|
||||
// still surface their answer to ResolveInternal.
|
||||
func (w *internalResponseWriter) Write(p []byte) (int, error) {
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.response = msg
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *internalResponseWriter) Close() error { return nil }
|
||||
func (w *internalResponseWriter) TsigStatus() error { return nil }
|
||||
|
||||
// TsigTimersOnly is part of dns.ResponseWriter.
|
||||
func (w *internalResponseWriter) TsigTimersOnly(bool) {
|
||||
// no-op: in-process queries carry no TSIG state.
|
||||
}
|
||||
|
||||
// Hijack is part of dns.ResponseWriter.
|
||||
func (w *internalResponseWriter) Hijack() {
|
||||
// no-op: in-process queries have no underlying connection to hand off.
|
||||
}
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
package dns_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
@@ -1046,163 +1042,3 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// answeringHandler writes a fixed A record to ack the query. Used to verify
|
||||
// which handler ResolveInternal dispatches to.
|
||||
type answeringHandler struct {
|
||||
name string
|
||||
ip string
|
||||
}
|
||||
|
||||
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Answer = []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP(h.ip).To4(),
|
||||
}}
|
||||
_ = w.WriteMsg(resp)
|
||||
}
|
||||
|
||||
func (h *answeringHandler) String() string { return h.name }
|
||||
|
||||
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
|
||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
|
||||
|
||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, 1, len(resp.Answer))
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.Error(t, err, "no handler at or below maxPriority should error")
|
||||
}
|
||||
|
||||
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
|
||||
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
|
||||
type rawWriteHandler struct {
|
||||
ip string
|
||||
}
|
||||
|
||||
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Answer = []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP(h.ip).To4(),
|
||||
}}
|
||||
packed, err := resp.Pack()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(packed)
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
|
||||
type hangingHandler struct {
|
||||
block chan struct{}
|
||||
}
|
||||
|
||||
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
<-h.block
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
_ = w.WriteMsg(resp)
|
||||
}
|
||||
|
||||
func (h *hangingHandler) String() string { return "hangingHandler" }
|
||||
|
||||
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
h := &hangingHandler{block: make(chan struct{})}
|
||||
defer close(h.block)
|
||||
|
||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
|
||||
}
|
||||
|
||||
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
|
||||
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
|
||||
|
||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
|
||||
|
||||
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
|
||||
|
||||
chain.AddHandler(".", h, nbdns.PriorityDefault)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
|
||||
|
||||
chain.RemoveHandler(".", nbdns.PriorityDefault)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||
|
||||
// Primary nsgroup case: root handler lands at PriorityUpstream.
|
||||
chain.AddHandler(".", h, nbdns.PriorityUpstream)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
|
||||
chain.RemoveHandler(".", nbdns.PriorityUpstream)
|
||||
|
||||
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
|
||||
chain.AddHandler(".", h, nbdns.PriorityFallback)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
|
||||
chain.RemoveHandler(".", nbdns.PriorityFallback)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||
}
|
||||
|
||||
@@ -46,12 +46,12 @@ type restoreHostManager interface {
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface string) (hostManager, error) {
|
||||
osManager, reason, err := getOSDNSManagerType()
|
||||
osManager, err := getOSDNSManagerType()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
|
||||
log.Infof("System DNS manager discovered: %s", osManager)
|
||||
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
||||
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
||||
if err != nil {
|
||||
@@ -74,49 +74,17 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
|
||||
}
|
||||
}
|
||||
|
||||
func getOSDNSManagerType() (osManagerType, string, error) {
|
||||
resolved := isSystemdResolvedRunning()
|
||||
nss := isLibnssResolveUsed()
|
||||
stub := checkStub()
|
||||
|
||||
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
|
||||
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
|
||||
// that go through nss-resolve, and in foreign mode they can loop back
|
||||
// through resolved as an upstream.
|
||||
if resolved && (nss || stub) {
|
||||
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
|
||||
}
|
||||
|
||||
mgr, reason, rejected, err := scanResolvConfHeader()
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
if reason != "" {
|
||||
return mgr, reason, nil
|
||||
}
|
||||
|
||||
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
|
||||
if len(rejected) > 0 {
|
||||
fallback += "; rejected: " + strings.Join(rejected, ", ")
|
||||
}
|
||||
return fileManager, fallback, nil
|
||||
}
|
||||
|
||||
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
|
||||
// matching manager. If reason is empty the caller should pick file mode and
|
||||
// use rejected for diagnostics.
|
||||
func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
||||
func getOSDNSManagerType() (osManagerType, error) {
|
||||
file, err := os.Open(defaultResolvConfPath)
|
||||
if err != nil {
|
||||
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||
}
|
||||
defer func() {
|
||||
if cerr := file.Close(); cerr != nil {
|
||||
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
|
||||
if err := file.Close(); err != nil {
|
||||
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
|
||||
}
|
||||
}()
|
||||
|
||||
var rejected []string
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
@@ -124,48 +92,41 @@ func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
||||
continue
|
||||
}
|
||||
if text[0] != '#' {
|
||||
break
|
||||
return fileManager, nil
|
||||
}
|
||||
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
|
||||
return mgr, reason, nil, nil
|
||||
} else if rej != "" {
|
||||
rejected = append(rejected, rej)
|
||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||
return netbirdManager, nil
|
||||
}
|
||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, nil
|
||||
}
|
||||
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
|
||||
if checkStub() {
|
||||
return systemdManager, nil
|
||||
} else {
|
||||
return fileManager, nil
|
||||
}
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, nil
|
||||
}
|
||||
|
||||
return resolvConfManager, nil
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||
return 0, "", nil, fmt.Errorf("scan: %w", err)
|
||||
return 0, fmt.Errorf("scan: %w", err)
|
||||
}
|
||||
return 0, "", rejected, nil
|
||||
|
||||
return fileManager, nil
|
||||
}
|
||||
|
||||
// matchResolvConfHeader inspects a single comment line. Returns either a
|
||||
// definitive (manager, reason) or a non-empty rejected diagnostic.
|
||||
func matchResolvConfHeader(text string) (osManagerType, string, string) {
|
||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||
return netbirdManager, "netbird-managed resolv.conf header detected", ""
|
||||
}
|
||||
if strings.Contains(text, "NetworkManager") {
|
||||
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, "NetworkManager header + supported version on dbus", ""
|
||||
}
|
||||
return 0, "", "NetworkManager header (no dbus or unsupported version)"
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
|
||||
}
|
||||
return resolvConfManager, "resolvconf header detected", ""
|
||||
}
|
||||
return 0, "", ""
|
||||
}
|
||||
|
||||
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
|
||||
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
|
||||
// into file mode while resolved is active.
|
||||
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
|
||||
func checkStub() bool {
|
||||
rConf, err := parseDefaultResolvConf()
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
|
||||
log.Warnf("failed to parse resolv conf: %s", err)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -178,36 +139,3 @@ func checkStub() bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
|
||||
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
|
||||
// delegated to systemd-resolved regardless of /etc/resolv.conf.
|
||||
func isLibnssResolveUsed() bool {
|
||||
bs, err := os.ReadFile(nsswitchConfPath)
|
||||
if err != nil {
|
||||
log.Debugf("read %s: %v", nsswitchConfPath, err)
|
||||
return false
|
||||
}
|
||||
return parseNsswitchResolveAhead(bs)
|
||||
}
|
||||
|
||||
func parseNsswitchResolveAhead(data []byte) bool {
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
if i := strings.IndexByte(line, '#'); i >= 0 {
|
||||
line = line[:i]
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 || fields[0] != "hosts:" {
|
||||
continue
|
||||
}
|
||||
for _, module := range fields[1:] {
|
||||
switch module {
|
||||
case "dns":
|
||||
return false
|
||||
case "resolve":
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseNsswitchResolveAhead(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "resolve before dns with action token",
|
||||
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "dns before resolve",
|
||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "debian default with only dns",
|
||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "neither resolve nor dns",
|
||||
in: "hosts: files myhostname\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "no hosts line",
|
||||
in: "passwd: files systemd\ngroup: files systemd\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
in: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "comments and blank lines ignored",
|
||||
in: "# comment\n\n# another\nhosts: resolve dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "trailing inline comment",
|
||||
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "hosts token must be the first field",
|
||||
in: " hosts: resolve dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "other db line mentioning resolve is ignored",
|
||||
in: "networks: resolve\nhosts: dns\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "only resolve, no dns",
|
||||
in: "hosts: files resolve\n",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
|
||||
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,83 +2,40 @@ package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsTimeout = 5 * time.Second
|
||||
defaultTTL = 300 * time.Second
|
||||
refreshBackoff = 30 * time.Second
|
||||
const dnsTimeout = 5 * time.Second
|
||||
|
||||
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
|
||||
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
|
||||
)
|
||||
|
||||
// ChainResolver lets the cache refresh stale entries through the DNS handler
|
||||
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
|
||||
// system resolver.
|
||||
type ChainResolver interface {
|
||||
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
|
||||
HasRootHandlerAtOrBelow(maxPriority int) bool
|
||||
}
|
||||
|
||||
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
|
||||
// records and cachedAt are set at construction and treated as immutable;
|
||||
// lastFailedRefresh and consecFailures are mutable and must be accessed under
|
||||
// Resolver.mutex.
|
||||
type cachedRecord struct {
|
||||
records []dns.RR
|
||||
cachedAt time.Time
|
||||
lastFailedRefresh time.Time
|
||||
consecFailures int
|
||||
}
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains.
|
||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
||||
// Resolver caches critical NetBird infrastructure domains
|
||||
type Resolver struct {
|
||||
records map[dns.Question]*cachedRecord
|
||||
records map[dns.Question][]dns.RR
|
||||
mgmtDomain *domain.Domain
|
||||
serverDomains *dnsconfig.ServerDomains
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
chain ChainResolver
|
||||
chainMaxPriority int
|
||||
refreshGroup singleflight.Group
|
||||
|
||||
// refreshing tracks questions whose refresh is running via the OS
|
||||
// fallback path. A ServeDNS hit for a question in this map indicates
|
||||
// the OS resolver routed the recursive query back to us (loop). Only
|
||||
// the OS path arms this so chain-path refreshes don't produce false
|
||||
// positives. The atomic bool is CAS-flipped once per refresh to
|
||||
// throttle the warning log.
|
||||
refreshing map[dns.Question]*atomic.Bool
|
||||
|
||||
cacheTTL time.Duration
|
||||
type ipsResponse struct {
|
||||
ips []netip.Addr
|
||||
err error
|
||||
}
|
||||
|
||||
// NewResolver creates a new management domains cache resolver.
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
records: make(map[dns.Question][]dns.RR),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,19 +44,7 @@ func (m *Resolver) String() string {
|
||||
return "MgmtCacheResolver"
|
||||
}
|
||||
|
||||
// SetChainResolver wires the handler chain used to refresh stale cache entries.
|
||||
// maxPriority caps which handlers may answer refresh queries (typically
|
||||
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
|
||||
// mgmt/route/local handlers are skipped).
|
||||
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
|
||||
m.mutex.Lock()
|
||||
m.chain = chain
|
||||
m.chainMaxPriority = maxPriority
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// ServeDNS serves cached A/AAAA records. Stale entries are returned
|
||||
// immediately and refreshed asynchronously (stale-while-revalidate).
|
||||
// ServeDNS implements dns.Handler interface.
|
||||
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
m.continueToNext(w, r)
|
||||
@@ -115,14 +60,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
cached, found := m.records[question]
|
||||
inflight := m.refreshing[question]
|
||||
var shouldRefresh bool
|
||||
if found {
|
||||
stale := time.Since(cached.cachedAt) > m.cacheTTL
|
||||
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
|
||||
shouldRefresh = stale && !inBackoff
|
||||
}
|
||||
records, found := m.records[question]
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if !found {
|
||||
@@ -130,23 +68,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
if inflight != nil && inflight.CompareAndSwap(false, true) {
|
||||
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
|
||||
question.Name)
|
||||
}
|
||||
|
||||
// Skip scheduling a refresh goroutine if one is already inflight for
|
||||
// this question; singleflight would dedup anyway but skipping avoids
|
||||
// a parked goroutine per stale hit under bursty load.
|
||||
if shouldRefresh && inflight == nil {
|
||||
m.scheduleRefresh(question, cached)
|
||||
}
|
||||
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Authoritative = false
|
||||
resp.RecursionAvailable = true
|
||||
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
|
||||
|
||||
resp.Answer = append(resp.Answer, records...)
|
||||
|
||||
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
||||
|
||||
@@ -171,260 +98,101 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||
// entry for that qtype.
|
||||
// AddDomain manually adds a domain to cache by resolving it.
|
||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
|
||||
|
||||
if errA != nil && errAAAA != nil {
|
||||
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||
ips, err := lookupIPWithExtraTimeout(ctx, d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
|
||||
if err := errors.Join(errA, errAAAA); err != nil {
|
||||
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
|
||||
var aRecords, aaaaRecords []dns.RR
|
||||
for _, ip := range ips {
|
||||
if ip.Is4() {
|
||||
rr := &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dnsName,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
A: ip.AsSlice(),
|
||||
}
|
||||
aRecords = append(aRecords, rr)
|
||||
} else if ip.Is6() {
|
||||
rr := &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dnsName,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
AAAA: ip.AsSlice(),
|
||||
}
|
||||
aaaaRecords = append(aaaaRecords, rr)
|
||||
}
|
||||
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
|
||||
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
|
||||
if len(aRecords) > 0 {
|
||||
aQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
m.records[aQuestion] = aRecords
|
||||
}
|
||||
|
||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||
if len(aaaaRecords) > 0 {
|
||||
aaaaQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeAAAA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
m.records[aaaaQuestion] = aaaaRecords
|
||||
}
|
||||
|
||||
m.mutex.Unlock()
|
||||
|
||||
log.Debugf("added domain=%s with %d A records and %d AAAA records",
|
||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
|
||||
// untouched on error. Caller holds m.mutex.
|
||||
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
switch {
|
||||
case len(records) > 0:
|
||||
m.records[q] = &cachedRecord{records: records, cachedAt: now}
|
||||
case err == nil:
|
||||
delete(m.records, q)
|
||||
}
|
||||
}
|
||||
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
|
||||
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
|
||||
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
|
||||
resultChan := make(chan *ipsResponse, 1)
|
||||
|
||||
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
|
||||
// unique in-flight key; bursty stale hits share its channel. expected is the
|
||||
// cachedRecord pointer observed by the caller; the refresh only mutates the
|
||||
// cache if that pointer is still the one stored, so a stale in-flight refresh
|
||||
// can't clobber a newer entry written by AddDomain or a competing refresh.
|
||||
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
|
||||
key := question.Name + "|" + dns.TypeToString[question.Qtype]
|
||||
_ = m.refreshGroup.DoChan(key, func() (any, error) {
|
||||
return nil, m.refreshQuestion(question, expected)
|
||||
})
|
||||
}
|
||||
|
||||
// refreshQuestion replaces the cached records on success, or marks the entry
|
||||
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
|
||||
// a resolver loop by spotting a query for this same question arriving on us.
|
||||
// expected pins the cache entry observed at schedule time; mutations only apply
|
||||
// if m.records[question] still points at it.
|
||||
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
|
||||
if err != nil {
|
||||
m.markRefreshFailed(question, expected)
|
||||
return fmt.Errorf("parse domain: %w", err)
|
||||
}
|
||||
|
||||
records, err := m.lookupRecords(ctx, d, question)
|
||||
if err != nil {
|
||||
fails := m.markRefreshFailed(question, expected)
|
||||
logf := log.Warnf
|
||||
if fails == 0 || fails > 1 {
|
||||
logf = log.Debugf
|
||||
go func() {
|
||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
||||
resultChan <- &ipsResponse{
|
||||
err: err,
|
||||
ips: ips,
|
||||
}
|
||||
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
|
||||
return err
|
||||
}()
|
||||
|
||||
var resp *ipsResponse
|
||||
|
||||
select {
|
||||
case <-time.After(dnsTimeout + time.Millisecond*500):
|
||||
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
|
||||
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case resp = <-resultChan:
|
||||
}
|
||||
|
||||
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
|
||||
if len(records) == 0 {
|
||||
m.mutex.Lock()
|
||||
if m.records[question] == expected {
|
||||
delete(m.records, question)
|
||||
m.mutex.Unlock()
|
||||
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
if resp.err != nil {
|
||||
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
m.mutex.Lock()
|
||||
if m.records[question] != expected {
|
||||
m.mutex.Unlock()
|
||||
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
m.records[question] = &cachedRecord{records: records, cachedAt: now}
|
||||
m.mutex.Unlock()
|
||||
|
||||
log.Infof("refreshed mgmt cache domain=%s type=%s",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Resolver) markRefreshing(question dns.Question) {
|
||||
m.mutex.Lock()
|
||||
m.refreshing[question] = &atomic.Bool{}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *Resolver) clearRefreshing(question dns.Question) {
|
||||
m.mutex.Lock()
|
||||
delete(m.refreshing, question)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// markRefreshFailed arms the backoff and returns the new consecutive-failure
|
||||
// count so callers can downgrade subsequent failure logs to debug.
|
||||
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
c, ok := m.records[question]
|
||||
if !ok || c != expected {
|
||||
return 0
|
||||
}
|
||||
c.lastFailedRefresh = time.Now()
|
||||
c.consecFailures++
|
||||
return c.consecFailures
|
||||
}
|
||||
|
||||
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
|
||||
// callers tell records, NODATA (nil err, no records), and failure apart.
|
||||
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
|
||||
m.mutex.RLock()
|
||||
chain := m.chain
|
||||
maxPriority := m.chainMaxPriority
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
|
||||
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: drop once every supported OS registers a fallback resolver. Safe
|
||||
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
|
||||
// not the system resolver, so net.DefaultResolver will not loop back.
|
||||
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
|
||||
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
|
||||
return
|
||||
}
|
||||
|
||||
// lookupRecords resolves a single record type via chain or OS. The OS branch
|
||||
// arms the loop detector for the duration of its call so that ServeDNS can
|
||||
// spot the OS resolver routing the recursive query back to us.
|
||||
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
|
||||
m.mutex.RLock()
|
||||
chain := m.chain
|
||||
maxPriority := m.chainMaxPriority
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
|
||||
}
|
||||
|
||||
// TODO: drop once every supported OS registers a fallback resolver.
|
||||
m.markRefreshing(q)
|
||||
defer m.clearRefreshing(q)
|
||||
|
||||
return m.osLookup(ctx, d, q.Name, q.Qtype)
|
||||
}
|
||||
|
||||
// lookupViaChain resolves via the handler chain and rewrites each RR to use
|
||||
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
|
||||
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
|
||||
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
|
||||
msg := &dns.Msg{}
|
||||
msg.SetQuestion(dnsName, qtype)
|
||||
msg.RecursionDesired = true
|
||||
|
||||
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("chain resolve: %w", err)
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("chain resolve returned nil response")
|
||||
}
|
||||
if resp.Rcode != dns.RcodeSuccess {
|
||||
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
|
||||
}
|
||||
|
||||
ttl := uint32(m.cacheTTL.Seconds())
|
||||
owners := cnameOwners(dnsName, resp.Answer)
|
||||
var filtered []dns.RR
|
||||
for _, rr := range resp.Answer {
|
||||
h := rr.Header()
|
||||
if h.Class != dns.ClassINET || h.Rrtype != qtype {
|
||||
continue
|
||||
}
|
||||
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
|
||||
continue
|
||||
}
|
||||
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
|
||||
filtered = append(filtered, cp)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// osLookup resolves a single family via net.DefaultResolver using resutil,
|
||||
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
|
||||
// returns (nil, nil).
|
||||
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
|
||||
network := resutil.NetworkForQtype(qtype)
|
||||
if network == "" {
|
||||
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
|
||||
}
|
||||
|
||||
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||
|
||||
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
|
||||
if result.Rcode == dns.RcodeSuccess {
|
||||
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
|
||||
}
|
||||
|
||||
if result.Err != nil {
|
||||
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
|
||||
}
|
||||
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
|
||||
}
|
||||
|
||||
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
|
||||
// so downstream resolvers don't cache an answer for longer than we will.
|
||||
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
|
||||
remaining := m.cacheTTL - time.Since(cachedAt)
|
||||
if remaining <= 0 {
|
||||
return 0
|
||||
}
|
||||
return uint32((remaining + time.Second - 1) / time.Second)
|
||||
return resp.ips, nil
|
||||
}
|
||||
|
||||
// PopulateFromConfig extracts and caches domains from the client configuration.
|
||||
@@ -456,12 +224,19 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
|
||||
delete(m.records, qA)
|
||||
delete(m.records, qAAAA)
|
||||
delete(m.refreshing, qA)
|
||||
delete(m.refreshing, qAAAA)
|
||||
aQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
delete(m.records, aQuestion)
|
||||
|
||||
aaaaQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeAAAA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
delete(m.records, aaaaQuestion)
|
||||
|
||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||
return nil
|
||||
@@ -619,73 +394,3 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
||||
|
||||
return domains
|
||||
}
|
||||
|
||||
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
|
||||
// A/AAAA records return nil.
|
||||
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
cp := *r
|
||||
cp.Hdr.Name = owner
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.A = slices.Clone(r.A)
|
||||
return &cp
|
||||
case *dns.AAAA:
|
||||
cp := *r
|
||||
cp.Hdr.Name = owner
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.AAAA = slices.Clone(r.AAAA)
|
||||
return &cp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
|
||||
// stamping ttl so the response shares no memory with the cached slice.
|
||||
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
|
||||
out := make([]dns.RR, 0, len(records))
|
||||
for _, rr := range records {
|
||||
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
|
||||
out = append(out, cp)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
|
||||
// in answer, iterating until fixed point so out-of-order chains resolve.
|
||||
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
|
||||
owners := map[string]bool{dnsName: true}
|
||||
for {
|
||||
added := false
|
||||
for _, rr := range answer {
|
||||
cname, ok := rr.(*dns.CNAME)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
|
||||
if !owners[name] {
|
||||
continue
|
||||
}
|
||||
target := strings.ToLower(dns.Fqdn(cname.Target))
|
||||
if !owners[target] {
|
||||
owners[target] = true
|
||||
added = true
|
||||
}
|
||||
}
|
||||
if !added {
|
||||
return owners
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
|
||||
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
|
||||
func resolveCacheTTL() time.Duration {
|
||||
if v := os.Getenv(envMgmtCacheTTL); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil && d > 0 {
|
||||
return d
|
||||
}
|
||||
}
|
||||
return defaultTTL
|
||||
}
|
||||
|
||||
@@ -1,408 +0,0 @@
|
||||
package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type fakeChain struct {
|
||||
mu sync.Mutex
|
||||
calls map[string]int
|
||||
answers map[string][]dns.RR
|
||||
err error
|
||||
hasRoot bool
|
||||
onLookup func()
|
||||
}
|
||||
|
||||
func newFakeChain() *fakeChain {
|
||||
return &fakeChain{
|
||||
calls: map[string]int{},
|
||||
answers: map[string][]dns.RR{},
|
||||
hasRoot: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.hasRoot
|
||||
}
|
||||
|
||||
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
||||
f.mu.Lock()
|
||||
q := msg.Question[0]
|
||||
key := q.Name + "|" + dns.TypeToString[q.Qtype]
|
||||
f.calls[key]++
|
||||
answers := f.answers[key]
|
||||
err := f.err
|
||||
onLookup := f.onLookup
|
||||
f.mu.Unlock()
|
||||
|
||||
if onLookup != nil {
|
||||
onLookup()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(msg)
|
||||
resp.Answer = answers
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
key := name + "|" + dns.TypeToString[qtype]
|
||||
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
|
||||
case dns.TypeAAAA:
|
||||
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.calls[name+"|"+dns.TypeToString[qtype]]
|
||||
}
|
||||
|
||||
// waitFor polls the predicate until it returns true or the deadline passes.
|
||||
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(d)
|
||||
for time.Now().Before(deadline) {
|
||||
if fn() {
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("condition not met within %s", d)
|
||||
}
|
||||
|
||||
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
|
||||
t.Helper()
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(name, dns.TypeA)
|
||||
w := &test.MockResponseWriter{}
|
||||
r.ServeDNS(w, msg)
|
||||
return w.GetLastResponse()
|
||||
}
|
||||
|
||||
func firstA(t *testing.T, resp *dns.Msg) string {
|
||||
t.Helper()
|
||||
require.NotNil(t, resp)
|
||||
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
require.True(t, ok, "expected A record")
|
||||
return a.A.String()
|
||||
}
|
||||
|
||||
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
|
||||
// Same cached entry age, different cacheTTL values: the shorter TTL must
|
||||
// trigger a background refresh, the longer one must not. Proves that the
|
||||
// per-Resolver cacheTTL field actually drives the stale decision.
|
||||
cachedAt := time.Now().Add(-100 * time.Millisecond)
|
||||
|
||||
newRec := func() *cachedRecord {
|
||||
return &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: cachedAt,
|
||||
}
|
||||
}
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
|
||||
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r.cacheTTL = 10 * time.Millisecond
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
r.records[q] = newRec()
|
||||
|
||||
resp := queryA(t, r, q.Name)
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount(q.Name, dns.TypeA) >= 1
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r.cacheTTL = time.Hour
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
r.records[q] = newRec()
|
||||
|
||||
resp := queryA(t, r, q.Name)
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(), // fresh
|
||||
}
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
|
||||
}
|
||||
|
||||
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
|
||||
}
|
||||
|
||||
// First query: serves stale immediately.
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
|
||||
})
|
||||
|
||||
// Next query should now return the refreshed IP.
|
||||
waitFor(t, time.Second, func() bool {
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
|
||||
var inflight atomic.Int32
|
||||
var maxInflight atomic.Int32
|
||||
chain.onLookup = func() {
|
||||
cur := inflight.Add(1)
|
||||
defer inflight.Add(-1)
|
||||
for {
|
||||
prev := maxInflight.Load()
|
||||
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
|
||||
}
|
||||
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
waitFor(t, 2*time.Second, func() bool {
|
||||
return inflight.Load() == 0
|
||||
})
|
||||
|
||||
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
|
||||
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
|
||||
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
|
||||
}
|
||||
|
||||
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
||||
}
|
||||
|
||||
// First stale hit triggers a refresh attempt that fails.
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
|
||||
})
|
||||
waitFor(t, time.Second, func() bool {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
c, ok := r.records[q]
|
||||
return ok && !c.lastFailedRefresh.IsZero()
|
||||
})
|
||||
|
||||
// Subsequent stale hits within backoff window should not schedule more refreshes.
|
||||
for i := 0; i < 10; i++ {
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
|
||||
}
|
||||
|
||||
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.hasRoot = false
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
// With hasRoot=false the chain must not be consulted. Use a short
|
||||
// deadline so the OS fallback returns quickly without waiting on a
|
||||
// real network call in CI.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
|
||||
|
||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
|
||||
"chain must not be used when no root handler is registered at the bound priority")
|
||||
}
|
||||
|
||||
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
|
||||
// ServeDNS being invoked for a question while a refresh for that question
|
||||
// is inflight indicates a resolver loop (OS resolver sent the recursive
|
||||
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Simulate an inflight refresh.
|
||||
r.markRefreshing(q)
|
||||
defer r.clearRefreshing(q)
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
|
||||
|
||||
r.mutex.RLock()
|
||||
inflight := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
require.NotNil(t, inflight)
|
||||
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
|
||||
}
|
||||
|
||||
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
r.markRefreshing(q)
|
||||
defer r.clearRefreshing(q)
|
||||
|
||||
// Multiple ServeDNS calls during the same refresh must not re-set the flag
|
||||
// (CompareAndSwap from false -> true returns true only on the first call).
|
||||
for range 5 {
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}
|
||||
|
||||
r.mutex.RLock()
|
||||
inflight := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
assert.True(t, inflight.Load())
|
||||
}
|
||||
|
||||
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
|
||||
r.mutex.RLock()
|
||||
_, ok := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, ok, "no refresh inflight means no loop tracking")
|
||||
}
|
||||
|
||||
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.2", firstA(t, resp))
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -24,60 +23,6 @@ func TestResolver_NewResolver(t *testing.T) {
|
||||
assert.False(t, resolver.MatchSubdomains())
|
||||
}
|
||||
|
||||
func TestResolveCacheTTL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
want time.Duration
|
||||
}{
|
||||
{"unset falls back to default", "", defaultTTL},
|
||||
{"valid duration", "45s", 45 * time.Second},
|
||||
{"valid minutes", "2m", 2 * time.Minute},
|
||||
{"malformed falls back to default", "not-a-duration", defaultTTL},
|
||||
{"zero falls back to default", "0s", defaultTTL},
|
||||
{"negative falls back to default", "-5s", defaultTTL},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(envMgmtCacheTTL, tc.value)
|
||||
got := resolveCacheTTL()
|
||||
assert.Equal(t, tc.want, got, "parsed TTL should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
|
||||
t.Setenv(envMgmtCacheTTL, "7s")
|
||||
r := NewResolver()
|
||||
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
|
||||
}
|
||||
|
||||
func TestResolver_ResponseTTL(t *testing.T) {
|
||||
now := time.Now()
|
||||
tests := []struct {
|
||||
name string
|
||||
cacheTTL time.Duration
|
||||
cachedAt time.Time
|
||||
wantMin uint32
|
||||
wantMax uint32
|
||||
}{
|
||||
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
|
||||
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
|
||||
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
|
||||
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := &Resolver{cacheTTL: tc.cacheTTL}
|
||||
got := r.responseTTL(tc.cachedAt)
|
||||
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
|
||||
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -212,7 +212,6 @@ func newDefaultServer(
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
|
||||
mgmtCacheResolver := mgmt.NewResolver()
|
||||
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
|
||||
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
@@ -571,7 +570,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.connMgr.Start(e.ctx)
|
||||
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
e.srWatcher.Start(peer.IsForceRelayed())
|
||||
e.srWatcher.Start()
|
||||
|
||||
e.receiveSignalEvents()
|
||||
e.receiveManagementEvents()
|
||||
@@ -605,8 +604,6 @@ func (e *Engine) createFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
firewalld.SetParentContext(e.ctx)
|
||||
|
||||
var err error
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||
if err != nil {
|
||||
@@ -944,12 +941,7 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||
return fmt.Errorf("update relay token: %w", err)
|
||||
}
|
||||
|
||||
urls := update.Urls
|
||||
if override, ok := peer.OverrideRelayURLs(); ok {
|
||||
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
|
||||
urls = override
|
||||
}
|
||||
e.relayManager.UpdateServerURLs(urls)
|
||||
e.relayManager.UpdateServerURLs(update.Urls)
|
||||
|
||||
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||
|
||||
@@ -1671,7 +1671,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package activity
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -17,6 +18,10 @@ import (
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
func isBindListenerPlatform() bool {
|
||||
return runtime.GOOS == "windows" || runtime.GOOS == "js"
|
||||
}
|
||||
|
||||
// mockEndpointManager implements device.EndpointManager for testing
|
||||
type mockEndpointManager struct {
|
||||
endpoints map[netip.Addr]net.Conn
|
||||
@@ -176,6 +181,10 @@ func TestBindListener_Close(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestManager_BindMode(t *testing.T) {
|
||||
if !isBindListenerPlatform() {
|
||||
t.Skip("BindListener only used on Windows/JS platforms")
|
||||
}
|
||||
|
||||
mockEndpointMgr := newMockEndpointManager()
|
||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||
|
||||
@@ -217,6 +226,10 @@ func TestManager_BindMode(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestManager_BindMode_MultiplePeers(t *testing.T) {
|
||||
if !isBindListenerPlatform() {
|
||||
t.Skip("BindListener only used on Windows/JS platforms")
|
||||
}
|
||||
|
||||
mockEndpointMgr := newMockEndpointManager()
|
||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||
|
||||
|
||||
@@ -4,12 +4,14 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
@@ -73,6 +75,16 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
|
||||
return NewUDPListener(m.wgIface, peerCfg)
|
||||
}
|
||||
|
||||
// BindListener is used on Windows, JS, and netstack platforms:
|
||||
// - JS: Cannot listen to UDP sockets
|
||||
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
||||
// gateway points to, preventing them from reaching the loopback interface.
|
||||
// - Netstack: Allows multiple instances on the same host without port conflicts.
|
||||
// BindListener bypasses these issues by passing data directly through the bind.
|
||||
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
|
||||
return NewUDPListener(m.wgIface, peerCfg)
|
||||
}
|
||||
|
||||
provider, ok := m.wgIface.(bindProvider)
|
||||
if !ok {
|
||||
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")
|
||||
|
||||
@@ -185,20 +185,17 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
|
||||
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
|
||||
|
||||
forceRelay := IsForceRelayed()
|
||||
if !forceRelay {
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
|
||||
|
||||
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
||||
if !forceRelay {
|
||||
if !isForceRelayed() {
|
||||
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
@@ -254,9 +251,7 @@ func (conn *Conn) Close(signalToRemote bool) {
|
||||
conn.wgWatcherCancel()
|
||||
}
|
||||
conn.workerRelay.CloseConn()
|
||||
if conn.workerICE != nil {
|
||||
conn.workerICE.Close()
|
||||
}
|
||||
conn.workerICE.Close()
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
err := conn.wgProxyRelay.CloseConn()
|
||||
@@ -299,9 +294,7 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
|
||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
|
||||
conn.dumpState.RemoteCandidate()
|
||||
if conn.workerICE != nil {
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
}
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
}
|
||||
|
||||
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
||||
@@ -719,35 +712,33 @@ func (conn *Conn) evalStatus() ConnStatus {
|
||||
return StatusConnecting
|
||||
}
|
||||
|
||||
// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports.
|
||||
//
|
||||
// The result is a tri-state:
|
||||
// - ConnStatusConnected: all available transports are up
|
||||
// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting
|
||||
// - ConnStatusDisconnected: no working transport
|
||||
func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
|
||||
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
// would be better to protect this with a mutex, but it could cause deadlock with Close function
|
||||
|
||||
defer func() {
|
||||
if status == guard.ConnStatusDisconnected {
|
||||
if !connected {
|
||||
conn.logTraceConnState()
|
||||
}
|
||||
}()
|
||||
|
||||
iceWorkerCreated := conn.workerICE != nil
|
||||
|
||||
var iceInProgress bool
|
||||
if iceWorkerCreated {
|
||||
iceInProgress = conn.workerICE.InProgress()
|
||||
// For JS platform: only relay connection is supported
|
||||
if runtime.GOOS == "js" {
|
||||
return conn.statusRelay.Get() == worker.StatusConnected
|
||||
}
|
||||
|
||||
return evalConnStatus(connStatusInputs{
|
||||
forceRelay: IsForceRelayed(),
|
||||
peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(),
|
||||
relayConnected: conn.statusRelay.Get() == worker.StatusConnected,
|
||||
remoteSupportsICE: conn.handshaker.RemoteICESupported(),
|
||||
iceWorkerCreated: iceWorkerCreated,
|
||||
iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected,
|
||||
iceInProgress: iceInProgress,
|
||||
})
|
||||
// For non-JS platforms: check ICE connection status
|
||||
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
return false
|
||||
}
|
||||
|
||||
// If relay is supported with peer, it must also be connected
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
if conn.statusRelay.Get() == worker.StatusDisconnected {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
|
||||
@@ -935,43 +926,3 @@ func isController(config ConnConfig) bool {
|
||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||
return remoteRosenpassPubKey != nil
|
||||
}
|
||||
|
||||
func evalConnStatus(in connStatusInputs) guard.ConnStatus {
|
||||
// "Relay up and needed" — the peer uses relay and the transport is connected.
|
||||
relayUsedAndUp := in.peerUsesRelay && in.relayConnected
|
||||
|
||||
// Force-relay mode: ICE never runs. Relay is the only transport and must be up.
|
||||
if in.forceRelay {
|
||||
return boolToConnStatus(relayUsedAndUp)
|
||||
}
|
||||
|
||||
// Remote peer doesn't support ICE, or we haven't created the worker yet:
|
||||
// relay is the only possible transport.
|
||||
if !in.remoteSupportsICE || !in.iceWorkerCreated {
|
||||
return boolToConnStatus(relayUsedAndUp)
|
||||
}
|
||||
|
||||
// ICE counts as "up" when the status is anything other than Disconnected, OR
|
||||
// when a negotiation is currently in progress (so we don't spam offers while one is in flight).
|
||||
iceUp := in.iceStatusConnecting || in.iceInProgress
|
||||
|
||||
// Relay side is acceptable if the peer doesn't rely on relay, or relay is connected.
|
||||
relayOK := !in.peerUsesRelay || in.relayConnected
|
||||
|
||||
switch {
|
||||
case iceUp && relayOK:
|
||||
return guard.ConnStatusConnected
|
||||
case relayUsedAndUp:
|
||||
// Relay is up but ICE is down — partially connected.
|
||||
return guard.ConnStatusPartiallyConnected
|
||||
default:
|
||||
return guard.ConnStatusDisconnected
|
||||
}
|
||||
}
|
||||
|
||||
func boolToConnStatus(connected bool) guard.ConnStatus {
|
||||
if connected {
|
||||
return guard.ConnStatusConnected
|
||||
}
|
||||
return guard.ConnStatusDisconnected
|
||||
}
|
||||
|
||||
@@ -13,20 +13,6 @@ const (
|
||||
StatusConnected
|
||||
)
|
||||
|
||||
// connStatusInputs is the primitive-valued snapshot of the state that drives the
|
||||
// tri-state connection classification. Extracted so the decision logic can be unit-tested
|
||||
// without constructing full Worker/Handshaker objects.
|
||||
type connStatusInputs struct {
|
||||
forceRelay bool // NB_FORCE_RELAY or JS/WASM
|
||||
peerUsesRelay bool // remote peer advertises relay support AND local has relay
|
||||
relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay)
|
||||
remoteSupportsICE bool // remote peer sent ICE credentials
|
||||
iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode)
|
||||
iceStatusConnecting bool // statusICE is anything other than Disconnected
|
||||
iceInProgress bool // a negotiation is currently in flight
|
||||
}
|
||||
|
||||
|
||||
// ConnStatus describe the status of a peer's connection
|
||||
type ConnStatus int32
|
||||
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
)
|
||||
|
||||
func TestEvalConnStatus_ForceRelay(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in connStatusInputs
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "force relay, peer uses relay, relay up",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "force relay, peer uses relay, relay down",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: true,
|
||||
relayConnected: false,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "force relay, peer does NOT use relay - disconnected forever",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: false,
|
||||
relayConnected: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := evalConnStatus(tc.in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvalConnStatus_ICEUnavailable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in connStatusInputs
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "remote does not support ICE, peer uses relay, relay up",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "remote does not support ICE, peer uses relay, relay down",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: false,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE worker not yet created, relay up",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
remoteSupportsICE: true,
|
||||
iceWorkerCreated: false,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "remote does not support ICE, peer does not use relay",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: false,
|
||||
relayConnected: false,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := evalConnStatus(tc.in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvalConnStatus_FullyAvailable(t *testing.T) {
|
||||
base := connStatusInputs{
|
||||
remoteSupportsICE: true,
|
||||
iceWorkerCreated: true,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutator func(*connStatusInputs)
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "ICE connected, relay connected, peer uses relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = true
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE connected, peer does NOT use relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE InProgress only, peer does NOT use relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, relay up, peer uses relay -> partial",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = true
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusPartiallyConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, peer does NOT use relay -> disconnected",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
// relayOK = false (peer uses relay but it's down), iceUp = true
|
||||
// first switch arm fails (relayOK false), relayUsedAndUp = false (relay down),
|
||||
// falls into default: Disconnected.
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, relay up but peer does not use relay -> disconnected",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = true // not actually used since peer doesn't rely on it
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
in := base
|
||||
tc.mutator(&in)
|
||||
if got := evalConnStatus(in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,38 +7,12 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||
EnvKeyNBHomeRelayServers = "NB_HOME_RELAY_SERVERS"
|
||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||
)
|
||||
|
||||
func IsForceRelayed() bool {
|
||||
func isForceRelayed() bool {
|
||||
if runtime.GOOS == "js" {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
||||
}
|
||||
|
||||
// OverrideRelayURLs returns the relay server URL list set in
|
||||
// NB_HOME_RELAY_SERVERS (comma-separated) and a boolean indicating whether
|
||||
// the override is active. When the env var is unset, the boolean is false
|
||||
// and the caller should keep the list received from the management server.
|
||||
// Intended for lab/debug scenarios where a peer must pin to a specific home
|
||||
// relay regardless of what management offers.
|
||||
func OverrideRelayURLs() ([]string, bool) {
|
||||
raw := os.Getenv(EnvKeyNBHomeRelayServers)
|
||||
if raw == "" {
|
||||
return nil, false
|
||||
}
|
||||
parts := strings.Split(raw, ",")
|
||||
urls := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
urls = append(urls, p)
|
||||
}
|
||||
}
|
||||
if len(urls) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
return urls, true
|
||||
}
|
||||
|
||||
@@ -8,19 +8,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ConnStatus represents the connection state as seen by the guard.
|
||||
type ConnStatus int
|
||||
|
||||
const (
|
||||
// ConnStatusDisconnected means neither ICE nor Relay is connected.
|
||||
ConnStatusDisconnected ConnStatus = iota
|
||||
// ConnStatusPartiallyConnected means Relay is connected but ICE is not.
|
||||
ConnStatusPartiallyConnected
|
||||
// ConnStatusConnected means all required connections are established.
|
||||
ConnStatusConnected
|
||||
)
|
||||
|
||||
type connStatusFunc func() ConnStatus
|
||||
type isConnectedFunc func() bool
|
||||
|
||||
// Guard is responsible for the reconnection logic.
|
||||
// It will trigger to send an offer to the peer then has connection issues.
|
||||
@@ -32,14 +20,14 @@ type connStatusFunc func() ConnStatus
|
||||
// - ICE candidate changes
|
||||
type Guard struct {
|
||||
log *log.Entry
|
||||
isConnectedOnAllWay connStatusFunc
|
||||
isConnectedOnAllWay isConnectedFunc
|
||||
timeout time.Duration
|
||||
srWatcher *SRWatcher
|
||||
relayedConnDisconnected chan struct{}
|
||||
iCEConnDisconnected chan struct{}
|
||||
}
|
||||
|
||||
func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
return &Guard{
|
||||
log: log,
|
||||
isConnectedOnAllWay: isConnectedFn,
|
||||
@@ -69,17 +57,8 @@ func (g *Guard) SetICEConnDisconnected() {
|
||||
}
|
||||
}
|
||||
|
||||
// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity.
|
||||
//
|
||||
// Behavior depends on the connection state reported by isConnectedOnAllWay:
|
||||
// - Connected: no action, the peer is fully reachable.
|
||||
// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling
|
||||
// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all.
|
||||
// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches
|
||||
// to one attempt per hour. This limits signaling traffic when relay already provides connectivity.
|
||||
//
|
||||
// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry
|
||||
// counter and backoff ticker, giving ICE a fresh chance after network conditions change.
|
||||
// reconnectLoopWithRetry periodically check the connection status.
|
||||
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
|
||||
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
srReconnectedChan := g.srWatcher.NewListener()
|
||||
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
||||
@@ -89,47 +68,36 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
|
||||
tickerChannel := ticker.C
|
||||
|
||||
iceState := &iceRetryState{log: g.log}
|
||||
defer iceState.reset()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-tickerChannel:
|
||||
switch g.isConnectedOnAllWay() {
|
||||
case ConnStatusConnected:
|
||||
// all good, nothing to do
|
||||
case ConnStatusDisconnected:
|
||||
callback()
|
||||
case ConnStatusPartiallyConnected:
|
||||
if iceState.shouldRetry() {
|
||||
callback()
|
||||
} else {
|
||||
iceState.enterHourlyMode()
|
||||
ticker.Stop()
|
||||
tickerChannel = iceState.hourlyC()
|
||||
}
|
||||
case t := <-tickerChannel:
|
||||
if t.IsZero() {
|
||||
g.log.Infof("retry timed out, stop periodic offer sending")
|
||||
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
|
||||
tickerChannel = make(<-chan time.Time)
|
||||
continue
|
||||
}
|
||||
|
||||
if !g.isConnectedOnAllWay() {
|
||||
callback()
|
||||
}
|
||||
case <-g.relayedConnDisconnected:
|
||||
g.log.Debugf("Relay connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-g.iCEConnDisconnected:
|
||||
g.log.Debugf("ICE connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-srReconnectedChan:
|
||||
g.log.Debugf("has network changes, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-ctx.Done():
|
||||
g.log.Debugf("context is done, stop reconnect loop")
|
||||
@@ -152,7 +120,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
|
||||
return backoff.NewTicker(bo)
|
||||
}
|
||||
|
||||
func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker {
|
||||
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 0.1,
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxICERetries is the maximum number of ICE offer attempts when relay is connected
|
||||
maxICERetries = 3
|
||||
// iceRetryInterval is the periodic retry interval after ICE retries are exhausted
|
||||
iceRetryInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// iceRetryState tracks the limited ICE retry attempts when relay is already connected.
|
||||
// After maxICERetries attempts it switches to a periodic hourly retry.
|
||||
type iceRetryState struct {
|
||||
log *log.Entry
|
||||
retries int
|
||||
hourly *time.Ticker
|
||||
}
|
||||
|
||||
func (s *iceRetryState) reset() {
|
||||
s.retries = 0
|
||||
if s.hourly != nil {
|
||||
s.hourly.Stop()
|
||||
s.hourly = nil
|
||||
}
|
||||
}
|
||||
|
||||
// shouldRetry reports whether the caller should send another ICE offer on this tick.
|
||||
// Returns false when the per-cycle retry budget is exhausted and the caller must switch
|
||||
// to the hourly ticker via enterHourlyMode + hourlyC.
|
||||
func (s *iceRetryState) shouldRetry() bool {
|
||||
if s.hourly != nil {
|
||||
s.log.Debugf("hourly ICE retry attempt")
|
||||
return true
|
||||
}
|
||||
|
||||
s.retries++
|
||||
if s.retries <= maxICERetries {
|
||||
s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false.
|
||||
func (s *iceRetryState) enterHourlyMode() {
|
||||
s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries)
|
||||
s.hourly = time.NewTicker(iceRetryInterval)
|
||||
}
|
||||
|
||||
func (s *iceRetryState) hourlyC() <-chan time.Time {
|
||||
if s.hourly == nil {
|
||||
return nil
|
||||
}
|
||||
return s.hourly.C
|
||||
}
|
||||
@@ -1,103 +0,0 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func newTestRetryState() *iceRetryState {
|
||||
return &iceRetryState{log: log.NewEntry(log.StandardLogger())}
|
||||
}
|
||||
|
||||
func TestICERetryState_AllowsInitialBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
for i := 1; i <= maxICERetries; i++ {
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ExhaustsAfterBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
for i := 0; i < maxICERetries; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
|
||||
if s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned true after budget exhausted, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
for i := 0; i < maxICERetries+1; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
|
||||
s.enterHourlyMode()
|
||||
defer s.reset()
|
||||
|
||||
if s.hourlyC() == nil {
|
||||
t.Fatalf("hourlyC returned nil after enterHourlyMode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
s.enterHourlyMode()
|
||||
defer s.reset()
|
||||
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false in hourly mode, want true")
|
||||
}
|
||||
|
||||
// Subsequent calls also return true — we keep retrying on each hourly tick.
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("second shouldRetry returned false in hourly mode, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ResetRestoresBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
for i := 0; i < maxICERetries+1; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
s.enterHourlyMode()
|
||||
|
||||
s.reset()
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC returned non-nil channel after reset")
|
||||
}
|
||||
if s.retries != 0 {
|
||||
t.Fatalf("retries = %d after reset, want 0", s.retries)
|
||||
}
|
||||
|
||||
for i := 1; i <= maxICERetries; i++ {
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ResetIsIdempotent(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
s.reset()
|
||||
s.reset() // second call must not panic or re-stop a nil ticker
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC non-nil after double reset")
|
||||
}
|
||||
}
|
||||
@@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove
|
||||
return srw
|
||||
}
|
||||
|
||||
func (w *SRWatcher) Start(disableICEMonitor bool) {
|
||||
func (w *SRWatcher) Start() {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
@@ -50,10 +50,8 @@ func (w *SRWatcher) Start(disableICEMonitor bool) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w.cancelIceMonitor = cancel
|
||||
|
||||
if !disableICEMonitor {
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
}
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -44,10 +43,6 @@ type OfferAnswer struct {
|
||||
SessionID *ICESessionID
|
||||
}
|
||||
|
||||
func (o *OfferAnswer) hasICECredentials() bool {
|
||||
return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != ""
|
||||
}
|
||||
|
||||
type Handshaker struct {
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
@@ -64,10 +59,6 @@ type Handshaker struct {
|
||||
relayListener *AsyncOfferListener
|
||||
iceListener func(remoteOfferAnswer *OfferAnswer)
|
||||
|
||||
// remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers.
|
||||
// When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses.
|
||||
remoteICESupported atomic.Bool
|
||||
|
||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||
remoteOffersCh chan OfferAnswer
|
||||
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
|
||||
@@ -75,7 +66,7 @@ type Handshaker struct {
|
||||
}
|
||||
|
||||
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
|
||||
h := &Handshaker{
|
||||
return &Handshaker{
|
||||
log: log,
|
||||
config: config,
|
||||
signaler: signaler,
|
||||
@@ -85,13 +76,6 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
||||
remoteOffersCh: make(chan OfferAnswer),
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
}
|
||||
// assume remote supports ICE until we learn otherwise from received offers
|
||||
h.remoteICESupported.Store(ice != nil)
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *Handshaker) RemoteICESupported() bool {
|
||||
return h.remoteICESupported.Load()
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
@@ -106,20 +90,18 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
|
||||
// Record signaling received for reconnection attempts
|
||||
if h.metricsStages != nil {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil && h.RemoteICESupported() {
|
||||
if h.iceListener != nil {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
@@ -128,20 +110,18 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
|
||||
// Record signaling received for reconnection attempts
|
||||
if h.metricsStages != nil {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil && h.RemoteICESupported() {
|
||||
if h.iceListener != nil {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
@@ -203,18 +183,15 @@ func (h *Handshaker) sendAnswer() error {
|
||||
}
|
||||
|
||||
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||
sid := h.ice.SessionID()
|
||||
answer := OfferAnswer{
|
||||
IceCredentials: IceCredentials{uFrag, pwd},
|
||||
WgListenPort: h.config.LocalWgPort,
|
||||
Version: version.NetbirdVersion(),
|
||||
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
|
||||
RosenpassAddr: h.config.RosenpassConfig.Addr,
|
||||
}
|
||||
|
||||
if h.ice != nil && h.RemoteICESupported() {
|
||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||
sid := h.ice.SessionID()
|
||||
answer.IceCredentials = IceCredentials{uFrag, pwd}
|
||||
answer.SessionID = &sid
|
||||
SessionID: &sid,
|
||||
}
|
||||
|
||||
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
|
||||
@@ -223,18 +200,3 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||
|
||||
return answer
|
||||
}
|
||||
|
||||
func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) {
|
||||
hasICE := offer.hasICECredentials()
|
||||
prev := h.remoteICESupported.Swap(hasICE)
|
||||
if prev != hasICE {
|
||||
if hasICE {
|
||||
h.log.Infof("remote peer started sending ICE credentials")
|
||||
} else {
|
||||
h.log.Infof("remote peer stopped sending ICE credentials")
|
||||
if h.ice != nil {
|
||||
h.ice.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,13 +46,9 @@ func (s *Signaler) Ready() bool {
|
||||
|
||||
// SignalOfferAnswer signals either an offer or an answer to remote peer
|
||||
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
|
||||
var sessionIDBytes []byte
|
||||
if offerAnswer.SessionID != nil {
|
||||
var err error
|
||||
sessionIDBytes, err = offerAnswer.SessionID.Bytes()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
}
|
||||
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
}
|
||||
msg, err := signal.MarshalCredential(
|
||||
s.wgPrivateKey,
|
||||
|
||||
@@ -89,16 +89,8 @@ func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
|
||||
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
|
||||
}
|
||||
|
||||
reused := false
|
||||
if err := r.addScopedDefault(unspec, nexthop); err != nil {
|
||||
if !errors.Is(err, unix.EEXIST) {
|
||||
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
|
||||
}
|
||||
// macOS installs its own RTF_IFSCOPE defaults for primary service
|
||||
// selection on multi-NIC setups, so a route on this ifindex can
|
||||
// already exist before we try. Binding to it via IP[V6]_BOUND_IF
|
||||
// still produces the scoped lookup we need.
|
||||
reused = true
|
||||
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
|
||||
}
|
||||
|
||||
af := unix.AF_INET
|
||||
@@ -110,11 +102,7 @@ func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
|
||||
if nexthop.IP.IsValid() {
|
||||
via = nexthop.IP.String()
|
||||
}
|
||||
verb := "installed"
|
||||
if reused {
|
||||
verb = "reused existing"
|
||||
}
|
||||
log.Infof("%s scoped default route via %s on %s for %s", verb, via, nexthop.Intf.Name, afOf(unspec))
|
||||
log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec))
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -2,358 +2,217 @@
|
||||
|
||||
package sleep
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
|
||||
#include <IOKit/pwr_mgt/IOPMLib.h>
|
||||
#include <IOKit/IOMessage.h>
|
||||
#include <CoreFoundation/CoreFoundation.h>
|
||||
|
||||
extern void sleepCallbackBridge();
|
||||
extern void poweredOnCallbackBridge();
|
||||
extern void suspendedCallbackBridge();
|
||||
extern void resumedCallbackBridge();
|
||||
|
||||
|
||||
// C global variables for IOKit state
|
||||
static IONotificationPortRef g_notifyPortRef = NULL;
|
||||
static io_object_t g_notifierObject = 0;
|
||||
static io_object_t g_generalInterestNotifier = 0;
|
||||
static io_connect_t g_rootPort = 0;
|
||||
static CFRunLoopRef g_runLoop = NULL;
|
||||
|
||||
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
|
||||
switch (messageType) {
|
||||
case kIOMessageSystemWillSleep:
|
||||
sleepCallbackBridge();
|
||||
IOAllowPowerChange(g_rootPort, (long)messageArgument);
|
||||
break;
|
||||
case kIOMessageSystemHasPoweredOn:
|
||||
poweredOnCallbackBridge();
|
||||
break;
|
||||
case kIOMessageServiceIsSuspended:
|
||||
suspendedCallbackBridge();
|
||||
break;
|
||||
case kIOMessageServiceIsResumed:
|
||||
resumedCallbackBridge();
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void registerNotifications() {
|
||||
g_rootPort = IORegisterForSystemPower(
|
||||
NULL,
|
||||
&g_notifyPortRef,
|
||||
(IOServiceInterestCallback)sleepCallback,
|
||||
&g_notifierObject
|
||||
);
|
||||
|
||||
if (g_rootPort == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
CFRunLoopAddSource(CFRunLoopGetCurrent(),
|
||||
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||
kCFRunLoopCommonModes);
|
||||
|
||||
g_runLoop = CFRunLoopGetCurrent();
|
||||
CFRunLoopRun();
|
||||
}
|
||||
|
||||
static void unregisterNotifications() {
|
||||
CFRunLoopRemoveSource(g_runLoop,
|
||||
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||
kCFRunLoopCommonModes);
|
||||
|
||||
IODeregisterForSystemPower(&g_notifierObject);
|
||||
IOServiceClose(g_rootPort);
|
||||
IONotificationPortDestroy(g_notifyPortRef);
|
||||
CFRunLoopStop(g_runLoop);
|
||||
|
||||
g_notifyPortRef = NULL;
|
||||
g_notifierObject = 0;
|
||||
g_rootPort = 0;
|
||||
g_runLoop = NULL;
|
||||
}
|
||||
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// IOKit message types from IOKit/IOMessage.h.
|
||||
const (
|
||||
kIOMessageCanSystemSleep uintptr = 0xe0000270
|
||||
kIOMessageSystemWillSleep uintptr = 0xe0000280
|
||||
kIOMessageSystemHasPoweredOn uintptr = 0xe0000300
|
||||
)
|
||||
|
||||
var (
|
||||
ioKit iokitFuncs
|
||||
cf cfFuncs
|
||||
cfCommonModes uintptr
|
||||
|
||||
libInitOnce sync.Once
|
||||
libInitErr error
|
||||
|
||||
// callbackThunk is the single C-callable trampoline registered with IOKit.
|
||||
callbackThunk uintptr
|
||||
|
||||
serviceRegistry = make(map[*Detector]struct{})
|
||||
serviceRegistryMu sync.Mutex
|
||||
session *runLoopSession
|
||||
|
||||
// lifecycleMu serializes Register/Deregister so a new registration can't
|
||||
// start a second runloop while a previous teardown is still pending.
|
||||
lifecycleMu sync.Mutex
|
||||
)
|
||||
|
||||
// iokitFuncs holds IOKit symbols resolved once at init.
|
||||
type iokitFuncs struct {
|
||||
IORegisterForSystemPower func(refcon uintptr, portRef *uintptr, callback uintptr, notifier *uintptr) uintptr
|
||||
IODeregisterForSystemPower func(notifier *uintptr) int32
|
||||
IOAllowPowerChange func(kernelPort uintptr, notificationID uintptr) int32
|
||||
IOServiceClose func(connect uintptr) int32
|
||||
IONotificationPortGetRunLoopSource func(port uintptr) uintptr
|
||||
IONotificationPortDestroy func(port uintptr)
|
||||
}
|
||||
|
||||
// cfFuncs holds CoreFoundation symbols resolved once at init.
|
||||
type cfFuncs struct {
|
||||
CFRunLoopGetCurrent func() uintptr
|
||||
CFRunLoopRun func()
|
||||
CFRunLoopStop func(rl uintptr)
|
||||
CFRunLoopAddSource func(rl, source, mode uintptr)
|
||||
CFRunLoopRemoveSource func(rl, source, mode uintptr)
|
||||
}
|
||||
|
||||
// runLoopSession bundles the handles owned by one CFRunLoop lifetime. A nil
|
||||
// session means no runloop is active and the next Register must start one.
|
||||
type runLoopSession struct {
|
||||
rl uintptr
|
||||
port uintptr
|
||||
notifier uintptr
|
||||
rp uintptr
|
||||
}
|
||||
|
||||
// detectorSnapshot pins a detector's callback and done channel so dispatch
|
||||
// runs with values valid at snapshot time, even if a concurrent
|
||||
// Deregister/Register rewrites the detector's fields.
|
||||
type detectorSnapshot struct {
|
||||
detector *Detector
|
||||
callback func(event EventType)
|
||||
done <-chan struct{}
|
||||
}
|
||||
|
||||
// Detector delivers sleep and wake events to a registered callback.
|
||||
type Detector struct {
|
||||
callback func(event EventType)
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// Register installs callback for power events. The first registration starts
|
||||
// the CFRunLoop on a dedicated OS-locked thread and blocks until IOKit
|
||||
// registration succeeds or fails; subsequent registrations just add to the
|
||||
// dispatch set.
|
||||
func (d *Detector) Register(callback func(event EventType)) error {
|
||||
lifecycleMu.Lock()
|
||||
defer lifecycleMu.Unlock()
|
||||
//export sleepCallbackBridge
|
||||
func sleepCallbackBridge() {
|
||||
log.Info("sleepCallbackBridge event triggered")
|
||||
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
|
||||
for svc := range serviceRegistry {
|
||||
svc.triggerCallback(EventTypeSleep)
|
||||
}
|
||||
}
|
||||
|
||||
//export resumedCallbackBridge
|
||||
func resumedCallbackBridge() {
|
||||
log.Info("resumedCallbackBridge event triggered")
|
||||
}
|
||||
|
||||
//export suspendedCallbackBridge
|
||||
func suspendedCallbackBridge() {
|
||||
log.Info("suspendedCallbackBridge event triggered")
|
||||
}
|
||||
|
||||
//export poweredOnCallbackBridge
|
||||
func poweredOnCallbackBridge() {
|
||||
log.Info("poweredOnCallbackBridge event triggered")
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
|
||||
for svc := range serviceRegistry {
|
||||
svc.triggerCallback(EventTypeWakeUp)
|
||||
}
|
||||
}
|
||||
|
||||
type Detector struct {
|
||||
callback func(event EventType)
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewDetector() (*Detector, error) {
|
||||
return &Detector{}, nil
|
||||
}
|
||||
|
||||
func (d *Detector) Register(callback func(event EventType)) error {
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
|
||||
if _, exists := serviceRegistry[d]; exists {
|
||||
serviceRegistryMu.Unlock()
|
||||
return fmt.Errorf("detector service already registered")
|
||||
}
|
||||
d.callback = callback
|
||||
d.done = make(chan struct{})
|
||||
serviceRegistry[d] = struct{}{}
|
||||
needSetup := session == nil
|
||||
serviceRegistryMu.Unlock()
|
||||
|
||||
if !needSetup {
|
||||
d.callback = callback
|
||||
|
||||
d.ctx, d.cancel = context.WithCancel(context.Background())
|
||||
|
||||
if len(serviceRegistry) > 0 {
|
||||
serviceRegistry[d] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go runRunLoop(errCh)
|
||||
if err := <-errCh; err != nil {
|
||||
serviceRegistryMu.Lock()
|
||||
delete(serviceRegistry, d)
|
||||
close(d.done)
|
||||
d.done = nil
|
||||
serviceRegistryMu.Unlock()
|
||||
return err
|
||||
}
|
||||
serviceRegistry[d] = struct{}{}
|
||||
|
||||
// CFRunLoop must run on a single fixed OS thread
|
||||
go func() {
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
C.registerNotifications()
|
||||
}()
|
||||
|
||||
log.Info("sleep detection service started on macOS")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deregister removes the detector. When the last detector leaves, IOKit
|
||||
// notifications are torn down and the runloop is stopped.
|
||||
// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down
|
||||
// and the runloop is stopped and cleaned up.
|
||||
func (d *Detector) Deregister() error {
|
||||
lifecycleMu.Lock()
|
||||
defer lifecycleMu.Unlock()
|
||||
|
||||
serviceRegistryMu.Lock()
|
||||
if _, exists := serviceRegistry[d]; !exists {
|
||||
serviceRegistryMu.Unlock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
_, exists := serviceRegistry[d]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
close(d.done)
|
||||
|
||||
// cancel and remove this detector
|
||||
d.cancel()
|
||||
delete(serviceRegistry, d)
|
||||
|
||||
// If other Detectors still exist, leave IOKit running
|
||||
if len(serviceRegistry) > 0 {
|
||||
serviceRegistryMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
sess := session
|
||||
serviceRegistryMu.Unlock()
|
||||
|
||||
log.Info("sleep detection service stopping (deregister)")
|
||||
|
||||
if sess == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if sess.rl != 0 && sess.port != 0 {
|
||||
source := ioKit.IONotificationPortGetRunLoopSource(sess.port)
|
||||
cf.CFRunLoopRemoveSource(sess.rl, source, cfCommonModes)
|
||||
}
|
||||
if sess.notifier != 0 {
|
||||
n := sess.notifier
|
||||
ioKit.IODeregisterForSystemPower(&n)
|
||||
}
|
||||
|
||||
// Clear session only after IODeregisterForSystemPower returns so any
|
||||
// in-flight powerCallback can still look up session.rp to ack sleep.
|
||||
serviceRegistryMu.Lock()
|
||||
session = nil
|
||||
serviceRegistryMu.Unlock()
|
||||
|
||||
if sess.rp != 0 {
|
||||
ioKit.IOServiceClose(sess.rp)
|
||||
}
|
||||
if sess.port != 0 {
|
||||
ioKit.IONotificationPortDestroy(sess.port)
|
||||
}
|
||||
if sess.rl != 0 {
|
||||
cf.CFRunLoopStop(sess.rl)
|
||||
}
|
||||
// Deregister IOKit notifications, stop runloop, and free resources
|
||||
C.unregisterNotifications()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Detector) triggerCallback(event EventType, cb func(event EventType), done <-chan struct{}) {
|
||||
if cb == nil || done == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
func (d *Detector) triggerCallback(event EventType) {
|
||||
doneChan := make(chan struct{})
|
||||
|
||||
timeout := time.NewTimer(500 * time.Millisecond)
|
||||
defer timeout.Stop()
|
||||
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Errorf("panic in sleep callback: %v", r)
|
||||
}
|
||||
}()
|
||||
cb := d.callback
|
||||
go func(callback func(event EventType)) {
|
||||
log.Info("sleep detection event fired")
|
||||
cb(event)
|
||||
}()
|
||||
callback(event)
|
||||
close(doneChan)
|
||||
}(cb)
|
||||
|
||||
select {
|
||||
case <-doneChan:
|
||||
case <-done:
|
||||
case <-d.ctx.Done():
|
||||
case <-timeout.C:
|
||||
log.Warn("sleep callback timed out")
|
||||
log.Warnf("sleep callback timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// NewDetector initializes IOKit/CoreFoundation bindings and returns a Detector.
|
||||
func NewDetector() (*Detector, error) {
|
||||
if err := initLibs(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Detector{}, nil
|
||||
}
|
||||
|
||||
func initLibs() error {
|
||||
libInitOnce.Do(func() {
|
||||
iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libInitErr = fmt.Errorf("dlopen IOKit: %w", err)
|
||||
return
|
||||
}
|
||||
cfLib, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libInitErr = fmt.Errorf("dlopen CoreFoundation: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&ioKit.IORegisterForSystemPower, iokit, "IORegisterForSystemPower")
|
||||
purego.RegisterLibFunc(&ioKit.IODeregisterForSystemPower, iokit, "IODeregisterForSystemPower")
|
||||
purego.RegisterLibFunc(&ioKit.IOAllowPowerChange, iokit, "IOAllowPowerChange")
|
||||
purego.RegisterLibFunc(&ioKit.IOServiceClose, iokit, "IOServiceClose")
|
||||
purego.RegisterLibFunc(&ioKit.IONotificationPortGetRunLoopSource, iokit, "IONotificationPortGetRunLoopSource")
|
||||
purego.RegisterLibFunc(&ioKit.IONotificationPortDestroy, iokit, "IONotificationPortDestroy")
|
||||
|
||||
purego.RegisterLibFunc(&cf.CFRunLoopGetCurrent, cfLib, "CFRunLoopGetCurrent")
|
||||
purego.RegisterLibFunc(&cf.CFRunLoopRun, cfLib, "CFRunLoopRun")
|
||||
purego.RegisterLibFunc(&cf.CFRunLoopStop, cfLib, "CFRunLoopStop")
|
||||
purego.RegisterLibFunc(&cf.CFRunLoopAddSource, cfLib, "CFRunLoopAddSource")
|
||||
purego.RegisterLibFunc(&cf.CFRunLoopRemoveSource, cfLib, "CFRunLoopRemoveSource")
|
||||
|
||||
modeAddr, err := purego.Dlsym(cfLib, "kCFRunLoopCommonModes")
|
||||
if err != nil {
|
||||
libInitErr = fmt.Errorf("dlsym kCFRunLoopCommonModes: %w", err)
|
||||
return
|
||||
}
|
||||
// Launder the uintptr-to-pointer conversion through a Go variable so
|
||||
// go vet's unsafeptr analyzer doesn't flag a system-library global.
|
||||
cfCommonModes = **(**uintptr)(unsafe.Pointer(&modeAddr))
|
||||
|
||||
// NewCallback slots are a finite, non-reclaimable resource, so register
|
||||
// a single thunk that dispatches to the current Detector set.
|
||||
callbackThunk = purego.NewCallback(powerCallback)
|
||||
})
|
||||
return libInitErr
|
||||
}
|
||||
|
||||
// powerCallback is the IOServiceInterestCallback trampoline, invoked on the
|
||||
// runloop thread. A Go panic crossing the purego boundary has undefined
|
||||
// behavior, so contain it here.
|
||||
func powerCallback(refcon, service, messageType, messageArgument uintptr) uintptr {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Errorf("panic in sleep powerCallback: %v", r)
|
||||
}
|
||||
}()
|
||||
switch messageType {
|
||||
case kIOMessageCanSystemSleep:
|
||||
// Not acknowledging forces a 30s IOKit timeout before idle sleep.
|
||||
allowPowerChange(messageArgument)
|
||||
case kIOMessageSystemWillSleep:
|
||||
dispatchEvent(EventTypeSleep)
|
||||
allowPowerChange(messageArgument)
|
||||
case kIOMessageSystemHasPoweredOn:
|
||||
dispatchEvent(EventTypeWakeUp)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func allowPowerChange(messageArgument uintptr) {
|
||||
serviceRegistryMu.Lock()
|
||||
var port uintptr
|
||||
if session != nil {
|
||||
port = session.rp
|
||||
}
|
||||
serviceRegistryMu.Unlock()
|
||||
if port != 0 {
|
||||
ioKit.IOAllowPowerChange(port, messageArgument)
|
||||
}
|
||||
}
|
||||
|
||||
func dispatchEvent(event EventType) {
|
||||
serviceRegistryMu.Lock()
|
||||
snaps := make([]detectorSnapshot, 0, len(serviceRegistry))
|
||||
for d := range serviceRegistry {
|
||||
snaps = append(snaps, detectorSnapshot{
|
||||
detector: d,
|
||||
callback: d.callback,
|
||||
done: d.done,
|
||||
})
|
||||
}
|
||||
serviceRegistryMu.Unlock()
|
||||
|
||||
for _, s := range snaps {
|
||||
s.detector.triggerCallback(event, s.callback, s.done)
|
||||
}
|
||||
}
|
||||
|
||||
// runRunLoop owns the OS-locked thread that CFRunLoop is pinned to. Setup
|
||||
// result is reported on errCh so Register can surface failures synchronously.
|
||||
func runRunLoop(errCh chan<- error) {
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
sess, err := setupSession()
|
||||
if err == nil {
|
||||
serviceRegistryMu.Lock()
|
||||
session = sess
|
||||
serviceRegistryMu.Unlock()
|
||||
}
|
||||
errCh <- err
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Errorf("panic in sleep runloop: %v", r)
|
||||
}
|
||||
}()
|
||||
cf.CFRunLoopRun()
|
||||
}
|
||||
|
||||
// setupSession performs the IOKit registration on the current thread. Panics
|
||||
// are converted to errors so runRunLoop never leaves errCh unsent.
|
||||
func setupSession() (s *runLoopSession, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic during runloop setup: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
var portRef, notifier uintptr
|
||||
rp := ioKit.IORegisterForSystemPower(0, &portRef, callbackThunk, ¬ifier)
|
||||
if rp == 0 {
|
||||
return nil, fmt.Errorf("IORegisterForSystemPower returned zero")
|
||||
}
|
||||
|
||||
rl := cf.CFRunLoopGetCurrent()
|
||||
source := ioKit.IONotificationPortGetRunLoopSource(portRef)
|
||||
cf.CFRunLoopAddSource(rl, source, cfCommonModes)
|
||||
|
||||
return &runLoopSession{rl: rl, port: portRef, notifier: notifier, rp: rp}, nil
|
||||
}
|
||||
|
||||
@@ -13,25 +13,15 @@
|
||||
|
||||
<MajorUpgrade AllowSameVersionUpgrades='yes' DowngradeErrorMessage="A newer version of [ProductName] is already installed. Setup will now exit."/>
|
||||
|
||||
<!-- Autostart: enabled by default, disable with AUTOSTART=0 on the msiexec command line -->
|
||||
<Property Id="AUTOSTART" Value="1" />
|
||||
|
||||
<StandardDirectory Id="ProgramFiles64Folder">
|
||||
<Directory Id="NetbirdInstallDir" Name="Netbird">
|
||||
<Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64">
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird.exe" KeyPath="yes" />
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird-ui.exe">
|
||||
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon">
|
||||
<ShortcutProperty Key="System.AppUserModel.ID" Value="NetBird" />
|
||||
<ShortcutProperty Key="System.AppUserModel.ToastActivatorCLSID" Value="{0E1B4DE7-E148-432B-9814-544F941826EC}" />
|
||||
</Shortcut>
|
||||
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon">
|
||||
<ShortcutProperty Key="System.AppUserModel.ID" Value="NetBird" />
|
||||
<ShortcutProperty Key="System.AppUserModel.ToastActivatorCLSID" Value="{0E1B4DE7-E148-432B-9814-544F941826EC}" />
|
||||
</Shortcut>
|
||||
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
|
||||
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
|
||||
</File>
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" />
|
||||
<File Id="NetbirdToastIcon" Name="netbird.png" Source=".\client\ui\assets\netbird.png" />
|
||||
<?if $(var.ArchSuffix) = "amd64" ?>
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" />
|
||||
<?endif ?>
|
||||
@@ -56,31 +46,8 @@
|
||||
</Directory>
|
||||
</StandardDirectory>
|
||||
|
||||
<!-- Per-user component: HKCU keypath (auto GUID via "*"), separate from
|
||||
the per-machine NetbirdFiles component to satisfy ICE57. -->
|
||||
<StandardDirectory Id="ProgramMenuFolder">
|
||||
<Component Id="NetbirdAumidRegistry" Guid="*">
|
||||
<RegistryKey Root="HKCU" Key="Software\Classes\AppUserModelId\NetBird" ForceDeleteOnUninstall="yes">
|
||||
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
|
||||
</RegistryKey>
|
||||
</Component>
|
||||
</StandardDirectory>
|
||||
|
||||
<StandardDirectory Id="CommonAppDataFolder">
|
||||
<Directory Id="NetbirdAutoStartDir" Name="Netbird">
|
||||
<Component Id="NetbirdAutoStart" Guid="b199eaca-b0dd-4032-af19-679cfad48eb3" Bitness="always64">
|
||||
<Condition>AUTOSTART = "1"</Condition>
|
||||
<RegistryValue Root="HKLM" Key="Software\Microsoft\Windows\CurrentVersion\Run"
|
||||
Name="Netbird" Value=""[NetbirdInstallDir]netbird-ui.exe""
|
||||
Type="string" KeyPath="yes" />
|
||||
</Component>
|
||||
</Directory>
|
||||
</StandardDirectory>
|
||||
|
||||
<ComponentGroup Id="NetbirdFilesComponent">
|
||||
<ComponentRef Id="NetbirdFiles" />
|
||||
<ComponentRef Id="NetbirdAumidRegistry" />
|
||||
<ComponentRef Id="NetbirdAutoStart" />
|
||||
</ComponentGroup>
|
||||
|
||||
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -104,6 +104,8 @@ service DaemonService {
|
||||
// StopCPUProfile stops CPU profiling in the daemon
|
||||
rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {}
|
||||
|
||||
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
|
||||
|
||||
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
|
||||
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||
@@ -112,6 +114,20 @@ service DaemonService {
|
||||
|
||||
|
||||
|
||||
message OSLifecycleRequest {
|
||||
// avoid collision with loglevel enum
|
||||
enum CycleType {
|
||||
UNKNOWN = 0;
|
||||
SLEEP = 1;
|
||||
WAKEUP = 2;
|
||||
}
|
||||
|
||||
CycleType type = 1;
|
||||
}
|
||||
|
||||
message OSLifecycleResponse {}
|
||||
|
||||
|
||||
message LoginRequest {
|
||||
// setupKey netbird setup key.
|
||||
string setupKey = 1;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -120,7 +120,6 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
}
|
||||
agent := &serverAgent{s}
|
||||
s.sleepHandler = sleephandler.New(agent)
|
||||
s.startSleepDetector()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -335,7 +335,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -2,18 +2,13 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/sleep"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const envDisableSleepDetector = "NB_DISABLE_SLEEP_DETECTOR"
|
||||
|
||||
// serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces
|
||||
type serverAgent struct {
|
||||
s *Server
|
||||
@@ -33,61 +28,19 @@ func (a *serverAgent) Status() (internal.StatusType, error) {
|
||||
return internal.CtxGetState(a.s.rootCtx).Status()
|
||||
}
|
||||
|
||||
// startSleepDetector starts the OS sleep/wake detector and forwards events to
|
||||
// the sleep handler. On platforms without a supported detector the attempt
|
||||
// logs a warning and returns. Setting NB_DISABLE_SLEEP_DETECTOR=true skips
|
||||
// registration entirely.
|
||||
func (s *Server) startSleepDetector() {
|
||||
if sleepDetectorDisabled() {
|
||||
log.Info("sleep detection disabled via " + envDisableSleepDetector)
|
||||
return
|
||||
}
|
||||
|
||||
svc, err := sleep.New()
|
||||
if err != nil {
|
||||
log.Warnf("failed to initialize sleep detection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = svc.Register(func(event sleep.EventType) {
|
||||
switch event {
|
||||
case sleep.EventTypeSleep:
|
||||
log.Info("handling sleep event")
|
||||
if err := s.sleepHandler.HandleSleep(s.rootCtx); err != nil {
|
||||
log.Errorf("failed to handle sleep event: %v", err)
|
||||
}
|
||||
case sleep.EventTypeWakeUp:
|
||||
log.Info("handling wakeup event")
|
||||
if err := s.sleepHandler.HandleWakeUp(s.rootCtx); err != nil {
|
||||
log.Errorf("failed to handle wakeup event: %v", err)
|
||||
}
|
||||
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
|
||||
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
|
||||
switch req.GetType() {
|
||||
case proto.OSLifecycleRequest_WAKEUP:
|
||||
if err := s.sleepHandler.HandleWakeUp(callerCtx); err != nil {
|
||||
return &proto.OSLifecycleResponse{}, err
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("failed to register sleep detector: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("sleep detection service initialized")
|
||||
|
||||
go func() {
|
||||
<-s.rootCtx.Done()
|
||||
log.Info("stopping sleep event listener")
|
||||
if err := svc.Deregister(); err != nil {
|
||||
log.Errorf("failed to deregister sleep detector: %v", err)
|
||||
case proto.OSLifecycleRequest_SLEEP:
|
||||
if err := s.sleepHandler.HandleSleep(callerCtx); err != nil {
|
||||
return &proto.OSLifecycleResponse{}, err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func sleepDetectorDisabled() bool {
|
||||
val := os.Getenv(envDisableSleepDetector)
|
||||
if val == "" {
|
||||
return false
|
||||
default:
|
||||
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
|
||||
}
|
||||
disabled, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s=%q: %v", envDisableSleepDetector, val, err)
|
||||
return false
|
||||
}
|
||||
return disabled
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -38,10 +38,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/sleep"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ui/desktop"
|
||||
"github.com/netbirdio/netbird/client/ui/event"
|
||||
"github.com/netbirdio/netbird/client/ui/notifier"
|
||||
"github.com/netbirdio/netbird/client/ui/process"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
@@ -260,7 +260,6 @@ type serviceClient struct {
|
||||
|
||||
// application with main windows.
|
||||
app fyne.App
|
||||
notifier notifier.Notifier
|
||||
wSettings fyne.Window
|
||||
showAdvancedSettings bool
|
||||
sendNotification bool
|
||||
@@ -365,7 +364,6 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
||||
cancel: cancel,
|
||||
addr: args.addr,
|
||||
app: args.app,
|
||||
notifier: notifier.New(args.app),
|
||||
logFile: args.logFile,
|
||||
sendNotification: false,
|
||||
|
||||
@@ -894,7 +892,7 @@ func (s *serviceClient) updateStatus() error {
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
if s.connected {
|
||||
s.notifier.Send("Error", "Connection to service lost")
|
||||
s.app.SendNotification(fyne.NewNotification("Error", "Connection to service lost"))
|
||||
}
|
||||
s.setDisconnectedStatus()
|
||||
return err
|
||||
@@ -1111,7 +1109,7 @@ func (s *serviceClient) onTrayReady() {
|
||||
}
|
||||
}()
|
||||
|
||||
s.eventManager = event.NewManager(s.notifier, s.addr)
|
||||
s.eventManager = event.NewManager(s.app, s.addr)
|
||||
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
|
||||
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
|
||||
if event.Category == proto.SystemEvent_SYSTEM {
|
||||
@@ -1148,6 +1146,9 @@ func (s *serviceClient) onTrayReady() {
|
||||
|
||||
go s.eventManager.Start(s.ctx)
|
||||
go s.eventHandler.listen(s.ctx)
|
||||
|
||||
// Start sleep detection listener
|
||||
go s.startSleepListener()
|
||||
}
|
||||
|
||||
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
|
||||
@@ -1208,6 +1209,62 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
|
||||
return s.conn, nil
|
||||
}
|
||||
|
||||
// startSleepListener initializes the sleep detection service and listens for sleep events
|
||||
func (s *serviceClient) startSleepListener() {
|
||||
sleepService, err := sleep.New()
|
||||
if err != nil {
|
||||
log.Warnf("%v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := sleepService.Register(s.handleSleepEvents); err != nil {
|
||||
log.Errorf("failed to start sleep detection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("sleep detection service initialized")
|
||||
|
||||
// Cleanup on context cancellation
|
||||
go func() {
|
||||
<-s.ctx.Done()
|
||||
log.Info("stopping sleep event listener")
|
||||
if err := sleepService.Deregister(); err != nil {
|
||||
log.Errorf("failed to deregister sleep detection: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// handleSleepEvents sends a sleep notification to the daemon via gRPC
|
||||
func (s *serviceClient) handleSleepEvents(event sleep.EventType) {
|
||||
conn, err := s.getSrvClient(0)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get daemon client for sleep notification: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
req := &proto.OSLifecycleRequest{}
|
||||
|
||||
switch event {
|
||||
case sleep.EventTypeWakeUp:
|
||||
log.Infof("handle wakeup event: %v", event)
|
||||
req.Type = proto.OSLifecycleRequest_WAKEUP
|
||||
case sleep.EventTypeSleep:
|
||||
log.Infof("handle sleep event: %v", event)
|
||||
req.Type = proto.OSLifecycleRequest_SLEEP
|
||||
default:
|
||||
log.Infof("unknown event: %v", event)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.NotifyOSLifecycle(s.ctx, req)
|
||||
if err != nil {
|
||||
log.Errorf("failed to notify daemon about os lifecycle notification: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("successfully notified daemon about os lifecycle")
|
||||
}
|
||||
|
||||
// setSettingsEnabled enables or disables the settings menu based on the provided state
|
||||
func (s *serviceClient) setSettingsEnabled(enabled bool) {
|
||||
if s.mSettings != nil {
|
||||
@@ -1491,7 +1548,7 @@ func (s *serviceClient) onUpdateAvailable(newVersion string, enforced bool) {
|
||||
|
||||
if enforced && s.lastNotifiedVersion != newVersion {
|
||||
s.lastNotifiedVersion = newVersion
|
||||
s.notifier.Send("Update available", "A new version "+newVersion+" is ready to install")
|
||||
s.app.SendNotification(fyne.NewNotification("Update available", "A new version "+newVersion+" is ready to install"))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
@@ -17,17 +18,11 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ui/desktop"
|
||||
)
|
||||
|
||||
// Notifier sends desktop notifications. Defined here so the event package
|
||||
// does not depend on fyne or the platform-specific notifier implementation.
|
||||
type Notifier interface {
|
||||
Send(title, body string)
|
||||
}
|
||||
|
||||
type Handler func(*proto.SystemEvent)
|
||||
|
||||
type Manager struct {
|
||||
notifier Notifier
|
||||
addr string
|
||||
app fyne.App
|
||||
addr string
|
||||
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
@@ -36,10 +31,10 @@ type Manager struct {
|
||||
handlers []Handler
|
||||
}
|
||||
|
||||
func NewManager(notifier Notifier, addr string) *Manager {
|
||||
func NewManager(app fyne.App, addr string) *Manager {
|
||||
return &Manager{
|
||||
notifier: notifier,
|
||||
addr: addr,
|
||||
app: app,
|
||||
addr: addr,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,7 +114,7 @@ func (e *Manager) handleEvent(event *proto.SystemEvent) {
|
||||
if id != "" {
|
||||
body += fmt.Sprintf(" ID: %s", id)
|
||||
}
|
||||
e.notifier.Send(title, body)
|
||||
e.app.SendNotification(fyne.NewNotification(title, body))
|
||||
}
|
||||
|
||||
for _, handler := range handlers {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/systray"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -86,7 +87,7 @@ func (h *eventHandler) handleConnectClick() {
|
||||
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
|
||||
log.Debugf("connect operation cancelled by user")
|
||||
} else {
|
||||
h.client.notifier.Send("Error", "Failed to connect")
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect"))
|
||||
log.Errorf("connect failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -111,7 +112,7 @@ func (h *eventHandler) handleDisconnectClick() {
|
||||
if err := h.client.menuDownClick(); err != nil {
|
||||
st, ok := status.FromError(err)
|
||||
if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) {
|
||||
h.client.notifier.Send("Error", "Failed to disconnect")
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect"))
|
||||
log.Errorf("disconnect failed: %v", err)
|
||||
} else {
|
||||
log.Debugf("disconnect cancelled or already disconnecting")
|
||||
@@ -129,7 +130,7 @@ func (h *eventHandler) handleAllowSSHClick() {
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
h.toggleCheckbox(h.client.mAllowSSH) // revert checkbox state on error
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
h.client.notifier.Send("Error", "Failed to update SSH settings")
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update SSH settings"))
|
||||
}
|
||||
|
||||
}
|
||||
@@ -139,7 +140,7 @@ func (h *eventHandler) handleAutoConnectClick() {
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
h.toggleCheckbox(h.client.mAutoConnect) // revert checkbox state on error
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
h.client.notifier.Send("Error", "Failed to update auto-connect settings")
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update auto-connect settings"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,7 +149,7 @@ func (h *eventHandler) handleRosenpassClick() {
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
h.toggleCheckbox(h.client.mEnableRosenpass) // revert checkbox state on error
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
h.client.notifier.Send("Error", "Failed to update Rosenpass settings")
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update Rosenpass settings"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,7 +158,7 @@ func (h *eventHandler) handleLazyConnectionClick() {
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
h.toggleCheckbox(h.client.mLazyConnEnabled) // revert checkbox state on error
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
h.client.notifier.Send("Error", "Failed to update lazy connection settings")
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update lazy connection settings"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,7 +167,7 @@ func (h *eventHandler) handleBlockInboundClick() {
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
h.toggleCheckbox(h.client.mBlockInbound) // revert checkbox state on error
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
h.client.notifier.Send("Error", "Failed to update block inbound settings")
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update block inbound settings"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,7 +176,7 @@ func (h *eventHandler) handleNotificationsClick() {
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
h.toggleCheckbox(h.client.mNotifications) // revert checkbox state on error
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
h.client.notifier.Send("Error", "Failed to update notifications settings")
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update notifications settings"))
|
||||
} else if h.client.eventManager != nil {
|
||||
h.client.eventManager.SetNotificationsEnabled(h.client.mNotifications.Checked())
|
||||
}
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
// Package notifier sends desktop notifications. On Windows it uses the WinRT
|
||||
// COM API directly via go-toast/v2 to avoid the PowerShell window flash that
|
||||
// fyne's default implementation produces. On other platforms it delegates to
|
||||
// fyne.
|
||||
package notifier
|
||||
|
||||
import "fyne.io/fyne/v2"
|
||||
|
||||
// Notifier sends desktop notifications.
|
||||
type Notifier interface {
|
||||
Send(title, body string)
|
||||
}
|
||||
|
||||
// New returns a platform-specific Notifier. The fyne app is used as the
|
||||
// fallback notifier on platforms where no native implementation is wired up,
|
||||
// and on Windows when the COM path fails to initialize.
|
||||
func New(app fyne.App) Notifier {
|
||||
return newNotifier(app)
|
||||
}
|
||||
|
||||
type fyneNotifier struct {
|
||||
app fyne.App
|
||||
}
|
||||
|
||||
func (f *fyneNotifier) Send(title, body string) {
|
||||
f.app.SendNotification(fyne.NewNotification(title, body))
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package notifier
|
||||
|
||||
import "fyne.io/fyne/v2"
|
||||
|
||||
func newNotifier(app fyne.App) Notifier {
|
||||
return &fyneNotifier{app: app}
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
toast "git.sr.ht/~jackmordaunt/go-toast/v2"
|
||||
"git.sr.ht/~jackmordaunt/go-toast/v2/wintoast"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// appID is the AppUserModelID shown in the Windows Action Center. It
|
||||
// must match the System.AppUserModel.ID property set on the Start Menu
|
||||
// shortcut by the MSI (see client/netbird.wxs); otherwise Windows
|
||||
// groups toasts under a separate, unbranded entry.
|
||||
appID = "NetBird"
|
||||
|
||||
// appGUID identifies the COM activation callback class. Generated once
|
||||
// for NetBird; do not change without coordinating an installer bump,
|
||||
// since old registry entries pointing at the previous GUID would orphan.
|
||||
appGUID = "{0E1B4DE7-E148-432B-9814-544F941826EC}"
|
||||
)
|
||||
|
||||
type comNotifier struct {
|
||||
fallback *fyneNotifier
|
||||
ready bool
|
||||
iconPath string
|
||||
}
|
||||
|
||||
var (
|
||||
initOnce sync.Once
|
||||
initErr error
|
||||
)
|
||||
|
||||
func newNotifier(app fyne.App) Notifier {
|
||||
n := &comNotifier{
|
||||
fallback: &fyneNotifier{app: app},
|
||||
iconPath: resolveIcon(),
|
||||
}
|
||||
initOnce.Do(func() {
|
||||
initErr = wintoast.SetAppData(wintoast.AppData{
|
||||
AppID: appID,
|
||||
GUID: appGUID,
|
||||
IconPath: n.iconPath,
|
||||
})
|
||||
})
|
||||
if initErr != nil {
|
||||
log.Warnf("toast: register app data failed, falling back to fyne notifications: %v", initErr)
|
||||
return n.fallback
|
||||
}
|
||||
n.ready = true
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *comNotifier) Send(title, body string) {
|
||||
if !n.ready {
|
||||
n.fallback.Send(title, body)
|
||||
return
|
||||
}
|
||||
notification := toast.Notification{
|
||||
AppID: appID,
|
||||
Title: title,
|
||||
Body: body,
|
||||
Icon: n.iconPath,
|
||||
}
|
||||
if err := notification.Push(); err != nil {
|
||||
log.Warnf("toast: push failed, using fyne fallback: %v", err)
|
||||
n.fallback.Send(title, body)
|
||||
}
|
||||
}
|
||||
|
||||
// resolveIcon returns an absolute path to the toast icon, or an empty string
|
||||
// when no icon can be located. Windows requires a PNG/JPG for the
|
||||
// AppUserModelId IconUri registry value; .ico is silently ignored.
|
||||
func resolveIcon() string {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
candidate := filepath.Join(filepath.Dir(exe), "netbird.png")
|
||||
if _, err := os.Stat(candidate); err == nil {
|
||||
return candidate
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -548,7 +548,7 @@ func (p *profileMenu) refresh() {
|
||||
if err != nil {
|
||||
log.Errorf("failed to switch profile: %v", err)
|
||||
// show notification dialog
|
||||
p.serviceClient.notifier.Send("Error", "Failed to switch profile")
|
||||
p.app.SendNotification(fyne.NewNotification("Error", "Failed to switch profile"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -628,9 +628,9 @@ func (p *profileMenu) refresh() {
|
||||
}
|
||||
if err := p.eventHandler.logout(p.ctx); err != nil {
|
||||
log.Errorf("logout failed: %v", err)
|
||||
p.serviceClient.notifier.Send("Error", "Failed to deregister")
|
||||
p.app.SendNotification(fyne.NewNotification("Error", "Failed to deregister"))
|
||||
} else {
|
||||
p.serviceClient.notifier.Send("Success", "Deregistered successfully")
|
||||
p.app.SendNotification(fyne.NewNotification("Success", "Deregistered successfully"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,8 +119,6 @@ server:
|
||||
|
||||
# Reverse proxy settings (optional)
|
||||
# reverseProxy:
|
||||
# trustedHTTPProxies: [] # CIDRs of trusted reverse proxies (e.g. ["10.0.0.0/8"])
|
||||
# trustedHTTPProxiesCount: 0 # Number of trusted proxies in front of the server (alternative to trustedHTTPProxies)
|
||||
# trustedPeers: [] # CIDRs of trusted peer networks (e.g. ["100.64.0.0/10"])
|
||||
# accessLogRetentionDays: 7 # Days to retain HTTP access logs. 0 (or unset) defaults to 7. Negative values disable cleanup (logs kept indefinitely).
|
||||
# accessLogCleanupIntervalHours: 24 # How often (in hours) to run the access-log cleanup job. 0 (or unset) is treated as "not set" and defaults to 24 hours; cleanup remains enabled. To disable cleanup, set accessLogRetentionDays to a negative value.
|
||||
# trustedHTTPProxies: []
|
||||
# trustedHTTPProxiesCount: 0
|
||||
# trustedPeers: []
|
||||
|
||||
@@ -457,18 +457,6 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cleanups run LIFO: the goroutine-drain registered here runs after Close below,
|
||||
// which is when Receive has actually returned. Without this, the Receive goroutine
|
||||
// can outlive the test and call t.Logf after teardown, panicking.
|
||||
receiveDone := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
select {
|
||||
case <-receiveDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("Receive goroutine did not exit after Close")
|
||||
}
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
@@ -480,7 +468,6 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
||||
receivedAfterReconnect := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(receiveDone)
|
||||
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
if msg.IsInitiator || len(msg.EventId) == 0 {
|
||||
return nil
|
||||
|
||||
5
go.mod
5
go.mod
@@ -30,7 +30,6 @@ require (
|
||||
require (
|
||||
fyne.io/fyne/v2 v2.7.0
|
||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
|
||||
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3
|
||||
github.com/awnumar/memguard v0.23.0
|
||||
github.com/aws/aws-sdk-go-v2 v1.38.3
|
||||
github.com/aws/aws-sdk-go-v2/config v1.31.6
|
||||
@@ -47,7 +46,6 @@ require (
|
||||
github.com/crowdsecurity/go-cs-bouncer v0.0.21
|
||||
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
|
||||
github.com/dexidp/dex/api/v2 v2.4.0
|
||||
github.com/ebitengine/purego v0.8.4
|
||||
github.com/eko/gocache/lib/v4 v4.2.0
|
||||
github.com/eko/gocache/store/go_cache/v4 v4.2.2
|
||||
github.com/eko/gocache/store/redis/v4 v4.2.2
|
||||
@@ -180,6 +178,7 @@ require (
|
||||
github.com/docker/docker v28.0.1+incompatible // indirect
|
||||
github.com/docker/go-connections v0.6.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/ebitengine/purego v0.8.4 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fredbi/uri v1.1.1 // indirect
|
||||
github.com/fyne-io/gl-js v0.2.0 // indirect
|
||||
@@ -324,5 +323,3 @@ replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184
|
||||
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
|
||||
|
||||
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
|
||||
|
||||
replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0
|
||||
|
||||
6
go.sum
6
go.sum
@@ -15,8 +15,6 @@ fyne.io/fyne/v2 v2.7.0 h1:GvZSpE3X0liU/fqstInVvRsaboIVpIWQ4/sfjDGIGGQ=
|
||||
fyne.io/fyne/v2 v2.7.0/go.mod h1:xClVlrhxl7D+LT+BWYmcrW4Nf+dJTvkhnPgji7spAwE=
|
||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 h1:829+77I4TaMrcg9B3wf+gHhdSgoCVEgH2czlPXPbfj4=
|
||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
|
||||
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3 h1:N3IGoHHp9pb6mj1cbXbuaSXV/UMKwmbKLf53nQmtqMA=
|
||||
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3/go.mod h1:QtOLZGz8olr4qH2vWK0QH0w0O4T9fEIjMuWpKUsH7nc=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||
github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk=
|
||||
@@ -402,6 +400,8 @@ github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tA
|
||||
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
|
||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
|
||||
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
|
||||
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
|
||||
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
|
||||
@@ -449,8 +449,6 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U=
|
||||
github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
|
||||
github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus=
|
||||
github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||
|
||||
@@ -193,7 +193,7 @@ func (c *Connector) ToStorageConnector() (storage.Connector, error) {
|
||||
// are stored with types that Dex can open.
|
||||
func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) {
|
||||
switch connType {
|
||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
|
||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
|
||||
return "oidc", applyOIDCDefaults(connType, config)
|
||||
default:
|
||||
return connType, config
|
||||
@@ -218,8 +218,6 @@ func applyOIDCDefaults(connType string, config map[string]interface{}) map[strin
|
||||
setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"})
|
||||
case "okta", "pocketid":
|
||||
augmented["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "adfs":
|
||||
augmented["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
|
||||
}
|
||||
|
||||
return augmented
|
||||
|
||||
@@ -168,7 +168,7 @@ func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connecto
|
||||
var err error
|
||||
|
||||
switch cfg.Type {
|
||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
|
||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
|
||||
dexType = "oidc"
|
||||
configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
|
||||
case "google":
|
||||
@@ -220,8 +220,6 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "pocketid":
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "adfs":
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
|
||||
}
|
||||
return encodeConnectorConfig(oidcConfig)
|
||||
}
|
||||
@@ -285,7 +283,7 @@ func inferIdentityProviderType(dexType, connectorID string, _ map[string]interfa
|
||||
// inferOIDCProviderType infers the specific OIDC provider from connector ID
|
||||
func inferOIDCProviderType(connectorID string) string {
|
||||
connectorIDLower := strings.ToLower(connectorID)
|
||||
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak", "adfs"} {
|
||||
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} {
|
||||
if strings.Contains(connectorIDLower, provider) {
|
||||
return provider
|
||||
}
|
||||
|
||||
@@ -231,20 +231,7 @@ get_upstream_host() {
|
||||
|
||||
wait_management_proxy() {
|
||||
local proxy_container="${1:-traefik}"
|
||||
local use_docker_logs=false
|
||||
set +e
|
||||
|
||||
if [[ "$proxy_container" == "detect-traefik" ]]; then
|
||||
proxy_container=$(docker ps --format "{{.ID}}\t{{.Image}}\t{{.Ports}}" \
|
||||
| awk -F'\t' '$2 ~ /traefik/ && $3 ~ /:(80|443)->/ {print $1; exit}')
|
||||
|
||||
if [[ -z "$proxy_container" ]]; then
|
||||
echo "Warning: could not auto-detect Traefik container, log output will be skipped on timeout." > /dev/stderr
|
||||
else
|
||||
use_docker_logs=true
|
||||
fi
|
||||
fi
|
||||
|
||||
echo -n "Waiting for NetBird server to become ready"
|
||||
counter=1
|
||||
while true; do
|
||||
@@ -255,13 +242,7 @@ wait_management_proxy() {
|
||||
if [[ $counter -eq 60 ]]; then
|
||||
echo ""
|
||||
echo "Taking too long. Checking logs..."
|
||||
if [[ -n "$proxy_container" ]]; then
|
||||
if [[ "$use_docker_logs" == "true" ]]; then
|
||||
docker logs --tail=20 "$proxy_container"
|
||||
else
|
||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 "$proxy_container"
|
||||
fi
|
||||
fi
|
||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 "$proxy_container"
|
||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 netbird-server
|
||||
fi
|
||||
echo -n " ."
|
||||
@@ -491,7 +472,7 @@ start_services_and_show_instructions() {
|
||||
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
|
||||
echo "Registering CrowdSec bouncer..."
|
||||
local cs_retries=0
|
||||
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli lapi status >/dev/null 2>&1; do
|
||||
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli capi status >/dev/null 2>&1; do
|
||||
cs_retries=$((cs_retries + 1))
|
||||
if [[ $cs_retries -ge 30 ]]; then
|
||||
echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr
|
||||
@@ -537,7 +518,7 @@ start_services_and_show_instructions() {
|
||||
$DOCKER_COMPOSE_COMMAND up -d
|
||||
|
||||
sleep 3
|
||||
wait_management_proxy detect-traefik
|
||||
wait_management_direct
|
||||
|
||||
echo -e "$MSG_DONE"
|
||||
print_post_setup_instructions
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -15,9 +16,11 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/mod/semver"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@@ -55,6 +58,13 @@ type Controller struct {
|
||||
proxyController port_forwarding.Controller
|
||||
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
holder *types.Holder
|
||||
|
||||
expNewNetworkMap bool
|
||||
expNewNetworkMapAIDs map[string]struct{}
|
||||
|
||||
compactedNetworkMap bool
|
||||
}
|
||||
|
||||
type bufferUpdate struct {
|
||||
@@ -71,6 +81,29 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
||||
}
|
||||
|
||||
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
|
||||
newNetworkMapBuilder = false
|
||||
}
|
||||
|
||||
compactedNetworkMap := true
|
||||
compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted)
|
||||
parsedCompactedNmap, err := strconv.ParseBool(compactedEnv)
|
||||
if err != nil && len(compactedEnv) > 0 {
|
||||
log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err)
|
||||
}
|
||||
if err == nil && !parsedCompactedNmap {
|
||||
log.WithContext(ctx).Info("disabling compacted mode")
|
||||
compactedNetworkMap = false
|
||||
}
|
||||
|
||||
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
|
||||
expIDs := make(map[string]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
expIDs[id] = struct{}{}
|
||||
}
|
||||
|
||||
return &Controller{
|
||||
repo: newRepository(store),
|
||||
metrics: nMetrics,
|
||||
@@ -84,6 +117,12 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
|
||||
proxyController: proxyController,
|
||||
EphemeralPeersManager: ephemeralPeersManager,
|
||||
|
||||
holder: types.NewHolder(),
|
||||
expNewNetworkMap: newNetworkMapBuilder,
|
||||
expNewNetworkMapAIDs: expIDs,
|
||||
|
||||
compactedNetworkMap: compactedNetworkMap,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,9 +153,17 @@ func (c *Controller) CountStreams() int {
|
||||
|
||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account: %v", err)
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account = c.getAccountFromHolderOrInit(ctx, accountID)
|
||||
} else {
|
||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
globalStart := time.Now()
|
||||
@@ -150,6 +197,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
||||
}
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
@@ -192,7 +243,16 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
switch {
|
||||
case c.experimentalNetworkMap(accountID):
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
case c.compactedNetworkMap:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
default:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
@@ -258,6 +318,10 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
|
||||
// UpdatePeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return fmt.Errorf("recalculate network map cache: %v", err)
|
||||
}
|
||||
|
||||
return c.sendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -307,7 +371,16 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
return err
|
||||
}
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
switch {
|
||||
case c.experimentalNetworkMap(accountId):
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
case c.compactedNetworkMap:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
default:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
@@ -378,9 +451,17 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return peer, emptyMap, nil, 0, nil
|
||||
}
|
||||
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account = c.getAccountFromHolderOrInit(ctx, accountID)
|
||||
} else {
|
||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
@@ -412,10 +493,20 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
} else {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
if c.compactedNetworkMap {
|
||||
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
@@ -427,6 +518,108 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return peer, networkMap, postureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
c.enrichAccountFromHolder(account)
|
||||
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
|
||||
}
|
||||
|
||||
func (c *Controller) getPeerNetworkMapExp(
|
||||
ctx context.Context,
|
||||
accountId string,
|
||||
peerId string,
|
||||
validatedPeers map[string]struct{},
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *types.NetworkMap {
|
||||
account := c.getAccountFromHolderOrInit(ctx, accountId)
|
||||
if account == nil {
|
||||
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
|
||||
return &types.NetworkMap{
|
||||
Network: &types.Network{},
|
||||
}
|
||||
}
|
||||
|
||||
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
|
||||
c.enrichAccountFromHolder(account)
|
||||
account.OnPeersAddedUpdNetworkMapCache(peerIds...)
|
||||
}
|
||||
|
||||
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
||||
c.enrichAccountFromHolder(account)
|
||||
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
|
||||
}
|
||||
|
||||
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
|
||||
account := c.getAccountFromHolder(accountId)
|
||||
if account == nil {
|
||||
return
|
||||
}
|
||||
account.UpdatePeerInNetworkMapCache(peer)
|
||||
}
|
||||
|
||||
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
account.RecalculateNetworkMapCache(validatedPeers)
|
||||
c.updateAccountInHolder(account)
|
||||
}
|
||||
|
||||
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
|
||||
if c.experimentalNetworkMap(accountId) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
|
||||
return err
|
||||
}
|
||||
c.recalculateNetworkMapCache(account, validatedPeers)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) experimentalNetworkMap(accountId string) bool {
|
||||
_, ok := c.expNewNetworkMapAIDs[accountId]
|
||||
return c.expNewNetworkMap || ok
|
||||
}
|
||||
|
||||
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
|
||||
a := c.holder.GetAccount(account.Id)
|
||||
if a == nil {
|
||||
c.holder.AddAccount(account)
|
||||
return
|
||||
}
|
||||
account.NetworkMapCache = a.NetworkMapCache
|
||||
if account.NetworkMapCache == nil {
|
||||
return
|
||||
}
|
||||
c.holder.AddAccount(account)
|
||||
}
|
||||
|
||||
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
|
||||
return c.holder.GetAccount(accountID)
|
||||
}
|
||||
|
||||
func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account {
|
||||
a := c.holder.GetAccount(accountID)
|
||||
if a != nil {
|
||||
return a
|
||||
}
|
||||
account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return account
|
||||
}
|
||||
|
||||
func (c *Controller) updateAccountInHolder(account *types.Account) {
|
||||
c.holder.AddAccount(account)
|
||||
}
|
||||
|
||||
// GetDNSDomain returns the configured dnsDomain
|
||||
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
|
||||
if settings == nil {
|
||||
@@ -563,7 +756,16 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
err := c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get peers by ids: %w", err)
|
||||
}
|
||||
|
||||
for _, peer := range peers {
|
||||
c.UpdatePeerInNetworkMapCache(accountID, peer)
|
||||
}
|
||||
|
||||
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
||||
}
|
||||
@@ -573,6 +775,14 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
|
||||
|
||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("peers are ready to be added to networkmap cache: %v", peerIDs)
|
||||
c.onPeersAddedUpdNetworkMapCache(account, peerIDs...)
|
||||
}
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -607,6 +817,19 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
})
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
|
||||
continue
|
||||
}
|
||||
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
@@ -649,11 +872,21 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(peer.AccountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
||||
} else {
|
||||
account.InjectProxyPolicies(ctx)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
if c.compactedNetworkMap {
|
||||
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
}
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
|
||||
@@ -12,6 +12,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
|
||||
EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
|
||||
|
||||
DnsForwarderPort = nbdns.ForwarderServerPort
|
||||
OldForwarderPort = nbdns.ForwarderClientPort
|
||||
DnsForwarderPortMinVersion = "v0.59.0"
|
||||
|
||||
@@ -30,7 +30,6 @@ import (
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
@@ -110,7 +109,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -118,15 +117,6 @@ func (s *BaseServer) APIHandler() http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
||||
return Create(s, func() *middleware.APIRateLimiter {
|
||||
cfg, enabled := middleware.RateLimiterConfigFromEnv()
|
||||
limiter := middleware.NewAPIRateLimiter(cfg)
|
||||
limiter.SetEnabled(enabled)
|
||||
return limiter
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
return Create(s, func() *grpc.Server {
|
||||
trustedPeers := s.Config.ReverseProxy.TrustedPeers
|
||||
@@ -173,7 +163,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
}
|
||||
|
||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider(), s.SessionStore())
|
||||
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create management server: %v", err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
|
||||
@@ -67,12 +66,6 @@ func (s *BaseServer) SecretsManager() grpc.SecretsManager {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) SessionStore() *auth.SessionStore {
|
||||
return Create(s, func() *auth.SessionStore {
|
||||
return auth.NewSessionStore(s.CacheStore())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) AuthManager() auth.Manager {
|
||||
audiences := s.Config.GetAuthAudiences()
|
||||
audience := s.Config.HttpConfig.AuthAudience
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
jwtv5 "github.com/golang-jwt/jwt/v5"
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
@@ -68,7 +67,6 @@ type Server struct {
|
||||
appMetrics telemetry.AppMetrics
|
||||
peerLocks sync.Map
|
||||
authManager auth.Manager
|
||||
sessionStore *auth.SessionStore
|
||||
|
||||
logBlockedPeers bool
|
||||
blockPeersWithSameConfig bool
|
||||
@@ -100,7 +98,6 @@ func NewServer(
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||
networkMapController network_map.Controller,
|
||||
oAuthConfigProvider idp.OAuthConfigProvider,
|
||||
sessionStore *auth.SessionStore,
|
||||
) (*Server, error) {
|
||||
if appMetrics != nil {
|
||||
// update gauge based on number of connected peers which is equal to open gRPC streams
|
||||
@@ -143,7 +140,6 @@ func NewServer(
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
networkMapController: networkMapController,
|
||||
oAuthConfigProvider: oAuthConfigProvider,
|
||||
sessionStore: sessionStore,
|
||||
|
||||
loginFilter: newLoginFilter(),
|
||||
|
||||
@@ -539,7 +535,7 @@ func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID st
|
||||
log.WithContext(ctx).Debugf("peer %s has been disconnected", peer.Key)
|
||||
}
|
||||
|
||||
func (s *Server) validateToken(ctx context.Context, peerKey, jwtToken string) (string, error) {
|
||||
func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||
if s.authManager == nil {
|
||||
return "", status.Errorf(codes.Internal, "missing auth manager")
|
||||
}
|
||||
@@ -549,10 +545,6 @@ func (s *Server) validateToken(ctx context.Context, peerKey, jwtToken string) (s
|
||||
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
|
||||
}
|
||||
|
||||
if err := s.claimLoginToken(ctx, peerKey, jwtToken, token); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
accountId, _, err := s.accountManager.GetAccountIDFromUserAuth(ctx, userAuth)
|
||||
if err != nil {
|
||||
@@ -836,31 +828,6 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
func (s *Server) claimLoginToken(ctx context.Context, peerKey, jwtToken string, token *jwtv5.Token) error {
|
||||
if s.sessionStore == nil || token == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
exp, err := token.Claims.GetExpirationTime()
|
||||
if err != nil || exp == nil {
|
||||
log.WithContext(ctx).Warnf("JWT has no usable exp claim for peer %s", peerKey)
|
||||
return status.Error(codes.Unauthenticated, "jwt token has no expiration")
|
||||
}
|
||||
|
||||
err = s.sessionStore.RegisterToken(ctx, jwtToken, exp.Time)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if errors.Is(err, auth.ErrTokenAlreadyUsed) || errors.Is(err, auth.ErrTokenExpired) {
|
||||
log.WithContext(ctx).Warnf("%v for peer %s", err, peerKey)
|
||||
return status.Error(codes.Unauthenticated, err.Error())
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Warnf("failed to claim JWT for peer %s: %v", peerKey, err)
|
||||
return status.Error(codes.Unavailable, "failed to claim jwt token")
|
||||
}
|
||||
|
||||
// processJwtToken validates the existence of a JWT token in the login request, and returns the corresponding user ID if
|
||||
// the token is valid.
|
||||
//
|
||||
@@ -871,7 +838,7 @@ func (s *Server) processJwtToken(ctx context.Context, loginReq *proto.LoginReque
|
||||
if loginReq.GetJwtToken() != "" {
|
||||
var err error
|
||||
for i := 0; i < 3; i++ {
|
||||
userID, err = s.validateToken(ctx, peerKey.String(), loginReq.GetJwtToken())
|
||||
userID, err = s.validateToken(ctx, loginReq.GetJwtToken())
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -408,7 +408,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||
}
|
||||
@@ -1171,6 +1171,11 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
@@ -1226,6 +1231,11 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
@@ -1264,6 +1274,11 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
@@ -1317,6 +1332,11 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
@@ -1377,6 +1397,11 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
@@ -1608,6 +1633,75 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
|
||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||
}
|
||||
|
||||
func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, prefix2, err := route.ParseNetwork("192.168.0.0/24")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
account := &types.Account{
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
|
||||
},
|
||||
Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
|
||||
Routes: map[route.ID]*route.Route{
|
||||
"route-1": {
|
||||
ID: "route-1",
|
||||
Network: prefix,
|
||||
NetID: "network-1",
|
||||
Description: "network-1",
|
||||
Peer: "peer-1",
|
||||
NetworkType: 0,
|
||||
Masquerade: false,
|
||||
Metric: 999,
|
||||
Enabled: true,
|
||||
Groups: []string{"group1"},
|
||||
},
|
||||
"route-2": {
|
||||
ID: "route-2",
|
||||
Network: prefix2,
|
||||
NetID: "network-2",
|
||||
Description: "network-2",
|
||||
Peer: "peer-2",
|
||||
NetworkType: 0,
|
||||
Masquerade: false,
|
||||
Metric: 999,
|
||||
Enabled: true,
|
||||
Groups: []string{"group1"},
|
||||
},
|
||||
"route-3": {
|
||||
ID: "route-3",
|
||||
Network: prefix,
|
||||
NetID: "network-1",
|
||||
Description: "network-1",
|
||||
Peer: "peer-2",
|
||||
NetworkType: 0,
|
||||
Masquerade: false,
|
||||
Metric: 999,
|
||||
Enabled: true,
|
||||
Groups: []string{"group1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
|
||||
|
||||
assert.Len(t, routes, 2)
|
||||
routeIDs := make(map[route.ID]struct{}, 2)
|
||||
for _, r := range routes {
|
||||
routeIDs[r.ID] = struct{}{}
|
||||
}
|
||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||
assert.Contains(t, routeIDs, route.ID("route-3"))
|
||||
|
||||
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
|
||||
|
||||
assert.Len(t, emptyRoutes, 0)
|
||||
}
|
||||
|
||||
func TestAccount_Copy(t *testing.T) {
|
||||
account := &types.Account{
|
||||
Id: "account1",
|
||||
@@ -1730,7 +1824,9 @@ func TestAccount_Copy(t *testing.T) {
|
||||
AccountID: "account1",
|
||||
},
|
||||
},
|
||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||
}
|
||||
account.InitOnce()
|
||||
err := hasNilField(account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -2215,29 +2311,6 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetExpiredPeers_SkipsAlreadyExpired(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
// Verify the already-expired peer is excluded at the store level
|
||||
peers, err := testStore.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, peer := range peers {
|
||||
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should be excluded by the store query")
|
||||
assert.False(t, peer.Status.LoginExpired, "returned peers should not already be marked as login expired")
|
||||
}
|
||||
|
||||
// Only the non-expired peer with expiration enabled should be returned
|
||||
require.Len(t, peers, 1)
|
||||
assert.Equal(t, "notexpired01", peers[0].ID)
|
||||
}
|
||||
|
||||
func TestAccount_GetInactivePeers(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
@@ -3157,13 +3230,6 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
|
||||
return manager, updateManager, account, peer1, peer2, peer3
|
||||
}
|
||||
|
||||
// peerUpdateTimeout bounds how long peerShouldReceiveUpdate and its outer
|
||||
// wrappers wait for an expected update message. Sized for slow CI runners
|
||||
// (MySQL, FreeBSD, loaded sqlite) where the channel publish can take
|
||||
// seconds. Only runs down on failure; passing tests return immediately
|
||||
// when the channel delivers.
|
||||
const peerUpdateTimeout = 5 * time.Second
|
||||
|
||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
|
||||
t.Helper()
|
||||
select {
|
||||
@@ -3182,7 +3248,7 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.Upd
|
||||
if msg == nil {
|
||||
t.Errorf("Received nil update message, expected valid message")
|
||||
}
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("Timed out waiting for update message")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/eko/gocache/lib/v4/cache"
|
||||
"github.com/eko/gocache/lib/v4/store"
|
||||
)
|
||||
|
||||
const (
|
||||
usedTokenKeyPrefix = "jwt-used:"
|
||||
usedTokenMarker = "1"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTokenAlreadyUsed = errors.New("JWT already used")
|
||||
ErrTokenExpired = errors.New("JWT expired")
|
||||
)
|
||||
|
||||
type SessionStore struct {
|
||||
cache *cache.Cache[string]
|
||||
}
|
||||
|
||||
func NewSessionStore(cacheStore store.StoreInterface) *SessionStore {
|
||||
return &SessionStore{cache: cache.New[string](cacheStore)}
|
||||
}
|
||||
|
||||
// RegisterToken records a JWT until its exp time and rejects reuse.
|
||||
func (s *SessionStore) RegisterToken(ctx context.Context, token string, expiresAt time.Time) error {
|
||||
ttl := time.Until(expiresAt)
|
||||
if ttl <= 0 {
|
||||
return ErrTokenExpired
|
||||
}
|
||||
|
||||
key := usedTokenKeyPrefix + hashToken(token)
|
||||
_, err := s.cache.Get(ctx, key)
|
||||
if err == nil {
|
||||
return ErrTokenAlreadyUsed
|
||||
}
|
||||
|
||||
var notFound *store.NotFound
|
||||
if !errors.As(err, ¬Found) {
|
||||
return fmt.Errorf("failed to lookup used token entry: %w", err)
|
||||
}
|
||||
|
||||
if err := s.cache.Set(ctx, key, usedTokenMarker, store.WithExpiration(ttl)); err != nil {
|
||||
return fmt.Errorf("failed to store used token entry: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func hashToken(token string) string {
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
)
|
||||
|
||||
func newTestSessionStore(t *testing.T) *SessionStore {
|
||||
t.Helper()
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), time.Hour, time.Hour, 100)
|
||||
require.NoError(t, err)
|
||||
return NewSessionStore(cacheStore)
|
||||
}
|
||||
|
||||
func TestSessionStore_FirstRegisterSucceeds(t *testing.T) {
|
||||
s := newTestSessionStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
require.NoError(t, s.RegisterToken(ctx, "token", time.Now().Add(time.Hour)))
|
||||
}
|
||||
|
||||
func TestSessionStore_RegisterSameTokenTwiceIsRejected(t *testing.T) {
|
||||
s := newTestSessionStore(t)
|
||||
ctx := context.Background()
|
||||
token := "token"
|
||||
exp := time.Now().Add(time.Hour)
|
||||
|
||||
require.NoError(t, s.RegisterToken(ctx, token, exp))
|
||||
|
||||
err := s.RegisterToken(ctx, token, exp)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrTokenAlreadyUsed)
|
||||
}
|
||||
|
||||
func TestSessionStore_RegisterDifferentTokensAreIndependent(t *testing.T) {
|
||||
s := newTestSessionStore(t)
|
||||
ctx := context.Background()
|
||||
exp := time.Now().Add(time.Hour)
|
||||
|
||||
require.NoError(t, s.RegisterToken(ctx, "tokenA", exp))
|
||||
require.NoError(t, s.RegisterToken(ctx, "tokenB", exp))
|
||||
}
|
||||
|
||||
func TestSessionStore_RegisterWithPastExpiryIsRejected(t *testing.T) {
|
||||
s := newTestSessionStore(t)
|
||||
ctx := context.Background()
|
||||
token := "token"
|
||||
|
||||
err := s.RegisterToken(ctx, token, time.Now().Add(-time.Second))
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrTokenExpired)
|
||||
}
|
||||
|
||||
func TestSessionStore_EntryEvictsAtTTLAndAllowsReRegistration(t *testing.T) {
|
||||
s := newTestSessionStore(t)
|
||||
ctx := context.Background()
|
||||
token := "token"
|
||||
|
||||
require.NoError(t, s.RegisterToken(ctx, token, time.Now().Add(50*time.Millisecond)))
|
||||
|
||||
err := s.RegisterToken(ctx, token, time.Now().Add(50*time.Millisecond))
|
||||
assert.ErrorIs(t, err, ErrTokenAlreadyUsed)
|
||||
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
|
||||
require.NoError(t, s.RegisterToken(ctx, token, time.Now().Add(time.Hour)))
|
||||
}
|
||||
|
||||
func TestHashToken_StableAndDoesNotLeak(t *testing.T) {
|
||||
a := hashToken("tokenA")
|
||||
b := hashToken("tokenB")
|
||||
assert.Equal(t, a, hashToken("tokenA"), "hash must be deterministic")
|
||||
assert.NotEqual(t, a, b, "different tokens must hash differently")
|
||||
assert.Len(t, a, 64, "sha256 hex must be 64 chars")
|
||||
assert.NotContains(t, a, "tokenA", "raw token must not appear in hash")
|
||||
}
|
||||
@@ -458,7 +458,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -478,7 +478,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -518,7 +518,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -620,7 +620,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -638,7 +638,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -656,7 +656,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -689,7 +689,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -730,7 +730,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -757,7 +757,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -804,7 +804,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -5,6 +5,9 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/cors"
|
||||
@@ -63,11 +66,14 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
apiPrefix = "/api"
|
||||
apiPrefix = "/api"
|
||||
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
|
||||
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
|
||||
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
|
||||
)
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
@@ -88,10 +94,34 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||
}
|
||||
|
||||
if rateLimiter == nil {
|
||||
log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
|
||||
rateLimiter = middleware.NewAPIRateLimiter(nil)
|
||||
rateLimiter.SetEnabled(false)
|
||||
var rateLimitingConfig *middleware.RateLimiterConfig
|
||||
if os.Getenv(rateLimitingEnabledKey) == "true" {
|
||||
rpm := 6
|
||||
if v := os.Getenv(rateLimitingRPMKey); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
|
||||
} else {
|
||||
rpm = value
|
||||
}
|
||||
}
|
||||
|
||||
burst := 500
|
||||
if v := os.Getenv(rateLimitingBurstKey); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
|
||||
} else {
|
||||
burst = value
|
||||
}
|
||||
}
|
||||
|
||||
rateLimitingConfig = &middleware.RateLimiterConfig{
|
||||
RequestsPerMinute: float64(rpm),
|
||||
Burst: burst,
|
||||
CleanupInterval: 6 * time.Hour,
|
||||
LimiterTTL: 24 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
@@ -99,7 +129,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
accountManager.GetAccountIDFromUserAuth,
|
||||
accountManager.SyncUserJWTGroups,
|
||||
accountManager.GetUserFromUserAuth,
|
||||
rateLimiter,
|
||||
rateLimitingConfig,
|
||||
appMetrics.GetMeter(),
|
||||
)
|
||||
|
||||
|
||||
@@ -417,7 +417,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
||||
|
||||
netMap := account.GetPeerNetworkMapFromComponents(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
@@ -43,9 +42,14 @@ func NewAuthMiddleware(
|
||||
ensureAccount EnsureAccountFunc,
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
rateLimiter *APIRateLimiter,
|
||||
rateLimiterConfig *RateLimiterConfig,
|
||||
meter metric.Meter,
|
||||
) *AuthMiddleware {
|
||||
var rateLimiter *APIRateLimiter
|
||||
if rateLimiterConfig != nil {
|
||||
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
|
||||
}
|
||||
|
||||
var patUsageTracker *PATUsageTracker
|
||||
if meter != nil {
|
||||
var err error
|
||||
@@ -83,14 +87,17 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
|
||||
switch authType {
|
||||
case "bearer":
|
||||
if err := m.checkJWTFromRequest(r, authHeader); err != nil {
|
||||
request, err := m.checkJWTFromRequest(r, authHeader)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
h.ServeHTTP(w, request)
|
||||
case "token":
|
||||
if err := m.checkPATFromRequest(r, authHeader); err != nil {
|
||||
request, err := m.checkPATFromRequest(r, authHeader)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||
// Check if it's a status error, otherwise default to Unauthorized
|
||||
if _, ok := status.FromError(err); !ok {
|
||||
@@ -99,7 +106,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
h.ServeHTTP(w, request)
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||
return
|
||||
@@ -108,19 +115,19 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// CheckJWTFromRequest checks if the JWT is valid
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error {
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||
token, err := getTokenFromJWTRequest(authHeaderParts)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
|
||||
if err != nil {
|
||||
return err
|
||||
return r, err
|
||||
}
|
||||
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
@@ -136,7 +143,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
return r, err
|
||||
}
|
||||
|
||||
if userAuth.AccountId != accountId {
|
||||
@@ -146,7 +153,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
|
||||
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
||||
if err != nil {
|
||||
return err
|
||||
return r, err
|
||||
}
|
||||
|
||||
err = m.syncUserJWTGroups(ctx, userAuth)
|
||||
@@ -157,41 +164,41 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
_, err = m.getUserFromUserAuth(ctx, userAuth)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
||||
return err
|
||||
return r, err
|
||||
}
|
||||
|
||||
// propagates ctx change to upstream middleware
|
||||
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||
return nil
|
||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error {
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||
token, err := getTokenFromPATRequest(authHeaderParts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
if m.patUsageTracker != nil {
|
||||
m.patUsageTracker.IncrementUsage(token)
|
||||
}
|
||||
|
||||
if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) {
|
||||
return status.Errorf(status.TooManyRequests, "too many requests")
|
||||
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
||||
if !m.rateLimiter.Allow(token) {
|
||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid Token: %w", err)
|
||||
return r, fmt.Errorf("invalid Token: %w", err)
|
||||
}
|
||||
if time.Now().After(pat.GetExpirationDate()) {
|
||||
return fmt.Errorf("token expired")
|
||||
return r, fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
return r, err
|
||||
}
|
||||
|
||||
userAuth := auth.UserAuth{
|
||||
@@ -209,9 +216,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
}
|
||||
}
|
||||
|
||||
// propagates ctx change to upstream middleware
|
||||
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||
return nil
|
||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||
}
|
||||
|
||||
func isTerraformRequest(r *http.Request) bool {
|
||||
|
||||
@@ -196,8 +196,6 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
disabledLimiter := NewAPIRateLimiter(nil)
|
||||
disabledLimiter.SetEnabled(false)
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
@@ -209,7 +207,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
disabledLimiter,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -268,7 +266,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -320,7 +318,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -363,7 +361,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -407,7 +405,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -471,7 +469,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -530,7 +528,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -585,7 +583,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -672,8 +670,6 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
disabledLimiter := NewAPIRateLimiter(nil)
|
||||
disabledLimiter.SetEnabled(false)
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
@@ -685,7 +681,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
disabledLimiter,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,27 +4,14 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
const (
|
||||
RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED"
|
||||
RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST"
|
||||
RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM"
|
||||
|
||||
defaultAPIRPM = 6
|
||||
defaultAPIBurst = 500
|
||||
)
|
||||
|
||||
// RateLimiterConfig holds configuration for the API rate limiter
|
||||
type RateLimiterConfig struct {
|
||||
// RequestsPerMinute defines the rate at which tokens are replenished
|
||||
@@ -47,43 +34,6 @@ func DefaultRateLimiterConfig() *RateLimiterConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) {
|
||||
rpm := defaultAPIRPM
|
||||
if v := os.Getenv(RateLimitingRPMEnv); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm)
|
||||
} else {
|
||||
rpm = value
|
||||
}
|
||||
}
|
||||
if rpm <= 0 {
|
||||
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingRPMEnv, rpm, defaultAPIRPM)
|
||||
rpm = defaultAPIRPM
|
||||
}
|
||||
|
||||
burst := defaultAPIBurst
|
||||
if v := os.Getenv(RateLimitingBurstEnv); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst)
|
||||
} else {
|
||||
burst = value
|
||||
}
|
||||
}
|
||||
if burst <= 0 {
|
||||
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingBurstEnv, burst, defaultAPIBurst)
|
||||
burst = defaultAPIBurst
|
||||
}
|
||||
|
||||
return &RateLimiterConfig{
|
||||
RequestsPerMinute: float64(rpm),
|
||||
Burst: burst,
|
||||
CleanupInterval: 6 * time.Hour,
|
||||
LimiterTTL: 24 * time.Hour,
|
||||
}, os.Getenv(RateLimitingEnabledEnv) == "true"
|
||||
}
|
||||
|
||||
// limiterEntry holds a rate limiter and its last access time
|
||||
type limiterEntry struct {
|
||||
limiter *rate.Limiter
|
||||
@@ -96,7 +46,6 @@ type APIRateLimiter struct {
|
||||
limiters map[string]*limiterEntry
|
||||
mu sync.RWMutex
|
||||
stopChan chan struct{}
|
||||
enabled atomic.Bool
|
||||
}
|
||||
|
||||
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
|
||||
@@ -110,53 +59,14 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
|
||||
limiters: make(map[string]*limiterEntry),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
rl.enabled.Store(true)
|
||||
|
||||
go rl.cleanupLoop()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) SetEnabled(enabled bool) {
|
||||
rl.enabled.Store(enabled)
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) Enabled() bool {
|
||||
return rl.enabled.Load()
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) {
|
||||
if config == nil {
|
||||
return
|
||||
}
|
||||
if config.RequestsPerMinute <= 0 || config.Burst <= 0 {
|
||||
log.Warnf("UpdateConfig: ignoring invalid rpm=%v burst=%d", config.RequestsPerMinute, config.Burst)
|
||||
return
|
||||
}
|
||||
|
||||
newRPS := rate.Limit(config.RequestsPerMinute / 60.0)
|
||||
newBurst := config.Burst
|
||||
|
||||
rl.mu.Lock()
|
||||
rl.config.RequestsPerMinute = config.RequestsPerMinute
|
||||
rl.config.Burst = newBurst
|
||||
snapshot := make([]*rate.Limiter, 0, len(rl.limiters))
|
||||
for _, entry := range rl.limiters {
|
||||
snapshot = append(snapshot, entry.limiter)
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
|
||||
for _, l := range snapshot {
|
||||
l.SetLimit(newRPS)
|
||||
l.SetBurst(newBurst)
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request for the given key (token) is allowed
|
||||
func (rl *APIRateLimiter) Allow(key string) bool {
|
||||
if !rl.enabled.Load() {
|
||||
return true
|
||||
}
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Allow()
|
||||
}
|
||||
@@ -164,9 +74,6 @@ func (rl *APIRateLimiter) Allow(key string) bool {
|
||||
// Wait blocks until the rate limiter allows another request for the given key
|
||||
// Returns an error if the context is canceled
|
||||
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
|
||||
if !rl.enabled.Load() {
|
||||
return nil
|
||||
}
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Wait(ctx)
|
||||
}
|
||||
@@ -246,10 +153,6 @@ func (rl *APIRateLimiter) Reset(key string) {
|
||||
// Returns 429 Too Many Requests if the rate limit is exceeded.
|
||||
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !rl.enabled.Load() {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
clientIP := getClientIP(r)
|
||||
if !rl.Allow(clientIP) {
|
||||
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -158,172 +156,3 @@ func TestAPIRateLimiter_Reset(t *testing.T) {
|
||||
// Should be allowed again
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_SetEnabled(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
assert.True(t, rl.Allow("key"))
|
||||
assert.False(t, rl.Allow("key"), "burst exhausted while enabled")
|
||||
|
||||
rl.SetEnabled(false)
|
||||
assert.False(t, rl.Enabled())
|
||||
for i := 0; i < 5; i++ {
|
||||
assert.True(t, rl.Allow("key"), "disabled limiter must always allow")
|
||||
}
|
||||
|
||||
rl.SetEnabled(true)
|
||||
assert.True(t, rl.Enabled())
|
||||
assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state")
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_UpdateConfig(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 2,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
assert.True(t, rl.Allow("k1"))
|
||||
assert.True(t, rl.Allow("k1"))
|
||||
assert.False(t, rl.Allow("k1"), "burst=2 exhausted")
|
||||
|
||||
rl.UpdateConfig(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 10,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
|
||||
// New burst applies to existing keys in place; bucket refills up to new burst over time,
|
||||
// but importantly newly-added keys use the updated config immediately.
|
||||
assert.True(t, rl.Allow("k2"))
|
||||
for i := 0; i < 9; i++ {
|
||||
assert.True(t, rl.Allow("k2"))
|
||||
}
|
||||
assert.False(t, rl.Allow("k2"), "new burst=10 exhausted")
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
rl.UpdateConfig(nil) // must not panic or zero the config
|
||||
|
||||
assert.True(t, rl.Allow("k"))
|
||||
assert.False(t, rl.Allow("k"))
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
assert.True(t, rl.Allow("k"))
|
||||
assert.False(t, rl.Allow("k"))
|
||||
|
||||
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
||||
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
||||
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
||||
|
||||
rl.Reset("k")
|
||||
assert.True(t, rl.Allow("k"))
|
||||
assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored")
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 600,
|
||||
Burst: 10,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stop := make(chan struct{})
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
key := fmt.Sprintf("k%d", id)
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
rl.Allow(key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 200; i++ {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
rl.UpdateConfig(&RateLimiterConfig{
|
||||
RequestsPerMinute: float64(30 + (i % 90)),
|
||||
Burst: 1 + (i % 20),
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
rl.SetEnabled(i%2 == 0)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
close(stop)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestRateLimiterConfigFromEnv(t *testing.T) {
|
||||
t.Setenv(RateLimitingEnabledEnv, "true")
|
||||
t.Setenv(RateLimitingRPMEnv, "42")
|
||||
t.Setenv(RateLimitingBurstEnv, "7")
|
||||
|
||||
cfg, enabled := RateLimiterConfigFromEnv()
|
||||
assert.True(t, enabled)
|
||||
assert.Equal(t, float64(42), cfg.RequestsPerMinute)
|
||||
assert.Equal(t, 7, cfg.Burst)
|
||||
|
||||
t.Setenv(RateLimitingEnabledEnv, "false")
|
||||
_, enabled = RateLimiterConfigFromEnv()
|
||||
assert.False(t, enabled)
|
||||
|
||||
t.Setenv(RateLimitingEnabledEnv, "")
|
||||
t.Setenv(RateLimitingRPMEnv, "")
|
||||
t.Setenv(RateLimitingBurstEnv, "")
|
||||
cfg, enabled = RateLimiterConfigFromEnv()
|
||||
assert.False(t, enabled)
|
||||
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute)
|
||||
assert.Equal(t, defaultAPIBurst, cfg.Burst)
|
||||
|
||||
t.Setenv(RateLimitingRPMEnv, "0")
|
||||
t.Setenv(RateLimitingBurstEnv, "-5")
|
||||
cfg, _ = RateLimiterConfigFromEnv()
|
||||
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default")
|
||||
assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default")
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -264,7 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -274,7 +274,7 @@ func identityProviderToConnectorConfig(idpConfig *types.IdentityProvider) *dex.C
|
||||
}
|
||||
|
||||
// generateIdentityProviderID generates a unique ID for an identity provider.
|
||||
// For specific provider types (okta, zitadel, entra, google, pocketid, microsoft, adfs),
|
||||
// For specific provider types (okta, zitadel, entra, google, pocketid, microsoft),
|
||||
// the ID is prefixed with the type name. Generic OIDC providers get no prefix.
|
||||
func generateIdentityProviderID(idpType types.IdentityProviderType) string {
|
||||
id := xid.New().String()
|
||||
@@ -296,8 +296,6 @@ func generateIdentityProviderID(idpType types.IdentityProviderType) string {
|
||||
return "authentik-" + id
|
||||
case types.IdentityProviderTypeKeycloak:
|
||||
return "keycloak-" + id
|
||||
case types.IdentityProviderTypeADFS:
|
||||
return "adfs-" + id
|
||||
default:
|
||||
// Generic OIDC - no prefix
|
||||
return id
|
||||
|
||||
@@ -267,8 +267,8 @@ func Test_SyncProtocol(t *testing.T) {
|
||||
}
|
||||
|
||||
// expired peers come separately.
|
||||
if len(networkMap.GetOfflinePeers()) != 2 {
|
||||
t.Fatal("expecting SyncResponse to have NetworkMap with 2 offline peer")
|
||||
if len(networkMap.GetOfflinePeers()) != 1 {
|
||||
t.Fatal("expecting SyncResponse to have NetworkMap with 1 offline peer")
|
||||
}
|
||||
|
||||
expiredPeerPubKey := "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4="
|
||||
@@ -391,7 +391,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
|
||||
return nil, nil, "", cleanup, err
|
||||
}
|
||||
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil)
|
||||
if err != nil {
|
||||
return nil, nil, "", cleanup, err
|
||||
}
|
||||
|
||||
@@ -256,7 +256,6 @@ func startServer(
|
||||
server.MockIntegratedValidator{},
|
||||
networkMapController,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed creating management server: %v", err)
|
||||
|
||||
@@ -1087,7 +1087,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1105,7 +1105,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -33,8 +33,8 @@ import (
|
||||
|
||||
const remoteJobsMinVer = "0.64.0"
|
||||
|
||||
// GetPeers returns peers visible to the user within an account.
|
||||
// Users with "peers:read" see all peers. Otherwise, users see only their own peers, or none if restricted by account settings.
|
||||
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
|
||||
// the current user is not an admin.
|
||||
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
@@ -46,8 +46,14 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, nameFilter, ipFilter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// @note if the user has permission to read peers it shows all account peers
|
||||
if allowed {
|
||||
return am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, nameFilter, ipFilter)
|
||||
return accountPeers, nil
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
@@ -59,7 +65,41 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
return []*nbpeer.Peer{}, nil
|
||||
}
|
||||
|
||||
return am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
|
||||
// @note if it does not have permission read peers then only display it's own peers
|
||||
peers := make([]*nbpeer.Peer, 0)
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
|
||||
for _, peer := range accountPeers {
|
||||
if user.Id != peer.UserID {
|
||||
continue
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
peersMap[peer.ID] = peer
|
||||
}
|
||||
|
||||
return am.getUserAccessiblePeers(ctx, accountID, peersMap, peers)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, accountID string, peersMap map[string]*nbpeer.Peer, peers []*nbpeer.Peer) ([]*nbpeer.Peer, error) {
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// fetch all the peers that have access to the user's peers
|
||||
for _, peer := range peers {
|
||||
aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap, account.GetActiveGroupUsers())
|
||||
for _, p := range aclPeers {
|
||||
peersMap[p.ID] = p
|
||||
}
|
||||
}
|
||||
|
||||
return maps.Values(peersMap), nil
|
||||
}
|
||||
|
||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||
@@ -1190,8 +1230,7 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Se
|
||||
return false
|
||||
}
|
||||
|
||||
// GetPeer returns a peer visible to the user within an account.
|
||||
// Users with "peers:read" permission can access any peer. Otherwise, users can access only their own peer.
|
||||
// GetPeer for a given accountID, peerID and userID error if not found.
|
||||
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
@@ -1216,6 +1255,36 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
return am.checkIfUserOwnsPeer(ctx, accountID, userID, peer)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// it is also possible that user doesn't own the peer but some of his peers have access to it,
|
||||
// this is a valid case, show the peer as well.
|
||||
userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, p := range userPeers {
|
||||
aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap, account.GetActiveGroupUsers())
|
||||
for _, aclPeer := range aclPeers {
|
||||
if aclPeer.ID == peer.ID {
|
||||
return peer, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peer.ID, accountID)
|
||||
}
|
||||
|
||||
@@ -1336,10 +1405,6 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID
|
||||
|
||||
var peers []*nbpeer.Peer
|
||||
for _, peer := range peersWithExpiry {
|
||||
if peer.Status.LoginExpired {
|
||||
continue
|
||||
}
|
||||
|
||||
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
|
||||
if expired {
|
||||
peers = append(peers, peer)
|
||||
|
||||
@@ -179,6 +179,11 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
testGetNetworkMapGeneral(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testGetNetworkMapGeneral(t)
|
||||
}
|
||||
|
||||
func testGetNetworkMapGeneral(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
@@ -559,9 +564,25 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
}
|
||||
assert.NotNil(t, peer)
|
||||
|
||||
// the user can NOT see peer2 because it is not owned by them.
|
||||
// Regular users only see peers they directly own.
|
||||
_, err = manager.GetPeer(context.Background(), accountID, peer2.ID, someUser)
|
||||
// the user can see peer2 because peer1 of the user has access to peer2 due to the All group and the default rule 0 all-to-all access
|
||||
peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, someUser)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
assert.NotNil(t, peer)
|
||||
|
||||
// delete the all-to-all policy so that user's peer1 has no access to peer2
|
||||
for _, policy := range account.Policies {
|
||||
err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// at this point the user can't see the details of peer2
|
||||
peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, someUser) //nolint
|
||||
assert.Error(t, err)
|
||||
|
||||
// admin users can always access all the peers
|
||||
@@ -995,6 +1016,11 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAccountPeers_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testUpdateAccountPeers(t)
|
||||
}
|
||||
|
||||
func TestUpdateAccountPeers(t *testing.T) {
|
||||
testUpdateAccountPeers(t)
|
||||
}
|
||||
@@ -1574,6 +1600,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_LoginPeer(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
@@ -1880,7 +1907,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1902,7 +1929,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1967,7 +1994,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1985,7 +2012,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2031,7 +2058,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2049,7 +2076,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2086,7 +2113,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2104,7 +2131,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
_ "embed"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
@@ -47,40 +46,25 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
var isUpdate = policy.ID != ""
|
||||
var updateAccountPeers bool
|
||||
var action = activity.PolicyAdded
|
||||
var unchanged bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy)
|
||||
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
saveFunc := transaction.CreatePolicy
|
||||
if isUpdate {
|
||||
if policy.Equal(existingPolicy) {
|
||||
logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID)
|
||||
unchanged = true
|
||||
return nil
|
||||
}
|
||||
|
||||
action = activity.PolicyUpdated
|
||||
saveFunc = transaction.SavePolicy
|
||||
}
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.SavePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.CreatePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = saveFunc(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
@@ -89,10 +73,6 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if unchanged {
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
@@ -121,7 +101,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -158,55 +138,49 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
|
||||
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers.
|
||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) {
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||
}
|
||||
|
||||
func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) {
|
||||
if !policy.Enabled && !existingPolicy.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, rule := range existingPolicy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||
}
|
||||
|
||||
// validatePolicy validates the policy and its rules. For updates it returns
|
||||
// the existing policy loaded from the store so callers can avoid a second read.
|
||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) (*types.Policy, error) {
|
||||
var existingPolicy *types.Policy
|
||||
if policy.ID != "" {
|
||||
var err error
|
||||
existingPolicy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
|
||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
|
||||
if isUpdate {
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !policy.Enabled && !existingPolicy.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, rule := range existingPolicy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||
}
|
||||
|
||||
// validatePolicy validates the policy and its rules.
|
||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
|
||||
if policy.ID != "" {
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: Refactor to support multiple rules per policy
|
||||
@@ -217,7 +191,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.ID != "" && !existingRuleIDs[rule.ID] {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
|
||||
return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -227,12 +201,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
for i, rule := range policy.Rules {
|
||||
@@ -251,7 +225,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
||||
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
|
||||
}
|
||||
|
||||
return existingPolicy, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
||||
|
||||
@@ -1231,7 +1231,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1263,7 +1263,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1294,7 +1294,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1314,7 +1314,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1355,7 +1355,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1373,7 +1373,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
|
||||
@@ -1393,7 +1393,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -244,7 +244,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -273,7 +273,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -292,7 +292,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -395,7 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -438,7 +438,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -2,8 +2,10 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -1838,6 +1840,11 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
validatedPeers := make(map[string]struct{})
|
||||
for p := range account.Peers {
|
||||
validatedPeers[p] = struct{}{}
|
||||
}
|
||||
|
||||
t.Run("check applied policies for the route", func(t *testing.T) {
|
||||
route1 := account.Routes["route1"]
|
||||
policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups)
|
||||
@@ -1851,6 +1858,116 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups)
|
||||
assert.Len(t, policies, 0)
|
||||
})
|
||||
|
||||
t.Run("check peer routes firewall rules", func(t *testing.T) {
|
||||
routesFirewallRules := account.GetPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers)
|
||||
assert.Len(t, routesFirewallRules, 4)
|
||||
|
||||
expectedRoutesFirewallRules := []*types.RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.0.0/16",
|
||||
Protocol: "all",
|
||||
Port: 80,
|
||||
RouteID: "route1:peerA",
|
||||
},
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.0.0/16",
|
||||
Protocol: "all",
|
||||
Port: 320,
|
||||
RouteID: "route1:peerA",
|
||||
},
|
||||
}
|
||||
additionalFirewallRule := []*types.RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerJIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.10.0/16",
|
||||
Protocol: "tcp",
|
||||
Port: 80,
|
||||
RouteID: "route4:peerA",
|
||||
},
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerKIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.10.0/16",
|
||||
Protocol: "all",
|
||||
RouteID: "route4:peerA",
|
||||
},
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...)))
|
||||
|
||||
// peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
|
||||
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
|
||||
assert.Len(t, routesFirewallRules, 2)
|
||||
for _, rule := range expectedRoutesFirewallRules {
|
||||
rule.RouteID = "route1:peerD"
|
||||
}
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
|
||||
|
||||
// peerE is a single routing peer for route 2 and route 3
|
||||
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers)
|
||||
assert.Len(t, routesFirewallRules, 3)
|
||||
|
||||
expectedRoutesFirewallRules = []*types.RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
|
||||
Action: "accept",
|
||||
Destination: existingNetwork.String(),
|
||||
Protocol: "tcp",
|
||||
PortRange: types.RulePortRange{Start: 80, End: 350},
|
||||
RouteID: "route2",
|
||||
},
|
||||
{
|
||||
SourceRanges: []string{"0.0.0.0/0"},
|
||||
Action: "accept",
|
||||
Destination: "192.0.2.0/32",
|
||||
Protocol: "all",
|
||||
Domains: domain.List{"example.com"},
|
||||
IsDynamic: true,
|
||||
RouteID: "route3",
|
||||
},
|
||||
{
|
||||
SourceRanges: []string{"::/0"},
|
||||
Action: "accept",
|
||||
Destination: "192.0.2.0/32",
|
||||
Protocol: "all",
|
||||
Domains: domain.List{"example.com"},
|
||||
IsDynamic: true,
|
||||
RouteID: "route3",
|
||||
},
|
||||
}
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
|
||||
|
||||
// peerC is part of route1 distribution groups but should not receive the routes firewall rules
|
||||
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
|
||||
assert.Len(t, routesFirewallRules, 0)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// orderList is a helper function to sort a list of strings
|
||||
func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFirewallRule {
|
||||
for _, rule := range ruleList {
|
||||
sort.Strings(rule.SourceRanges)
|
||||
}
|
||||
return ruleList
|
||||
}
|
||||
|
||||
func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
@@ -1953,7 +2070,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
|
||||
@@ -1990,7 +2107,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2010,7 +2127,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2028,7 +2145,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2068,7 +2185,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2108,7 +2225,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2548,6 +2665,11 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
validatedPeers := make(map[string]struct{})
|
||||
for p := range account.Peers {
|
||||
validatedPeers[p] = struct{}{}
|
||||
}
|
||||
|
||||
t.Run("validate applied policies for different network resources", func(t *testing.T) {
|
||||
// Test case: Resource1 is directly applied to the policy (policyResource1)
|
||||
policies := account.GetPoliciesForNetworkResource("resource1")
|
||||
@@ -2571,4 +2693,127 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) {
|
||||
policies = account.GetPoliciesForNetworkResource("resource6")
|
||||
assert.Len(t, policies, 1, "resource6 should have exactly 1 policy applied via access control groups")
|
||||
})
|
||||
|
||||
t.Run("validate routing peer firewall rules for network resources", func(t *testing.T) {
|
||||
resourcePoliciesMap := account.GetResourcePoliciesMap()
|
||||
resourceRoutersMap := account.GetResourceRoutersMap()
|
||||
_, routes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), "peerA", resourcePoliciesMap, resourceRoutersMap)
|
||||
firewallRules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerA"], validatedPeers, routes, resourcePoliciesMap)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Len(t, sourcePeers, 5)
|
||||
|
||||
expectedFirewallRules := []*types.RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.0.0/16",
|
||||
Protocol: "all",
|
||||
Port: 80,
|
||||
RouteID: "resource2:peerA",
|
||||
},
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.0.0/16",
|
||||
Protocol: "all",
|
||||
Port: 320,
|
||||
RouteID: "resource2:peerA",
|
||||
},
|
||||
}
|
||||
|
||||
additionalFirewallRules := []*types.RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerJIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.0.2.0/32",
|
||||
Protocol: "tcp",
|
||||
Port: 80,
|
||||
Domains: domain.List{"example.com"},
|
||||
IsDynamic: true,
|
||||
RouteID: "resource4:peerA",
|
||||
},
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerKIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.0.2.0/32",
|
||||
Protocol: "all",
|
||||
Domains: domain.List{"example.com"},
|
||||
IsDynamic: true,
|
||||
RouteID: "resource4:peerA",
|
||||
},
|
||||
}
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(append(expectedFirewallRules, additionalFirewallRules...)))
|
||||
|
||||
// peerD is also the routing peer for resource2
|
||||
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerD", resourcePoliciesMap, resourceRoutersMap)
|
||||
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerD"], validatedPeers, routes, resourcePoliciesMap)
|
||||
assert.Len(t, firewallRules, 2)
|
||||
for _, rule := range expectedFirewallRules {
|
||||
rule.RouteID = "resource2:peerD"
|
||||
}
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
|
||||
assert.Len(t, sourcePeers, 3)
|
||||
|
||||
// peerE is a single routing peer for resource1 and resource3
|
||||
// PeerE should only receive rules for resource1 since resource3 has no applied policy
|
||||
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerE", resourcePoliciesMap, resourceRoutersMap)
|
||||
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerE"], validatedPeers, routes, resourcePoliciesMap)
|
||||
assert.Len(t, firewallRules, 1)
|
||||
assert.Len(t, sourcePeers, 2)
|
||||
|
||||
expectedFirewallRules = []*types.RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
|
||||
Action: "accept",
|
||||
Destination: "10.10.10.0/24",
|
||||
Protocol: "tcp",
|
||||
PortRange: types.RulePortRange{Start: 80, End: 350},
|
||||
RouteID: "resource1:peerE",
|
||||
},
|
||||
}
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
|
||||
|
||||
// peerC is part of distribution groups for resource2 but should not receive the firewall rules
|
||||
firewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerL is the single routing peer for resource5
|
||||
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerL", resourcePoliciesMap, resourceRoutersMap)
|
||||
assert.Len(t, routes, 1)
|
||||
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerL"], validatedPeers, routes, resourcePoliciesMap)
|
||||
assert.Len(t, firewallRules, 1)
|
||||
assert.Len(t, sourcePeers, 1)
|
||||
|
||||
expectedFirewallRules = []*types.RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{"100.65.29.67/32"},
|
||||
Action: "accept",
|
||||
Destination: "10.12.12.1/32",
|
||||
Protocol: "tcp",
|
||||
Port: 8080,
|
||||
RouteID: "resource5:peerL",
|
||||
},
|
||||
}
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
|
||||
|
||||
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerM", resourcePoliciesMap, resourceRoutersMap)
|
||||
assert.Len(t, routes, 1)
|
||||
assert.Len(t, sourcePeers, 0)
|
||||
|
||||
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerN", resourcePoliciesMap, resourceRoutersMap)
|
||||
assert.Len(t, routes, 1)
|
||||
assert.Len(t, sourcePeers, 2)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1196,6 +1196,7 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
|
||||
account.NameServerGroups[ns.ID] = &ns
|
||||
}
|
||||
account.NameServerGroupsG = nil
|
||||
account.InitOnce()
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
@@ -1634,6 +1635,7 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
if sExtraIntegratedValidatorGroups.Valid {
|
||||
_ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups)
|
||||
}
|
||||
account.InitOnce()
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
@@ -3308,7 +3310,7 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng
|
||||
|
||||
var peers []*nbpeer.Peer
|
||||
result := tx.
|
||||
Where("login_expiration_enabled = ? AND peer_status_login_expired != ? AND user_id IS NOT NULL AND user_id != ''", true, true).
|
||||
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
|
||||
Find(&peers, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error)
|
||||
|
||||
@@ -2729,7 +2729,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) {
|
||||
{
|
||||
name: "should retrieve peers for an existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 5,
|
||||
expectedCount: 4,
|
||||
},
|
||||
{
|
||||
name: "should return no peers for a non-existing account ID",
|
||||
@@ -2751,7 +2751,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) {
|
||||
name: "should filter peers by partial name",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
nameFilter: "host",
|
||||
expectedCount: 4,
|
||||
expectedCount: 3,
|
||||
},
|
||||
{
|
||||
name: "should filter peers by ip",
|
||||
@@ -2777,16 +2777,14 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
expectedCount int
|
||||
expectedPeerIDs []string
|
||||
name string
|
||||
accountID string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "should retrieve only non-expired peers with expiration enabled",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 1,
|
||||
expectedPeerIDs: []string{"notexpired01"},
|
||||
name: "should retrieve peers with expiration for an existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "should return no peers with expiration for a non-existing account ID",
|
||||
@@ -2805,30 +2803,10 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
|
||||
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, tt.accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, tt.expectedCount)
|
||||
for i, peer := range peers {
|
||||
assert.Equal(t, tt.expectedPeerIDs[i], peer.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountPeersWithExpiration_ExcludesAlreadyExpired(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the already-expired peer (cg05lnblo1hkg2j514p0) is not returned
|
||||
for _, peer := range peers {
|
||||
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should not be returned")
|
||||
assert.False(t, peer.Status.LoginExpired, "returned peers should not have LoginExpired set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
@@ -2909,7 +2887,7 @@ func TestSqlStore_GetUserPeers(t *testing.T) {
|
||||
name: "should retrieve peers for another valid account ID and user ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
userID: "edafee4e-63fb-11ec-90d6-0242ac120003",
|
||||
expectedCount: 3,
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "should return no peers for existing account ID with empty user ID",
|
||||
|
||||
@@ -193,12 +193,20 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
|
||||
}
|
||||
})
|
||||
|
||||
// Hold on to req so auth's in-place ctx update is visible after ServeHTTP.
|
||||
req := r.WithContext(ctx)
|
||||
h.ServeHTTP(w, req)
|
||||
h.ServeHTTP(w, r.WithContext(ctx))
|
||||
close(handlerDone)
|
||||
|
||||
ctx = req.Context()
|
||||
userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
|
||||
if err == nil {
|
||||
if userAuth.AccountId != "" {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId)
|
||||
}
|
||||
if userAuth.UserId != "" {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId)
|
||||
}
|
||||
}
|
||||
|
||||
if w.Status() > 399 {
|
||||
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
|
||||
|
||||
@@ -31,7 +31,6 @@ INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-3465300
|
||||
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','nVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HX=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('notexpired01','bf1c8084-ba50-4ce7-9439-34653001fc3b','oVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HY=','','"100.64.117.98"','activehost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'activehost','activehost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,1,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@@ -26,6 +27,7 @@ import (
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -108,9 +110,16 @@ type Account struct {
|
||||
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
|
||||
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
|
||||
|
||||
NetworkMapCache *NetworkMapBuilder `gorm:"-"`
|
||||
nmapInitOnce *sync.Once `gorm:"-"`
|
||||
|
||||
ReverseProxyFreeDomainNonce string
|
||||
}
|
||||
|
||||
func (a *Account) InitOnce() {
|
||||
a.nmapInitOnce = &sync.Once{}
|
||||
}
|
||||
|
||||
// this class is used by gorm only
|
||||
type PrimaryAccountInfo struct {
|
||||
IsDomainPrimaryAccount bool
|
||||
@@ -146,6 +155,108 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
|
||||
o.SignupFormPending == onboarding.SignupFormPending
|
||||
}
|
||||
|
||||
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||
// from the ACL peers that have distribution groups associated with the peer ID.
|
||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route {
|
||||
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID)
|
||||
peerRoutesMembership := make(LookupMap)
|
||||
for _, r := range append(routes, peerDisabledRoutes...) {
|
||||
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
|
||||
}
|
||||
|
||||
for _, peer := range aclPeers {
|
||||
activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID)
|
||||
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, peerGroups)
|
||||
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
|
||||
routes = append(routes, filteredRoutes...)
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership
|
||||
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
_, found := peerMemberships[string(r.GetHAUniqueID())]
|
||||
if !found {
|
||||
filteredRoutes = append(filteredRoutes, r)
|
||||
}
|
||||
}
|
||||
return filteredRoutes
|
||||
}
|
||||
|
||||
// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map
|
||||
func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
for _, groupID := range r.Groups {
|
||||
_, found := groupListMap[groupID]
|
||||
if found {
|
||||
filteredRoutes = append(filteredRoutes, r)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return filteredRoutes
|
||||
}
|
||||
|
||||
// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves
|
||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||
// If the given is not a routing peer, then the lists are empty.
|
||||
func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
|
||||
|
||||
peer := a.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id)
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
seenRoute := make(map[route.ID]struct{})
|
||||
|
||||
takeRoute := func(r *route.Route, id string) {
|
||||
if _, ok := seenRoute[r.ID]; ok {
|
||||
return
|
||||
}
|
||||
seenRoute[r.ID] = struct{}{}
|
||||
|
||||
if r.Enabled {
|
||||
r.Peer = peer.Key
|
||||
enabledRoutes = append(enabledRoutes, r)
|
||||
return
|
||||
}
|
||||
disabledRoutes = append(disabledRoutes, r)
|
||||
}
|
||||
|
||||
for _, r := range a.Routes {
|
||||
for _, groupID := range r.PeerGroups {
|
||||
group := a.GetGroup(groupID)
|
||||
if group == nil {
|
||||
log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id)
|
||||
continue
|
||||
}
|
||||
for _, id := range group.Peers {
|
||||
if id != peerID {
|
||||
continue
|
||||
}
|
||||
|
||||
newPeerRoute := r.Copy()
|
||||
newPeerRoute.Peer = id
|
||||
newPeerRoute.PeerGroups = nil
|
||||
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map
|
||||
takeRoute(newPeerRoute, id)
|
||||
break
|
||||
}
|
||||
}
|
||||
if r.Peer == peerID {
|
||||
takeRoute(r.Copy(), peerID)
|
||||
}
|
||||
}
|
||||
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
||||
func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route {
|
||||
var routes []*route.Route
|
||||
@@ -165,6 +276,106 @@ func (a *Account) GetGroup(groupID string) *Group {
|
||||
return a.Groups[groupID]
|
||||
}
|
||||
|
||||
// GetPeerNetworkMap returns the networkmap for the given peer ID.
|
||||
func (a *Account) GetPeerNetworkMap(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
validatedPeersMap map[string]struct{},
|
||||
resourcePolicies map[string][]*Policy,
|
||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
groupIDToUserIDs map[string][]string,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
peer := a.Peers[peerID]
|
||||
if peer == nil {
|
||||
return &NetworkMap{
|
||||
Network: a.Network.Copy(),
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := validatedPeersMap[peerID]; !ok {
|
||||
return &NetworkMap{
|
||||
Network: a.Network.Copy(),
|
||||
}
|
||||
}
|
||||
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
|
||||
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
|
||||
// exclude expired peers
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
var expiredPeers []*nbpeer.Peer
|
||||
for _, p := range aclPeers {
|
||||
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
|
||||
if a.Settings.PeerLoginExpirationEnabled && expired {
|
||||
expiredPeers = append(expiredPeers, p)
|
||||
continue
|
||||
}
|
||||
peersToConnect = append(peersToConnect, p)
|
||||
}
|
||||
|
||||
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect, peerGroups)
|
||||
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
|
||||
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
|
||||
var networkResourcesFirewallRules []*RouteFirewallRule
|
||||
if isRouter {
|
||||
networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies)
|
||||
}
|
||||
peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers)
|
||||
|
||||
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: dnsManagementStatus,
|
||||
}
|
||||
|
||||
if dnsManagementStatus {
|
||||
var zones []nbdns.CustomZone
|
||||
|
||||
if peersCustomZone.Domain != "" {
|
||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: peersCustomZone.Domain,
|
||||
Records: records,
|
||||
})
|
||||
}
|
||||
|
||||
filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups)
|
||||
zones = append(zones, filteredAccountZones...)
|
||||
|
||||
dnsUpdate.CustomZones = zones
|
||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||
}
|
||||
|
||||
nm := &NetworkMap{
|
||||
Peers: peersToConnectIncludingRouters,
|
||||
Network: a.Network.Copy(),
|
||||
Routes: slices.Concat(networkResourcesRoutes, routesUpdate),
|
||||
DNSConfig: dnsUpdate,
|
||||
OfflinePeers: expiredPeers,
|
||||
FirewallRules: firewallRules,
|
||||
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
EnableSSH: enableSSH,
|
||||
}
|
||||
|
||||
if metrics != nil {
|
||||
objectCount := int64(len(peersToConnectIncludingRouters) + len(expiredPeers) + len(routesUpdate) + len(networkResourcesRoutes) + len(firewallRules) + +len(networkResourcesFirewallRules) + len(routesFirewallRules))
|
||||
metrics.CountNetworkMapObjects(objectCount)
|
||||
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
if objectCount > 5000 {
|
||||
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+
|
||||
"peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d, network resources routes: %d, network resources firewall rules: %d, routes firewall rules: %d",
|
||||
a.Id, objectCount, len(peersToConnectIncludingRouters), len(expiredPeers), len(routesUpdate), len(firewallRules), len(networkResourcesRoutes), len(networkResourcesFirewallRules), len(routesFirewallRules))
|
||||
}
|
||||
}
|
||||
|
||||
return nm
|
||||
}
|
||||
|
||||
func (a *Account) addNetworksRoutingPeers(
|
||||
networkResourcesRoutes []*route.Route,
|
||||
peer *nbpeer.Peer,
|
||||
@@ -210,6 +421,39 @@ func (a *Account) addNetworksRoutingPeers(
|
||||
return peersToConnect
|
||||
}
|
||||
|
||||
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
||||
groupList := account.GetPeerGroups(peerID)
|
||||
|
||||
var peerNSGroups []*nbdns.NameServerGroup
|
||||
|
||||
for _, nsGroup := range account.NameServerGroups {
|
||||
if !nsGroup.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, gID := range nsGroup.Groups {
|
||||
_, found := groupList[gID]
|
||||
if found {
|
||||
if !peerIsNameserver(account.GetPeer(peerID), nsGroup) {
|
||||
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return peerNSGroups
|
||||
}
|
||||
|
||||
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup
|
||||
func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if peer.IP.Equal(ns.IP.AsSlice()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) {
|
||||
for _, peer := range account.Peers {
|
||||
label, err := GetPeerHostLabel(peer.Name, peerLabels)
|
||||
@@ -556,6 +800,19 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
|
||||
return grps
|
||||
}
|
||||
|
||||
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
enabled := true
|
||||
for _, groupID := range a.DNSSettings.DisabledManagementGroups {
|
||||
_, found := peerGroups[groupID]
|
||||
if found {
|
||||
enabled = false
|
||||
break
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
func (a *Account) GetPeerGroups(peerID string) LookupMap {
|
||||
groupList := make(LookupMap)
|
||||
for groupID, group := range a.Groups {
|
||||
@@ -684,6 +941,8 @@ func (a *Account) Copy() *Account {
|
||||
NetworkResources: networkResources,
|
||||
Services: services,
|
||||
Onboarding: a.Onboarding,
|
||||
NetworkMapCache: a.NetworkMapCache,
|
||||
nmapInitOnce: a.nmapInitOnce,
|
||||
Domains: domains,
|
||||
}
|
||||
}
|
||||
@@ -1045,6 +1304,31 @@ func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account.
|
||||
func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
|
||||
routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes))
|
||||
|
||||
enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID)
|
||||
for _, route := range enabledRoutes {
|
||||
// If no access control groups are specified, accept all traffic.
|
||||
if len(route.AccessControlGroups) == 0 {
|
||||
defaultPermit := getDefaultPermit(route)
|
||||
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
|
||||
continue
|
||||
}
|
||||
|
||||
distributionPeers := a.getDistributionGroupsPeers(route)
|
||||
|
||||
for _, accessGroup := range route.AccessControlGroups {
|
||||
policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup})
|
||||
rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers)
|
||||
routesFirewallRules = append(routesFirewallRules, rules...)
|
||||
}
|
||||
}
|
||||
|
||||
return routesFirewallRules
|
||||
}
|
||||
|
||||
func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule {
|
||||
var fwRules []*RouteFirewallRule
|
||||
for _, policy := range policies {
|
||||
@@ -1103,6 +1387,50 @@ func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID
|
||||
return distributionGroupPeers
|
||||
}
|
||||
|
||||
func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} {
|
||||
distPeers := make(map[string]struct{})
|
||||
for _, id := range route.Groups {
|
||||
group := a.Groups[id]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pID := range group.Peers {
|
||||
distPeers[pID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return distPeers
|
||||
}
|
||||
|
||||
func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
|
||||
var rules []*RouteFirewallRule
|
||||
|
||||
sources := []string{"0.0.0.0/0"}
|
||||
if route.Network.Addr().Is6() {
|
||||
sources = []string{"::/0"}
|
||||
}
|
||||
rule := RouteFirewallRule{
|
||||
SourceRanges: sources,
|
||||
Action: string(PolicyTrafficActionAccept),
|
||||
Destination: route.Network.String(),
|
||||
Protocol: string(PolicyRuleProtocolALL),
|
||||
Domains: route.Domains,
|
||||
IsDynamic: route.IsDynamic(),
|
||||
RouteID: route.ID,
|
||||
}
|
||||
|
||||
rules = append(rules, &rule)
|
||||
|
||||
// dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally
|
||||
if route.IsDynamic() {
|
||||
ruleV6 := rule
|
||||
ruleV6.SourceRanges = []string{"::/0"}
|
||||
rules = append(rules, &ruleV6)
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
// GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups
|
||||
// and returns a list of policies that have rules with destinations matching the specified groups.
|
||||
func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy {
|
||||
@@ -1180,6 +1508,65 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
|
||||
return resourcePolicies
|
||||
}
|
||||
|
||||
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
|
||||
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
|
||||
var isRoutingPeer bool
|
||||
var routes []*route.Route
|
||||
allSourcePeers := make(map[string]struct{}, len(a.Peers))
|
||||
|
||||
for _, resource := range a.NetworkResources {
|
||||
if !resource.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
var addSourcePeers bool
|
||||
|
||||
networkRoutingPeers, exists := routers[resource.NetworkID]
|
||||
if exists {
|
||||
if router, ok := networkRoutingPeers[peerID]; ok {
|
||||
isRoutingPeer, addSourcePeers = true, true
|
||||
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerID, router, resourcePolicies)...)
|
||||
}
|
||||
}
|
||||
|
||||
addedResourceRoute := false
|
||||
for _, policy := range resourcePolicies[resource.ID] {
|
||||
var peers []string
|
||||
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
|
||||
peers = []string{policy.Rules[0].SourceResource.ID}
|
||||
} else {
|
||||
peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
|
||||
}
|
||||
if addSourcePeers {
|
||||
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
|
||||
allSourcePeers[pID] = struct{}{}
|
||||
}
|
||||
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
|
||||
// add routes for the resource if the peer is in the distribution group
|
||||
for peerId, router := range networkRoutingPeers {
|
||||
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
|
||||
}
|
||||
addedResourceRoute = true
|
||||
}
|
||||
if addedResourceRoute {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return isRoutingPeer, routes, allSourcePeers
|
||||
}
|
||||
|
||||
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
|
||||
var dest []string
|
||||
for _, peerID := range inputPeers {
|
||||
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
|
||||
dest = append(dest, peerID)
|
||||
}
|
||||
}
|
||||
return dest
|
||||
}
|
||||
|
||||
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
|
||||
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
|
||||
for _, groupID := range groups {
|
||||
@@ -1271,6 +1658,22 @@ func (a *Account) GetPoliciesAppliedInNetwork(networkID string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// getNetworkResourcesRoutes convert the network resources list to routes list.
|
||||
func (a *Account) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerId string, router *routerTypes.NetworkRouter, resourcePolicies map[string][]*Policy) []*route.Route {
|
||||
resourceAppliedPolicies := resourcePolicies[resource.ID]
|
||||
|
||||
var routes []*route.Route
|
||||
// distribute the resource routes only if there is policy applied to it
|
||||
if len(resourceAppliedPolicies) > 0 {
|
||||
peer := a.GetPeer(peerId)
|
||||
if peer != nil {
|
||||
routes = append(routes, resource.ToRoute(peer, router))
|
||||
}
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter {
|
||||
routers := make(map[string]map[string]*routerTypes.NetworkRouter)
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -17,6 +19,7 @@ import (
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -448,6 +451,402 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) {
|
||||
require.Len(t, result, 0)
|
||||
}
|
||||
|
||||
const (
|
||||
accID = "accountID"
|
||||
network1ID = "network1ID"
|
||||
group1ID = "group1"
|
||||
accNetResourcePeer1ID = "peer1"
|
||||
accNetResourcePeer2ID = "peer2"
|
||||
accNetResourceRouter1ID = "router1"
|
||||
accNetResource1ID = "resource1ID"
|
||||
accNetResourceRestrictPostureCheckID = "restrictPostureCheck"
|
||||
accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck"
|
||||
accNetResourceLockedPostureCheckID = "lockedPostureCheck"
|
||||
accNetResourceLinuxPostureCheckID = "linuxPostureCheck"
|
||||
)
|
||||
|
||||
var (
|
||||
accNetResourcePeer1IP = net.IP{192, 168, 1, 1}
|
||||
accNetResourcePeer2IP = net.IP{192, 168, 1, 2}
|
||||
accNetResourceRouter1IP = net.IP{192, 168, 1, 3}
|
||||
accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}}
|
||||
)
|
||||
|
||||
func getBasicAccountsWithResource() *Account {
|
||||
return &Account{
|
||||
Id: accID,
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
accNetResourcePeer1ID: {
|
||||
ID: accNetResourcePeer1ID,
|
||||
AccountID: accID,
|
||||
Key: "peer1Key",
|
||||
IP: accNetResourcePeer1IP,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
GoOS: "linux",
|
||||
WtVersion: "0.35.1",
|
||||
KernelVersion: "4.4.0",
|
||||
},
|
||||
},
|
||||
accNetResourcePeer2ID: {
|
||||
ID: accNetResourcePeer2ID,
|
||||
AccountID: accID,
|
||||
Key: "peer2Key",
|
||||
IP: accNetResourcePeer2IP,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
GoOS: "windows",
|
||||
WtVersion: "0.34.1",
|
||||
KernelVersion: "4.4.0",
|
||||
},
|
||||
},
|
||||
accNetResourceRouter1ID: {
|
||||
ID: accNetResourceRouter1ID,
|
||||
AccountID: accID,
|
||||
Key: "router1Key",
|
||||
IP: accNetResourceRouter1IP,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
GoOS: "linux",
|
||||
WtVersion: "0.35.1",
|
||||
KernelVersion: "4.4.0",
|
||||
},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
group1ID: {
|
||||
ID: group1ID,
|
||||
Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID},
|
||||
},
|
||||
},
|
||||
Networks: []*networkTypes.Network{
|
||||
{
|
||||
ID: network1ID,
|
||||
AccountID: accID,
|
||||
Name: "network1",
|
||||
},
|
||||
},
|
||||
NetworkRouters: []*routerTypes.NetworkRouter{
|
||||
{
|
||||
ID: accNetResourceRouter1ID,
|
||||
NetworkID: network1ID,
|
||||
AccountID: accID,
|
||||
Peer: accNetResourceRouter1ID,
|
||||
PeerGroups: []string{},
|
||||
Masquerade: false,
|
||||
Metric: 100,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
NetworkResources: []*resourceTypes.NetworkResource{
|
||||
{
|
||||
ID: accNetResource1ID,
|
||||
AccountID: accID,
|
||||
NetworkID: network1ID,
|
||||
Address: "10.10.10.0/24",
|
||||
Prefix: netip.MustParsePrefix("10.10.10.0/24"),
|
||||
Type: resourceTypes.NetworkResourceType("subnet"),
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
Policies: []*Policy{
|
||||
{
|
||||
ID: "policy1ID",
|
||||
AccountID: accID,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "rule1ID",
|
||||
Enabled: true,
|
||||
Sources: []string{group1ID},
|
||||
DestinationResource: Resource{
|
||||
ID: accNetResource1ID,
|
||||
Type: "Host",
|
||||
},
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Ports: []string{"80"},
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: nil,
|
||||
},
|
||||
},
|
||||
PostureChecks: []*posture.Checks{
|
||||
{
|
||||
ID: accNetResourceRestrictPostureCheckID,
|
||||
Name: accNetResourceRestrictPostureCheckID,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.35.0",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: accNetResourceRelaxedPostureCheckID,
|
||||
Name: accNetResourceRelaxedPostureCheckID,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.0.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: accNetResourceLockedPostureCheckID,
|
||||
Name: accNetResourceLockedPostureCheckID,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "7.7.7",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: accNetResourceLinuxPostureCheckID,
|
||||
Name: accNetResourceLinuxPostureCheckID,
|
||||
Checks: posture.ChecksDefinition{
|
||||
OSVersionCheck: &posture.OSVersionCheck{
|
||||
Linux: &posture.MinKernelVersionCheck{
|
||||
MinKernelVersion: "0.0.0"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) {
|
||||
account := getBasicAccountsWithResource()
|
||||
|
||||
// all peers should match the policy
|
||||
|
||||
// validate for peer1
|
||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate for peer2
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate routes for router1
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.True(t, isRouter, "should be router")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
|
||||
|
||||
// validate rules for router1
|
||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||
assert.Len(t, rules, 1, "expected rules count don't match")
|
||||
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
||||
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
||||
}
|
||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) {
|
||||
account := getBasicAccountsWithResource()
|
||||
|
||||
// should allow peer1 to match the policy
|
||||
policy := account.Policies[0]
|
||||
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
|
||||
|
||||
// validate for peer1
|
||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate for peer2
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate routes for router1
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.True(t, isRouter, "should be router")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
||||
|
||||
// validate rules for router1
|
||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||
assert.Len(t, rules, 1, "expected rules count don't match")
|
||||
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
||||
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
||||
}
|
||||
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) {
|
||||
account := getBasicAccountsWithResource()
|
||||
|
||||
// should not match any peer
|
||||
policy := account.Policies[0]
|
||||
policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID}
|
||||
|
||||
// validate for peer1
|
||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate for peer2
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate routes for router1
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.True(t, isRouter, "should be router")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate rules for router1
|
||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||
assert.Len(t, rules, 0, "expected rules count don't match")
|
||||
}
|
||||
|
||||
func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) {
|
||||
account := getBasicAccountsWithResource()
|
||||
|
||||
// should allow peer1 to match the policy
|
||||
policy := account.Policies[0]
|
||||
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
|
||||
|
||||
// should allow peer1 and peer2 to match the policy
|
||||
newPolicy := &Policy{
|
||||
ID: "policy2ID",
|
||||
AccountID: accID,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "policy2ID",
|
||||
Enabled: true,
|
||||
Sources: []string{group1ID},
|
||||
DestinationResource: Resource{
|
||||
ID: accNetResource1ID,
|
||||
Type: "Host",
|
||||
},
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Ports: []string{"22"},
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID},
|
||||
}
|
||||
|
||||
account.Policies = append(account.Policies, newPolicy)
|
||||
|
||||
// validate for peer1
|
||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate for peer2
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate routes for router1
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.True(t, isRouter, "should be router")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
|
||||
|
||||
// validate rules for router1
|
||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||
assert.Len(t, rules, 2, "expected rules count don't match")
|
||||
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
||||
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
||||
}
|
||||
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
||||
}
|
||||
|
||||
assert.Equal(t, uint16(22), rules[1].Port, "should have port 22")
|
||||
assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp")
|
||||
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String())
|
||||
}
|
||||
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) {
|
||||
account := getBasicAccountsWithResource()
|
||||
|
||||
// two posture checks should match only the peers that match both checks
|
||||
policy := account.Policies[0]
|
||||
policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID}
|
||||
|
||||
// validate for peer1
|
||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate for peer2
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate routes for router1
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.True(t, isRouter, "should be router")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
||||
|
||||
// validate rules for router1
|
||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||
assert.Len(t, rules, 1, "expected rules count don't match")
|
||||
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
||||
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
||||
}
|
||||
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) {
|
||||
account := getBasicAccountsWithResource()
|
||||
|
||||
account.Peers["router2Id"] = &nbpeer.Peer{Key: "router2Key", ID: "router2Id", AccountID: accID, IP: net.IP{192, 168, 1, 4}}
|
||||
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
|
||||
ID: "router2Id",
|
||||
NetworkID: network1ID,
|
||||
AccountID: accID,
|
||||
Peer: "router2Id",
|
||||
})
|
||||
|
||||
// validate routes for router1
|
||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.True(t, isRouter, "should be router")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
|
||||
}
|
||||
|
||||
func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
47
management/server/types/holder.go
Normal file
47
management/server/types/holder.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Holder struct {
|
||||
mu sync.RWMutex
|
||||
accounts map[string]*Account
|
||||
}
|
||||
|
||||
func NewHolder() *Holder {
|
||||
return &Holder{
|
||||
accounts: make(map[string]*Account),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Holder) GetAccount(id string) *Account {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.accounts[id]
|
||||
}
|
||||
|
||||
func (h *Holder) AddAccount(account *Account) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
a := h.accounts[account.Id]
|
||||
if a != nil && a.Network.CurrentSerial() >= account.Network.CurrentSerial() {
|
||||
return
|
||||
}
|
||||
h.accounts[account.Id] = account
|
||||
}
|
||||
|
||||
func (h *Holder) LoadOrStoreFunc(ctx context.Context, id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if acc, ok := h.accounts[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
account, err := accGetter(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.accounts[id] = account
|
||||
return account, nil
|
||||
}
|
||||
@@ -39,8 +39,6 @@ const (
|
||||
IdentityProviderTypeAuthentik IdentityProviderType = "authentik"
|
||||
// IdentityProviderTypeKeycloak is the Keycloak identity provider
|
||||
IdentityProviderTypeKeycloak IdentityProviderType = "keycloak"
|
||||
// IdentityProviderTypeADFS is the Microsoft AD FS identity provider
|
||||
IdentityProviderTypeADFS IdentityProviderType = "adfs"
|
||||
)
|
||||
|
||||
// IdentityProvider represents an identity provider configuration
|
||||
@@ -114,8 +112,7 @@ func (t IdentityProviderType) IsValid() bool {
|
||||
switch t {
|
||||
case IdentityProviderTypeOIDC, IdentityProviderTypeZitadel, IdentityProviderTypeEntra,
|
||||
IdentityProviderTypeGoogle, IdentityProviderTypeOkta, IdentityProviderTypePocketID,
|
||||
IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak,
|
||||
IdentityProviderTypeADFS:
|
||||
IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
67
management/server/types/networkmap.go
Normal file
67
management/server/types/networkmap.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) {
|
||||
if a.NetworkMapCache != nil {
|
||||
return
|
||||
}
|
||||
a.nmapInitOnce.Do(func() {
|
||||
a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) {
|
||||
a.initNetworkMapBuilder(validatedPeers)
|
||||
}
|
||||
|
||||
func (a *Account) GetPeerNetworkMapExp(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
validatedPeers map[string]struct{},
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
a.initNetworkMapBuilder(validatedPeers)
|
||||
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, accountZones, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
|
||||
if a.NetworkMapCache == nil {
|
||||
return nil
|
||||
}
|
||||
return a.NetworkMapCache.OnPeerAddedIncremental(a, peerId)
|
||||
}
|
||||
|
||||
func (a *Account) OnPeersAddedUpdNetworkMapCache(peerIds ...string) {
|
||||
if a.NetworkMapCache == nil {
|
||||
return
|
||||
}
|
||||
a.NetworkMapCache.EnqueuePeersForIncrementalAdd(a, peerIds...)
|
||||
}
|
||||
|
||||
func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error {
|
||||
if a.NetworkMapCache == nil {
|
||||
return nil
|
||||
}
|
||||
return a.NetworkMapCache.OnPeerDeleted(a, peerId)
|
||||
}
|
||||
|
||||
func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) {
|
||||
if a.NetworkMapCache == nil {
|
||||
return
|
||||
}
|
||||
a.NetworkMapCache.UpdatePeer(peer)
|
||||
}
|
||||
|
||||
func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) {
|
||||
a.initNetworkMapBuilder(validatedPeers)
|
||||
}
|
||||
592
management/server/types/networkmap_comparison_test.go
Normal file
592
management/server/types/networkmap_comparison_test.go
Normal file
@@ -0,0 +1,592 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) {
|
||||
account := createTestAccount()
|
||||
ctx := context.Background()
|
||||
|
||||
peerID := testingPeerID
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
pid := fmt.Sprintf("peer-%d", i)
|
||||
if pid == offlinePeerID {
|
||||
continue
|
||||
}
|
||||
validatedPeersMap[pid] = struct{}{}
|
||||
}
|
||||
|
||||
peersCustomZone := nbdns.CustomZone{}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(
|
||||
ctx,
|
||||
peerID,
|
||||
peersCustomZone,
|
||||
nil,
|
||||
validatedPeersMap,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
nil,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
components := account.GetPeerNetworkMapComponents(
|
||||
ctx,
|
||||
peerID,
|
||||
peersCustomZone,
|
||||
nil,
|
||||
validatedPeersMap,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
if components == nil {
|
||||
t.Fatal("GetPeerNetworkMapComponents returned nil")
|
||||
}
|
||||
|
||||
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
|
||||
|
||||
if newNetworkMap == nil {
|
||||
t.Fatal("CalculateNetworkMapFromComponents returned nil")
|
||||
}
|
||||
|
||||
compareNetworkMaps(t, legacyNetworkMap, newNetworkMap)
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) {
|
||||
account := createTestAccount()
|
||||
ctx := context.Background()
|
||||
|
||||
peerID := testingPeerID
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
pid := fmt.Sprintf("peer-%d", i)
|
||||
if pid == offlinePeerID {
|
||||
continue
|
||||
}
|
||||
validatedPeersMap[pid] = struct{}{}
|
||||
}
|
||||
|
||||
peersCustomZone := nbdns.CustomZone{}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(
|
||||
ctx,
|
||||
peerID,
|
||||
peersCustomZone,
|
||||
nil,
|
||||
validatedPeersMap,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
nil,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
components := account.GetPeerNetworkMapComponents(
|
||||
ctx,
|
||||
peerID,
|
||||
peersCustomZone,
|
||||
nil,
|
||||
validatedPeersMap,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil")
|
||||
|
||||
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
|
||||
require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil")
|
||||
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
|
||||
componentsJSON, err := json.MarshalIndent(components, "", " ")
|
||||
require.NoError(t, err, "error marshaling components to JSON")
|
||||
|
||||
legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
|
||||
newJSON, err := json.MarshalIndent(newNetworkMap, "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
|
||||
goldenDir := filepath.Join("testdata", "comparison")
|
||||
err = os.MkdirAll(goldenDir, 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
legacyGoldenPath := filepath.Join(goldenDir, "legacy_networkmap.json")
|
||||
err = os.WriteFile(legacyGoldenPath, legacyJSON, 0644)
|
||||
require.NoError(t, err, "error writing legacy golden file")
|
||||
|
||||
newGoldenPath := filepath.Join(goldenDir, "components_networkmap.json")
|
||||
err = os.WriteFile(newGoldenPath, newJSON, 0644)
|
||||
require.NoError(t, err, "error writing components golden file")
|
||||
|
||||
componentsPath := filepath.Join(goldenDir, "components.json")
|
||||
err = os.WriteFile(componentsPath, componentsJSON, 0644)
|
||||
require.NoError(t, err, "error writing components golden file")
|
||||
|
||||
require.JSONEq(t, string(legacyJSON), string(newJSON),
|
||||
"NetworkMaps from legacy and components approaches do not match.\n"+
|
||||
"Legacy JSON saved to: %s\n"+
|
||||
"Components JSON saved to: %s",
|
||||
legacyGoldenPath, newGoldenPath)
|
||||
|
||||
t.Logf("✅ NetworkMaps are identical")
|
||||
t.Logf(" Legacy NetworkMap: %s", legacyGoldenPath)
|
||||
t.Logf(" Components NetworkMap: %s", newGoldenPath)
|
||||
}
|
||||
|
||||
func normalizeAndSortNetworkMap(nm *NetworkMap) {
|
||||
if nm == nil {
|
||||
return
|
||||
}
|
||||
|
||||
sort.Slice(nm.Peers, func(i, j int) bool {
|
||||
return nm.Peers[i].ID < nm.Peers[j].ID
|
||||
})
|
||||
|
||||
sort.Slice(nm.OfflinePeers, func(i, j int) bool {
|
||||
return nm.OfflinePeers[i].ID < nm.OfflinePeers[j].ID
|
||||
})
|
||||
|
||||
sort.Slice(nm.Routes, func(i, j int) bool {
|
||||
return string(nm.Routes[i].ID) < string(nm.Routes[j].ID)
|
||||
})
|
||||
|
||||
sort.Slice(nm.FirewallRules, func(i, j int) bool {
|
||||
if nm.FirewallRules[i].PeerIP != nm.FirewallRules[j].PeerIP {
|
||||
return nm.FirewallRules[i].PeerIP < nm.FirewallRules[j].PeerIP
|
||||
}
|
||||
if nm.FirewallRules[i].Direction != nm.FirewallRules[j].Direction {
|
||||
return nm.FirewallRules[i].Direction < nm.FirewallRules[j].Direction
|
||||
}
|
||||
if nm.FirewallRules[i].Protocol != nm.FirewallRules[j].Protocol {
|
||||
return nm.FirewallRules[i].Protocol < nm.FirewallRules[j].Protocol
|
||||
}
|
||||
if nm.FirewallRules[i].Port != nm.FirewallRules[j].Port {
|
||||
return nm.FirewallRules[i].Port < nm.FirewallRules[j].Port
|
||||
}
|
||||
return nm.FirewallRules[i].PolicyID < nm.FirewallRules[j].PolicyID
|
||||
})
|
||||
|
||||
for i := range nm.RoutesFirewallRules {
|
||||
sort.Strings(nm.RoutesFirewallRules[i].SourceRanges)
|
||||
}
|
||||
|
||||
sort.Slice(nm.RoutesFirewallRules, func(i, j int) bool {
|
||||
if nm.RoutesFirewallRules[i].Destination != nm.RoutesFirewallRules[j].Destination {
|
||||
return nm.RoutesFirewallRules[i].Destination < nm.RoutesFirewallRules[j].Destination
|
||||
}
|
||||
|
||||
minLen := len(nm.RoutesFirewallRules[i].SourceRanges)
|
||||
if len(nm.RoutesFirewallRules[j].SourceRanges) < minLen {
|
||||
minLen = len(nm.RoutesFirewallRules[j].SourceRanges)
|
||||
}
|
||||
for k := 0; k < minLen; k++ {
|
||||
if nm.RoutesFirewallRules[i].SourceRanges[k] != nm.RoutesFirewallRules[j].SourceRanges[k] {
|
||||
return nm.RoutesFirewallRules[i].SourceRanges[k] < nm.RoutesFirewallRules[j].SourceRanges[k]
|
||||
}
|
||||
}
|
||||
if len(nm.RoutesFirewallRules[i].SourceRanges) != len(nm.RoutesFirewallRules[j].SourceRanges) {
|
||||
return len(nm.RoutesFirewallRules[i].SourceRanges) < len(nm.RoutesFirewallRules[j].SourceRanges)
|
||||
}
|
||||
|
||||
if string(nm.RoutesFirewallRules[i].RouteID) != string(nm.RoutesFirewallRules[j].RouteID) {
|
||||
return string(nm.RoutesFirewallRules[i].RouteID) < string(nm.RoutesFirewallRules[j].RouteID)
|
||||
}
|
||||
|
||||
if nm.RoutesFirewallRules[i].PolicyID != nm.RoutesFirewallRules[j].PolicyID {
|
||||
return nm.RoutesFirewallRules[i].PolicyID < nm.RoutesFirewallRules[j].PolicyID
|
||||
}
|
||||
|
||||
if nm.RoutesFirewallRules[i].Port != nm.RoutesFirewallRules[j].Port {
|
||||
return nm.RoutesFirewallRules[i].Port < nm.RoutesFirewallRules[j].Port
|
||||
}
|
||||
|
||||
return nm.RoutesFirewallRules[i].Protocol < nm.RoutesFirewallRules[j].Protocol
|
||||
})
|
||||
|
||||
if nm.DNSConfig.CustomZones != nil {
|
||||
for i := range nm.DNSConfig.CustomZones {
|
||||
sort.Slice(nm.DNSConfig.CustomZones[i].Records, func(a, b int) bool {
|
||||
return nm.DNSConfig.CustomZones[i].Records[a].Name < nm.DNSConfig.CustomZones[i].Records[b].Name
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(nm.DNSConfig.NameServerGroups) != 0 {
|
||||
sort.Slice(nm.DNSConfig.NameServerGroups, func(a, b int) bool {
|
||||
return nm.DNSConfig.NameServerGroups[a].Name < nm.DNSConfig.NameServerGroups[b].Name
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func compareNetworkMaps(t *testing.T, legacy, current *NetworkMap) {
|
||||
t.Helper()
|
||||
|
||||
if legacy.Network.Serial != current.Network.Serial {
|
||||
t.Errorf("Network Serial mismatch: legacy=%d, current=%d", legacy.Network.Serial, current.Network.Serial)
|
||||
}
|
||||
|
||||
if len(legacy.Peers) != len(current.Peers) {
|
||||
t.Errorf("Peers count mismatch: legacy=%d, current=%d", len(legacy.Peers), len(current.Peers))
|
||||
}
|
||||
|
||||
legacyPeerIDs := make(map[string]bool)
|
||||
for _, p := range legacy.Peers {
|
||||
legacyPeerIDs[p.ID] = true
|
||||
}
|
||||
|
||||
for _, p := range current.Peers {
|
||||
if !legacyPeerIDs[p.ID] {
|
||||
t.Errorf("Current NetworkMap contains peer %s not in legacy", p.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if len(legacy.OfflinePeers) != len(current.OfflinePeers) {
|
||||
t.Errorf("OfflinePeers count mismatch: legacy=%d, current=%d", len(legacy.OfflinePeers), len(current.OfflinePeers))
|
||||
}
|
||||
|
||||
if len(legacy.FirewallRules) != len(current.FirewallRules) {
|
||||
t.Logf("FirewallRules count mismatch: legacy=%d, current=%d", len(legacy.FirewallRules), len(current.FirewallRules))
|
||||
}
|
||||
|
||||
if len(legacy.Routes) != len(current.Routes) {
|
||||
t.Logf("Routes count mismatch: legacy=%d, current=%d", len(legacy.Routes), len(current.Routes))
|
||||
}
|
||||
|
||||
if len(legacy.RoutesFirewallRules) != len(current.RoutesFirewallRules) {
|
||||
t.Logf("RoutesFirewallRules count mismatch: legacy=%d, current=%d", len(legacy.RoutesFirewallRules), len(current.RoutesFirewallRules))
|
||||
}
|
||||
|
||||
if legacy.DNSConfig.ServiceEnable != current.DNSConfig.ServiceEnable {
|
||||
t.Errorf("DNSConfig.ServiceEnable mismatch: legacy=%v, current=%v", legacy.DNSConfig.ServiceEnable, current.DNSConfig.ServiceEnable)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
numPeers = 100
|
||||
devGroupID = "group-dev"
|
||||
opsGroupID = "group-ops"
|
||||
allGroupID = "group-all"
|
||||
routeID = route.ID("route-main")
|
||||
routeHA1ID = route.ID("route-ha-1")
|
||||
routeHA2ID = route.ID("route-ha-2")
|
||||
policyIDDevOps = "policy-dev-ops"
|
||||
policyIDAll = "policy-all"
|
||||
policyIDPosture = "policy-posture"
|
||||
policyIDDrop = "policy-drop"
|
||||
postureCheckID = "posture-check-ver"
|
||||
networkResourceID = "res-database"
|
||||
networkID = "net-database"
|
||||
networkRouterID = "router-database"
|
||||
nameserverGroupID = "ns-group-main"
|
||||
testingPeerID = "peer-60"
|
||||
expiredPeerID = "peer-98"
|
||||
offlinePeerID = "peer-99"
|
||||
routingPeerID = "peer-95"
|
||||
testAccountID = "account-comparison-test"
|
||||
)
|
||||
|
||||
func createTestAccount() *Account {
|
||||
peers := make(map[string]*nbpeer.Peer)
|
||||
devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
|
||||
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
ip := net.IP{100, 64, 0, byte(i + 1)}
|
||||
wtVersion := "0.25.0"
|
||||
if i%2 == 0 {
|
||||
wtVersion = "0.40.0"
|
||||
}
|
||||
|
||||
p := &nbpeer.Peer{
|
||||
ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1),
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"},
|
||||
}
|
||||
|
||||
if peerID == expiredPeerID {
|
||||
p.LoginExpirationEnabled = true
|
||||
pastTimestamp := time.Now().Add(-2 * time.Hour)
|
||||
p.LastLogin = &pastTimestamp
|
||||
}
|
||||
|
||||
peers[peerID] = p
|
||||
allGroupPeers = append(allGroupPeers, peerID)
|
||||
if i < numPeers/2 {
|
||||
devGroupPeers = append(devGroupPeers, peerID)
|
||||
} else {
|
||||
opsGroupPeers = append(opsGroupPeers, peerID)
|
||||
}
|
||||
}
|
||||
|
||||
groups := map[string]*Group{
|
||||
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
|
||||
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
|
||||
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
|
||||
}
|
||||
|
||||
policies := []*Policy{
|
||||
{
|
||||
ID: policyIDAll, Name: "Default-Allow", Enabled: true,
|
||||
Rules: []*PolicyRule{{
|
||||
ID: policyIDAll, Name: "Allow All", Enabled: true, Action: PolicyTrafficActionAccept,
|
||||
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
|
||||
Sources: []string{allGroupID}, Destinations: []string{allGroupID},
|
||||
}},
|
||||
},
|
||||
{
|
||||
ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true,
|
||||
Rules: []*PolicyRule{{
|
||||
ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: PolicyTrafficActionAccept,
|
||||
Protocol: PolicyRuleProtocolTCP, Bidirectional: false,
|
||||
PortRanges: []RulePortRange{{Start: 8080, End: 8090}},
|
||||
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
|
||||
}},
|
||||
},
|
||||
{
|
||||
ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true,
|
||||
Rules: []*PolicyRule{{
|
||||
ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: PolicyTrafficActionDrop,
|
||||
Protocol: PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true,
|
||||
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
|
||||
}},
|
||||
},
|
||||
{
|
||||
ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true,
|
||||
SourcePostureChecks: []string{postureCheckID},
|
||||
Rules: []*PolicyRule{{
|
||||
ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: PolicyTrafficActionAccept,
|
||||
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
|
||||
Sources: []string{opsGroupID}, DestinationResource: Resource{ID: networkResourceID},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
routes := map[route.ID]*route.Route{
|
||||
routeID: {
|
||||
ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"),
|
||||
Peer: peers["peer-75"].Key,
|
||||
PeerID: "peer-75",
|
||||
Description: "Route to internal resource", Enabled: true,
|
||||
PeerGroups: []string{devGroupID, opsGroupID},
|
||||
Groups: []string{devGroupID, opsGroupID},
|
||||
AccessControlGroups: []string{devGroupID},
|
||||
},
|
||||
routeHA1ID: {
|
||||
ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
|
||||
Peer: peers["peer-80"].Key,
|
||||
PeerID: "peer-80",
|
||||
Description: "HA Route 1", Enabled: true, Metric: 1000,
|
||||
PeerGroups: []string{allGroupID},
|
||||
Groups: []string{allGroupID},
|
||||
AccessControlGroups: []string{allGroupID},
|
||||
},
|
||||
routeHA2ID: {
|
||||
ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
|
||||
Peer: peers["peer-90"].Key,
|
||||
PeerID: "peer-90",
|
||||
Description: "HA Route 2", Enabled: true, Metric: 900,
|
||||
PeerGroups: []string{devGroupID, opsGroupID},
|
||||
Groups: []string{devGroupID, opsGroupID},
|
||||
AccessControlGroups: []string{allGroupID},
|
||||
},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
|
||||
Network: &Network{
|
||||
Identifier: "net-comparison-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
|
||||
},
|
||||
DNSSettings: DNSSettings{DisabledManagementGroups: []string{opsGroupID}},
|
||||
NameServerGroups: map[string]*nbdns.NameServerGroup{
|
||||
nameserverGroupID: {
|
||||
ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID},
|
||||
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}},
|
||||
},
|
||||
},
|
||||
PostureChecks: []*posture.Checks{
|
||||
{ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
|
||||
}},
|
||||
},
|
||||
NetworkResources: []*resourceTypes.NetworkResource{
|
||||
{ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"},
|
||||
},
|
||||
Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}},
|
||||
NetworkRouters: []*routerTypes.NetworkRouter{
|
||||
{ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID},
|
||||
},
|
||||
Settings: &Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour},
|
||||
}
|
||||
|
||||
for _, p := range account.Policies {
|
||||
p.AccountID = account.Id
|
||||
}
|
||||
for _, r := range account.Routes {
|
||||
r.AccountID = account.Id
|
||||
}
|
||||
|
||||
return account
|
||||
}
|
||||
|
||||
func BenchmarkLegacyNetworkMap(b *testing.B) {
|
||||
account := createTestAccount()
|
||||
ctx := context.Background()
|
||||
peerID := testingPeerID
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
pid := fmt.Sprintf("peer-%d", i)
|
||||
if pid != offlinePeerID {
|
||||
validatedPeersMap[pid] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
peersCustomZone := nbdns.CustomZone{}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = account.GetPeerNetworkMap(
|
||||
ctx,
|
||||
peerID,
|
||||
peersCustomZone,
|
||||
nil,
|
||||
validatedPeersMap,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
nil,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkComponentsNetworkMap(b *testing.B) {
|
||||
account := createTestAccount()
|
||||
ctx := context.Background()
|
||||
peerID := testingPeerID
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
pid := fmt.Sprintf("peer-%d", i)
|
||||
if pid != offlinePeerID {
|
||||
validatedPeersMap[pid] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
peersCustomZone := nbdns.CustomZone{}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
components := account.GetPeerNetworkMapComponents(
|
||||
ctx,
|
||||
peerID,
|
||||
peersCustomZone,
|
||||
nil,
|
||||
validatedPeersMap,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
_ = CalculateNetworkMapFromComponents(ctx, components)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkComponentsCreation(b *testing.B) {
|
||||
account := createTestAccount()
|
||||
ctx := context.Background()
|
||||
peerID := testingPeerID
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
pid := fmt.Sprintf("peer-%d", i)
|
||||
if pid != offlinePeerID {
|
||||
validatedPeersMap[pid] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
peersCustomZone := nbdns.CustomZone{}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = account.GetPeerNetworkMapComponents(
|
||||
ctx,
|
||||
peerID,
|
||||
peersCustomZone,
|
||||
nil,
|
||||
validatedPeersMap,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCalculationFromComponents(b *testing.B) {
|
||||
account := createTestAccount()
|
||||
ctx := context.Background()
|
||||
peerID := testingPeerID
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
pid := fmt.Sprintf("peer-%d", i)
|
||||
if pid != offlinePeerID {
|
||||
validatedPeersMap[pid] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
peersCustomZone := nbdns.CustomZone{}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
components := account.GetPeerNetworkMapComponents(
|
||||
ctx,
|
||||
peerID,
|
||||
peersCustomZone,
|
||||
nil,
|
||||
validatedPeersMap,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = CalculateNetworkMapFromComponents(ctx, components)
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user