Compare commits

...

44 Commits

Author SHA1 Message Date
crn4
f357db6ffa Merge branch 'main' into nmap/compaction-deploy 2026-01-22 19:01:58 +01:00
crn4
0297976875 compaction of components 2026-01-22 19:01:18 +01:00
Bethuel Mmbaga
a1de2b8a98 [management] Move activity store encryption to shared crypt package (#5111) 2026-01-22 15:01:13 +03:00
Viktor Liu
d0221a3e72 [client] Add cpu profile to debug bundle (#4700) 2026-01-22 12:24:12 +01:00
Bethuel Mmbaga
8da23daae3 [management] Fix activity event initiator for user group changes (#5152) 2026-01-22 14:18:46 +03:00
Viktor Liu
f86022eace [client] Hide forwarding rules in status when count is zero (#5149)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 10:01:08 +01:00
Viktor Liu
ee54827f94 [client] Add IPv6 support to usersace bind (#5147) 2026-01-22 10:20:43 +08:00
crn4
b32cf57ed1 minor fixes on diffs 2026-01-21 17:12:34 +01:00
crn4
ad025e452b Merge branch 'main' into nmap/compaction-deploy 2026-01-21 16:14:18 +01:00
crn4
b533c173a4 minor fixes on diffs 2026-01-21 16:13:52 +01:00
Zoltan Papp
e908dea702 [client] Extend WG watcher for ICE connection too (#5133)
Extend WG watcher for ICE connection too
2026-01-21 10:42:13 +01:00
Maycon Santos
030650a905 [client] Fix RFC 4592 wildcard matching for existing domain names (#5145)
Per RFC 4592 section 2.2.1, wildcards should only match when the queried
name does not exist in the zone. Previously, if host.example.com had an
A record and *.example.com had an AAAA record, querying AAAA for
host.example.com would incorrectly return the wildcard AAAA instead of
NODATA.

Now the resolver checks if the domain exists (with any record type)
before falling back to wildcard matching, returning proper NODATA
responses for existing names without the requested record type.
2026-01-21 08:48:32 +01:00
crn4
9ccae8447b Merge branch 'main' into nmap/compaction-deploy 2026-01-20 19:50:51 +01:00
crn4
7941a91fcb diff fixes 2026-01-20 19:50:23 +01:00
Misha Bragin
e01998815e [infra] add embedded STUN to getting started (#5141) 2026-01-20 19:01:34 +01:00
Zoltan Papp
07e4a5a23c Fixes profile switching and repeated down/up command failures. (#5142)
When Down() and Up() are called in quick succession, the connectWithRetryRuns goroutine could set ErrResetConnection after Down() had cleared the state, causing the subsequent Up() to fail.

Fix by waiting for the goroutine to exit (via clientGiveUpChan) before Down() returns. Uses a 5-second timeout to prevent RPC timeouts while ensuring the goroutine completes in most cases.
2026-01-20 18:22:37 +01:00
Diego Romar
b3a2992a10 [client/android] - Fix Rosenpass connectivity for Android peers (#5044)
* [client] Add WGConfigurer interface

To allow Rosenpass to work both with kernel
WireGuard via wgctrl (default behavior) and
userspace WireGuard via IPC on Android/iOS
using WGUSPConfigurer

* [client] Remove Rosenpass debug logs

* [client] Return simpler peer configuration in outputKey method

ConfigureDevice, the method previously used in
outputKey via wgClient to update the device's
properties, is now defined in the WGConfigurer
interface and implemented both in kernel_unix and
usp configurers.

PresharedKey datatype was also changed from
boolean to [32]byte to compare it
to the original NetBird PSK, so that Rosenpass
may replace it with its own when necessary.

* [client] Remove unused field

* [client] Replace usage of WGConfigurer

Replaced with preshared key setter interface,
which only defines a method to set / update the preshared key.

Logic has been migrated from rosenpass/netbird_handler to client/iface.

* [client] Use same default peer keepalive value when setting preshared keys

* [client] Store PresharedKeySetter iface in rosenpass manager

To avoid no-op if SetInterface is called before generateConfig

* [client] Add mutex usage in rosenpass netbird handler

* [client] change implementation setting Rosenpass preshared key

Instead of providing a method to configure a device (device/interface.go),
it forwards the new parameters to the configurer (either
kernel_unix.go / usp.go).

This removes dependency on reading FullStats, and makes use of a common
method (buildPresharedKeyConfig in configurer/common.go) to build a
minimal WG config that only sets/updates the PSK.

netbird_handler.go now keeps s list of initializedPeers to choose whether
to set the value of "UpdateOnly" when calling iface.SetPresharedKey.

* [client] Address possible race condition

Between outputKey calls and peer removal; it
checks again if the peer still exists in the
peers map before inserting it in the
initializedPeers map.

* [client] Add psk Rosenpass-initialized check

On client/internal/peer/conn.go, the presharedKey
function would always return the current key
set in wgConfig.presharedKey.

This would eventually overwrite a key set
by Rosenpass if the feature is active.

The purpose here is to set a handler that will
check if a given peer has its psk initialized
by Rosenpass to skip updating the psk
via updatePeer (since it calls presharedKey
method in conn.go).

* Add missing updateOnly flag setup for usp peers

* Change common.go buildPresharedKeyConfig signature

PeerKey datatype changed from string to
wgTypes.Key. Callers are responsible for parsing
a peer key with string datatype.
2026-01-20 13:26:51 -03:00
Maycon Santos
202fa47f2b [client] Add support to wildcard custom records (#5125)
* **New Features**
  * Wildcard DNS fallback for eligible query types (excluding NS/SOA): attempts wildcard records when no exact match, rewrites wildcard names back to the original query, and rotates responses; preserves CNAME resolution.

* **Tests**
  * Vastly expanded coverage for wildcard behaviors, precedence, multi-record round‑robin, multi-type chains, multi-hop and cross-zone scenarios, and edge cases (NXDOMAIN/NODATA, fallthrough).

* **Chores**
  * CI lint config updated to ignore an additional codespell entry.
2026-01-20 17:21:25 +01:00
crn4
88dde22f07 small log 2026-01-20 15:20:34 +01:00
Misha Bragin
4888021ba6 Add missing activity events to the API response (#5140) 2026-01-20 15:12:22 +01:00
crn4
a23d538d09 opt over ssh users 2026-01-20 15:11:49 +01:00
Misha Bragin
a0b0b664b6 Local user password change (embedded IdP) (#5132) 2026-01-20 14:16:42 +01:00
crn4
6b15d14ef0 Merge branch 'main' into nmap/compaction-deploy 2026-01-20 14:13:01 +01:00
crn4
7598efe320 components network map for deploy - comparison with legacy 2026-01-20 14:12:13 +01:00
Diego Romar
50da5074e7 [client] change notifyDisconnected call (#5138)
On handleJobStream, when handling error codes 
from receiveJobRequest in the switch-case, 
notifying disconnected in cases where it isn't a 
disconnection breaks connection status reporting 
on mobile peers.

This commit changes it so it isn't called on
Canceled or Unimplemented status codes.
2026-01-20 07:14:33 -03:00
Zoltan Papp
58daa674ef [Management/Client] Trigger debug bundle runs from API/Dashboard (#4592) (#4832)
This PR adds the ability to trigger debug bundle generation remotely from the Management API/Dashboard.
2026-01-19 11:22:16 +01:00
Maycon Santos
245481f33b [client] fix: client/Dockerfile to reduce vulnerabilities (#5119)
The following vulnerabilities are fixed with an upgrade:
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091698
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091698
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091698
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091701
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091701

Co-authored-by: snyk-bot <snyk-bot@snyk.io>
2026-01-16 18:05:41 +01:00
shuuri-labs
b352ab84c0 Feat/quickstart reverse proxy assistant (#5100)
* add external reverse proxy config steps to quickstart script

* remove generated files

* - Remove 'press enter' prompt from post-traefik config since traefik requires no manual config
- Improve npm flow (ask users for docker network, user container names in config)

* fixes for npm flow

* nginx flow fixes

* caddy flow fixes

* Consolidate NPM_NETWORK, NGINX_NETWORK, CADDY_NETWORK into single
EXTERNAL_PROXY_NETWORK variable. Add read_proxy_docker_network()
function that prompts for Docker network for options 2-4 (Nginx,
NPM, Caddy). Generated configs now use container names when a
Docker network is specified.

* fix https for traefik

* fix sonar code smells

* fix sonar smell (add return to render_dashboard_env)

* added tls instructions to nginx flow

* removed unused bind_addr variable from quickstart.sh

* Refactor getting-started.sh for improved maintainability

Break down large functions into focused, single-responsibility components:
- Split init_environment() into 6 initialization functions
- Split print_post_setup_instructions() into 6 proxy-specific functions
- Add section headers for better code organization
- Fix 3 code smell issues (unused bind_addr variables)
- Add TLS certificate documentation for Nginx
- Link reverse proxy names to docs sections

Reduces largest function from 205 to ~90 lines while maintaining
single-file distribution. No functional changes.

* - Remove duplicate network display logic in Traefik instructions
- Use upstream_host instead of bind_addr for NPM forward hostname
- Use upstream_host instead of bind_addr in manual proxy route examples
- Prevents displaying invalid 0.0.0.0 as connection target in setup instructions

* add wait_management_direct to caddy flow to ensure script waits until containers are running/passing healthchecks before reporting 'done!'
2026-01-16 17:42:28 +01:00
ressys1978
3ce5d6a4f8 [management] Add idp timeout env variable (#4647)
Introduced the NETBIRD_IDP_TIMEOUT environment variable to the management service. This allows configuring a timeout for supported IDPs. If the variable is unset or contains an invalid value, a default timeout of 10 seconds is used as a fallback.

This is needed for larger IDP environments where 10s is just not enough time.
2026-01-16 16:23:37 +01:00
Misha Bragin
4c2eb2af73 [management] Skip email_verified if not present (#5118) 2026-01-16 16:01:39 +01:00
Misha Bragin
daf1449174 [client] Remove duplicate audiences check (#5117) 2026-01-16 14:25:02 +02:00
Misha Bragin
1ff7abe909 [management, client] Fix SSH server audience validator (#5105)
* **New Features**
  * SSH server JWT validation now accepts multiple audiences with backward-compatible handling of the previous single-audience setting and a guard ensuring at least one audience is configured.
* **Tests**
  * Test suites updated and new tests added to cover multiple-audience scenarios and compatibility with existing behavior.
* **Other**
  * Startup logging enhanced to report configured audiences for JWT auth.
2026-01-16 12:28:17 +01:00
Bethuel Mmbaga
067c77e49e [management] Add custom dns zones (#4849) 2026-01-16 12:12:05 +03:00
Maycon Santos
291e640b28 [client] Change priority between local and dns route handlers (#5106)
* Change priority between local and dns route handlers

* update priority tests
2026-01-15 17:30:10 +01:00
Pascal Fischer
efb954b7d6 [management] adapt ratelimiting (#5080) 2026-01-15 16:39:14 +01:00
Vlad
cac9326d3d [management] fetch all users data from external cache in one request (#5104)
---------

Co-authored-by: pascal <pascal@netbird.io>
2026-01-14 17:09:17 +01:00
Viktor Liu
520d9c66cf [client] Fix netstack upstream dns and add wasm debug methods (#4648) 2026-01-14 13:56:16 +01:00
Misha Bragin
ff10498a8b Feature/embedded STUN (#5062) 2026-01-14 13:13:30 +01:00
Zoltan Papp
00b747ad5d Handle fallback for invalid loginuid in ui-post-install.sh. (#5099) 2026-01-14 09:53:14 +01:00
Zoltan Papp
d9118eb239 [client] Fix WASM peer connection to lazy peers (#5097)
WASM peers now properly initiate relay connections instead of waiting for offers that lazy peers won't send.
2026-01-13 13:33:15 +01:00
Nima Sadeghifard
94de656fae [misc] Add hiring announcement with link to careers.netbird.io (#5095) 2026-01-12 19:06:28 +01:00
Misha Bragin
37abab8b69 [management] Check config compatibility (#5087)
* Enforce HttpConfig overwrite when embeddedIdp is enabled

* Disable offline_access scope in dashboard by default

* Add group propagation foundation to embedded idp

* Require groups scope in dex config for okt and pocket

* remove offline_access from device default scopes
2026-01-12 17:09:03 +01:00
Viktor Liu
b12c084a50 [client] Fall through dns chain for custom dns zones (#5081) 2026-01-12 13:56:39 +01:00
Viktor Liu
394ad19507 [client] Chase CNAMEs in local resolver to ensure musl compatibility (#5046) 2026-01-12 12:35:38 +01:00
179 changed files with 20400 additions and 3193 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
skip: go.mod,go.sum
golangci:
strategy:

View File

@@ -38,6 +38,11 @@
</strong>
<br>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
</strong>
<br>
<br>
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird terraform provider
</a>

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.22.2
FROM alpine:3.23.2
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
)
@@ -98,7 +97,6 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
}
@@ -221,21 +219,37 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
cpuProfilingStarted := false
if _, err := client.StartCPUProfile(cmd.Context(), &proto.StartCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to start CPU profiling: %v\n", err)
} else {
cpuProfilingStarted = true
defer func() {
if cpuProfilingStarted {
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
}
}
}()
}
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr
}
cmd.Println("\nDuration completed")
if cpuProfilingStarted {
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
} else {
cpuProfilingStarted = false
}
}
cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
}
@@ -302,25 +316,6 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
return nil
}
func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context(), true)
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
)
}
return statusOutputString
}
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
@@ -379,7 +374,8 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
InternalConfig: config,
StatusRecorder: recorder,
SyncResponse: syncResponse,
LogFile: logFilePath,
LogPath: logFilePath,
CPUProfile: nil,
},
debug.BundleConfig{
IncludeSystemInfo: true,

View File

@@ -99,17 +99,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
profName = activeProf.Name
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var statusOutputString string
switch {
case detailFlag:
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
statusOutputString = outputInformationHolder.FullDetailSummary()
case jsonFlag:
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
statusOutputString, err = outputInformationHolder.JSON()
case yamlFlag:
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
statusOutputString, err = outputInformationHolder.YAML()
default:
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
statusOutputString = outputInformationHolder.GeneralSummary(false, false, false, false)
}
if err != nil {

View File

@@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
@@ -97,6 +98,8 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
peersmanager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersmanager)
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
@@ -115,7 +118,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
@@ -124,7 +127,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -200,7 +200,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
connectClient := internal.NewConnectClient(ctx, config, r, false)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
var (
@@ -38,6 +39,7 @@ type Client struct {
setupKey string
jwtToken string
connect *internal.ConnectClient
recorder *peer.Status
}
// Options configures a new Client.
@@ -161,11 +163,17 @@ func New(opts Options) (*Client, error) {
func (c *Client) Start(startCtx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.cancel != nil {
if c.connect != nil {
return ErrClientAlreadyStarted
}
ctx := internal.CtxInitState(context.Background())
ctx, cancel := context.WithCancel(internal.CtxInitState(context.Background()))
defer func() {
if c.connect == nil {
cancel()
}
}()
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
@@ -173,14 +181,16 @@ func (c *Client) Start(startCtx context.Context) error {
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
c.recorder = recorder
client := internal.NewConnectClient(ctx, c.config, recorder, false)
client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
run := make(chan struct{})
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {
if err := client.Run(run, ""); err != nil {
clientErr <- err
}
}()
@@ -197,6 +207,7 @@ func (c *Client) Start(startCtx context.Context) error {
}
c.connect = client
c.cancel = cancel
return nil
}
@@ -211,17 +222,23 @@ func (c *Client) Stop(ctx context.Context) error {
return ErrClientNotStarted
}
if c.cancel != nil {
c.cancel()
c.cancel = nil
}
done := make(chan error, 1)
connect := c.connect
go func() {
done <- c.connect.Stop()
done <- connect.Stop()
}()
select {
case <-ctx.Done():
c.cancel = nil
c.connect = nil
return ctx.Err()
case err := <-done:
c.cancel = nil
c.connect = nil
if err != nil {
return fmt.Errorf("stop: %w", err)
}
@@ -315,6 +332,62 @@ func (c *Client) NewHTTPClient() *http.Client {
}
}
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()
recorder := c.recorder
connect := c.connect
c.mu.Unlock()
if recorder == nil {
return peer.FullStatus{}, errors.New("client not started")
}
if connect != nil {
engine := connect.Engine()
if engine != nil {
_ = engine.RunHealthProbes(false)
}
}
return recorder.GetFullStatus(), nil
}
// GetLatestSyncResponse returns the latest sync response from the management server.
func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
syncResp, err := engine.GetLatestSyncResponse()
if err != nil {
return nil, fmt.Errorf("get sync response: %w", err)
}
return syncResp, nil
}
// SetLogLevel sets the logging level for the client and its components.
func (c *Client) SetLogLevel(levelStr string) error {
level, err := logrus.ParseLevel(levelStr)
if err != nil {
return fmt.Errorf("parse log level: %w", err)
}
logrus.SetLevel(level)
c.mu.Lock()
connect := c.connect
c.mu.Unlock()
if connect != nil {
connect.SetLogLevel(level)
}
return nil
}
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.

View File

@@ -0,0 +1,169 @@
package bind
import (
"errors"
"net"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
)
var (
errNoIPv4Conn = errors.New("no IPv4 connection available")
errNoIPv6Conn = errors.New("no IPv6 connection available")
errInvalidAddr = errors.New("invalid address type")
)
// DualStackPacketConn wraps IPv4 and IPv6 UDP connections and routes writes
// to the appropriate connection based on the destination address.
// ReadFrom is not used in the hot path - ICEBind receives packets via
// BatchReader.ReadBatch() directly. This is only used by udpMux for sending.
type DualStackPacketConn struct {
ipv4Conn net.PacketConn
ipv6Conn net.PacketConn
readFromWarn sync.Once
}
// NewDualStackPacketConn creates a new dual-stack packet connection.
func NewDualStackPacketConn(ipv4Conn, ipv6Conn net.PacketConn) *DualStackPacketConn {
return &DualStackPacketConn{
ipv4Conn: ipv4Conn,
ipv6Conn: ipv6Conn,
}
}
// ReadFrom reads from the available connection (preferring IPv4).
// NOTE: This method is NOT used in the data path. ICEBind receives packets via
// BatchReader.ReadBatch() directly for both IPv4 and IPv6, which is much more efficient.
// This implementation exists only to satisfy the net.PacketConn interface for the udpMux,
// but the udpMux only uses WriteTo() for sending STUN responses - it never calls ReadFrom()
// because STUN packets are filtered and forwarded via HandleSTUNMessage() from the receive path.
func (d *DualStackPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
d.readFromWarn.Do(func() {
log.Warn("DualStackPacketConn.ReadFrom called - this is unexpected and may indicate an inefficient code path")
})
if d.ipv4Conn != nil {
return d.ipv4Conn.ReadFrom(b)
}
if d.ipv6Conn != nil {
return d.ipv6Conn.ReadFrom(b)
}
return 0, nil, net.ErrClosed
}
// WriteTo writes to the appropriate connection based on the address type.
func (d *DualStackPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, &net.OpError{
Op: "write",
Net: "udp",
Addr: addr,
Err: errInvalidAddr,
}
}
if udpAddr.IP.To4() == nil {
if d.ipv6Conn != nil {
return d.ipv6Conn.WriteTo(b, addr)
}
return 0, &net.OpError{
Op: "write",
Net: "udp6",
Addr: addr,
Err: errNoIPv6Conn,
}
}
if d.ipv4Conn != nil {
return d.ipv4Conn.WriteTo(b, addr)
}
return 0, &net.OpError{
Op: "write",
Net: "udp4",
Addr: addr,
Err: errNoIPv4Conn,
}
}
// Close closes both connections.
func (d *DualStackPacketConn) Close() error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.Close(); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.Close(); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// LocalAddr returns the local address of the IPv4 connection if available,
// otherwise the IPv6 connection.
func (d *DualStackPacketConn) LocalAddr() net.Addr {
if d.ipv4Conn != nil {
return d.ipv4Conn.LocalAddr()
}
if d.ipv6Conn != nil {
return d.ipv6Conn.LocalAddr()
}
return nil
}
// SetDeadline sets the deadline for both connections.
func (d *DualStackPacketConn) SetDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// SetReadDeadline sets the read deadline for both connections.
func (d *DualStackPacketConn) SetReadDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetReadDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetReadDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// SetWriteDeadline sets the write deadline for both connections.
func (d *DualStackPacketConn) SetWriteDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetWriteDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetWriteDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}

View File

@@ -0,0 +1,119 @@
package bind
import (
"net"
"testing"
)
var (
ipv4Addr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}
ipv6Addr = &net.UDPAddr{IP: net.ParseIP("::1"), Port: 12345}
payload = make([]byte, 1200)
)
func BenchmarkWriteTo_DirectUDPConn(b *testing.B) {
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = conn.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_IPv4Only(b *testing.B) {
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn.Close()
ds := NewDualStackPacketConn(conn, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_IPv6Only(b *testing.B) {
conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn.Close()
ds := NewDualStackPacketConn(nil, conn)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv6Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_IPv4Traffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_IPv6Traffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv6Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_MixedTraffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
addrs := []net.Addr{ipv4Addr, ipv6Addr}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, addrs[i&1])
}
}

View File

@@ -0,0 +1,191 @@
package bind
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDualStackPacketConn_RoutesWritesToCorrectSocket(t *testing.T) {
ipv4Conn := &mockPacketConn{network: "udp4"}
ipv6Conn := &mockPacketConn{network: "udp6"}
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
tests := []struct {
name string
addr *net.UDPAddr
wantSocket string
}{
{
name: "IPv4 address",
addr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv6 address",
addr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
wantSocket: "udp6",
},
{
name: "IPv4-mapped IPv6 goes to IPv4",
addr: &net.UDPAddr{IP: net.ParseIP("::ffff:192.168.1.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv4 loopback",
addr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv6 loopback",
addr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 1234},
wantSocket: "udp6",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ipv4Conn.writeCount = 0
ipv6Conn.writeCount = 0
n, err := dualStack.WriteTo([]byte("test"), tt.addr)
require.NoError(t, err)
assert.Equal(t, 4, n)
if tt.wantSocket == "udp4" {
assert.Equal(t, 1, ipv4Conn.writeCount, "expected write to IPv4")
assert.Equal(t, 0, ipv6Conn.writeCount, "expected no write to IPv6")
} else {
assert.Equal(t, 0, ipv4Conn.writeCount, "expected no write to IPv4")
assert.Equal(t, 1, ipv6Conn.writeCount, "expected write to IPv6")
}
})
}
}
func TestDualStackPacketConn_IPv4OnlyRejectsIPv6(t *testing.T) {
dualStack := NewDualStackPacketConn(&mockPacketConn{network: "udp4"}, nil)
// IPv4 works
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
require.NoError(t, err)
// IPv6 fails
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
require.Error(t, err)
assert.Contains(t, err.Error(), "no IPv6 connection")
}
func TestDualStackPacketConn_IPv6OnlyRejectsIPv4(t *testing.T) {
dualStack := NewDualStackPacketConn(nil, &mockPacketConn{network: "udp6"})
// IPv6 works
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
require.NoError(t, err)
// IPv4 fails
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
require.Error(t, err)
assert.Contains(t, err.Error(), "no IPv4 connection")
}
// TestDualStackPacketConn_ReadFromIsNotUsedInHotPath documents that ReadFrom
// only reads from one socket (IPv4 preferred). This is fine because the actual
// receive path uses wireguard-go's BatchReader directly, not ReadFrom.
func TestDualStackPacketConn_ReadFromIsNotUsedInHotPath(t *testing.T) {
ipv4Conn := &mockPacketConn{
network: "udp4",
readData: []byte("from ipv4"),
readAddr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
}
ipv6Conn := &mockPacketConn{
network: "udp6",
readData: []byte("from ipv6"),
readAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
}
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
buf := make([]byte, 100)
n, addr, err := dualStack.ReadFrom(buf)
require.NoError(t, err)
// reads from IPv4 (preferred) - this is expected behavior
assert.Equal(t, "from ipv4", string(buf[:n]))
assert.Equal(t, "192.168.1.1", addr.(*net.UDPAddr).IP.String())
}
func TestDualStackPacketConn_LocalAddrPrefersIPv4(t *testing.T) {
ipv4Addr := &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 51820}
ipv6Addr := &net.UDPAddr{IP: net.ParseIP("::"), Port: 51820}
tests := []struct {
name string
ipv4 net.PacketConn
ipv6 net.PacketConn
wantAddr net.Addr
}{
{
name: "both available returns IPv4",
ipv4: &mockPacketConn{localAddr: ipv4Addr},
ipv6: &mockPacketConn{localAddr: ipv6Addr},
wantAddr: ipv4Addr,
},
{
name: "IPv4 only",
ipv4: &mockPacketConn{localAddr: ipv4Addr},
ipv6: nil,
wantAddr: ipv4Addr,
},
{
name: "IPv6 only",
ipv4: nil,
ipv6: &mockPacketConn{localAddr: ipv6Addr},
wantAddr: ipv6Addr,
},
{
name: "neither returns nil",
ipv4: nil,
ipv6: nil,
wantAddr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dualStack := NewDualStackPacketConn(tt.ipv4, tt.ipv6)
assert.Equal(t, tt.wantAddr, dualStack.LocalAddr())
})
}
}
// mock
type mockPacketConn struct {
network string
writeCount int
readData []byte
readAddr net.Addr
localAddr net.Addr
}
func (m *mockPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
if m.readData != nil {
return copy(b, m.readData), m.readAddr, nil
}
return 0, nil, nil
}
func (m *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
m.writeCount++
return len(b), nil
}
func (m *mockPacketConn) Close() error { return nil }
func (m *mockPacketConn) LocalAddr() net.Addr { return m.localAddr }
func (m *mockPacketConn) SetDeadline(t time.Time) error { return nil }
func (m *mockPacketConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockPacketConn) SetWriteDeadline(t time.Time) error { return nil }

View File

@@ -14,7 +14,6 @@ import (
"github.com/pion/stun/v3"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
@@ -28,22 +27,7 @@ type receiverCreator struct {
}
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
if ipv4PC, ok := pc.(*ipv4.PacketConn); ok {
return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool)
}
// IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the
// wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6.
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
buf := bufs[0]
size, ep, err := conn.ReadFromUDPAddrPort(buf)
if err != nil {
return 0, err
}
sizes[0] = size
stdEp := &wgConn.StdNetEndpoint{AddrPort: ep}
eps[0] = stdEp
return 1, nil
}
return rc.iceBind.createReceiverFn(pc, conn, rxOffload, msgPool)
}
// ICEBind is a bind implementation with two main features:
@@ -73,6 +57,8 @@ type ICEBind struct {
muUDPMux sync.Mutex
udpMux *udpmux.UniversalUDPMuxDefault
ipv4Conn *net.UDPConn
ipv6Conn *net.UDPConn
}
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
@@ -118,6 +104,12 @@ func (s *ICEBind) Close() error {
close(s.closedChan)
s.muUDPMux.Lock()
s.ipv4Conn = nil
s.ipv6Conn = nil
s.udpMux = nil
s.muUDPMux.Unlock()
return s.StdNetBind.Close()
}
@@ -175,19 +167,18 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
return nil
}
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(conn),
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
MTU: s.mtu,
},
)
// Detect IPv4 vs IPv6 from connection's local address
if localAddr := conn.LocalAddr().(*net.UDPAddr); localAddr.IP.To4() != nil {
s.ipv4Conn = conn
} else {
s.ipv6Conn = conn
}
s.createOrUpdateMux()
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := getMessages(msgsPool)
for i := range bufs {
@@ -195,12 +186,13 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer putMessages(msgs, msgsPool)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
//nolint
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
//nolint:staticcheck
_, err = pc.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
@@ -222,12 +214,12 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
if ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr); ok {
continue
}
sizes[i] = msg.N
@@ -248,6 +240,38 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
}
}
// createOrUpdateMux creates or updates the UDP mux with the available connections.
// Must be called with muUDPMux held.
func (s *ICEBind) createOrUpdateMux() {
var muxConn net.PacketConn
switch {
case s.ipv4Conn != nil && s.ipv6Conn != nil:
muxConn = NewDualStackPacketConn(
nbnet.WrapPacketConn(s.ipv4Conn),
nbnet.WrapPacketConn(s.ipv6Conn),
)
case s.ipv4Conn != nil:
muxConn = nbnet.WrapPacketConn(s.ipv4Conn)
case s.ipv6Conn != nil:
muxConn = nbnet.WrapPacketConn(s.ipv6Conn)
default:
return
}
// Don't close the old mux - it doesn't own the underlying connections.
// The sockets are managed by WireGuard's StdNetBind, not by us.
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
udpmux.UniversalUDPMuxParams{
UDPConn: muxConn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
MTU: s.mtu,
},
)
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for i := range buffers {
if !stun.IsMessage(buffers[i]) {
@@ -260,9 +284,14 @@ func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr)
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
s.muUDPMux.Lock()
mux := s.udpMux
s.muUDPMux.Unlock()
if mux != nil {
if muxErr := mux.HandleSTUNMessage(msg, addr); muxErr != nil {
log.Warnf("failed to handle STUN packet: %v", muxErr)
}
}
buffers[i] = []byte{}

View File

@@ -0,0 +1,324 @@
package bind
import (
"fmt"
"net"
"net/netip"
"sync"
"testing"
"time"
"github.com/pion/transport/v3/stdnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestICEBind_CreatesReceiverForBothIPv4AndIPv6(t *testing.T) {
iceBind := setupICEBind(t)
ipv4Conn, ipv6Conn := createDualStackConns(t)
defer ipv4Conn.Close()
defer ipv6Conn.Close()
rc := receiverCreator{iceBind}
pool := createMsgPool()
// Simulate wireguard-go calling CreateReceiverFn for IPv4
ipv4RecvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, pool)
require.NotNil(t, ipv4RecvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn, "should store IPv4 connection")
assert.Nil(t, iceBind.ipv6Conn, "IPv6 not added yet")
assert.NotNil(t, iceBind.udpMux, "mux should be created after first connection")
iceBind.muUDPMux.Unlock()
// Simulate wireguard-go calling CreateReceiverFn for IPv6
ipv6RecvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, pool)
require.NotNil(t, ipv6RecvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn, "should still have IPv4 connection")
assert.NotNil(t, iceBind.ipv6Conn, "should now have IPv6 connection")
assert.NotNil(t, iceBind.udpMux, "mux should still exist")
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
func TestICEBind_WorksWithIPv4Only(t *testing.T) {
iceBind := setupICEBind(t)
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
require.NoError(t, err)
defer ipv4Conn.Close()
rc := receiverCreator{iceBind}
recvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, createMsgPool())
require.NotNil(t, recvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn)
assert.Nil(t, iceBind.ipv6Conn)
assert.NotNil(t, iceBind.udpMux)
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
func TestICEBind_WorksWithIPv6Only(t *testing.T) {
iceBind := setupICEBind(t)
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Conn.Close()
rc := receiverCreator{iceBind}
recvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, createMsgPool())
require.NotNil(t, recvFn)
iceBind.muUDPMux.Lock()
assert.Nil(t, iceBind.ipv4Conn)
assert.NotNil(t, iceBind.ipv6Conn)
assert.NotNil(t, iceBind.udpMux)
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
// TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously verifies that we can communicate
// with peers on different address families through the same DualStackPacketConn.
func TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously(t *testing.T) {
// two "remote peers" listening on different address families
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Peer.Close()
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Peer.Close()
// our local dual-stack connection
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Local.Close()
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
defer ipv6Local.Close()
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
// send to both peers
_, err = dualStack.WriteTo([]byte("to-ipv4"), ipv4Peer.LocalAddr())
require.NoError(t, err)
_, err = dualStack.WriteTo([]byte("to-ipv6"), ipv6Peer.LocalAddr())
require.NoError(t, err)
// verify IPv4 peer got its packet from the IPv4 socket
buf := make([]byte, 100)
_ = ipv4Peer.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err := ipv4Peer.ReadFrom(buf)
require.NoError(t, err)
assert.Equal(t, "to-ipv4", string(buf[:n]))
assert.Equal(t, ipv4Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
// verify IPv6 peer got its packet from the IPv6 socket
_ = ipv6Peer.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err = ipv6Peer.ReadFrom(buf)
require.NoError(t, err)
assert.Equal(t, "to-ipv6", string(buf[:n]))
assert.Equal(t, ipv6Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
}
// TestICEBind_HandlesConcurrentMixedTraffic sends packets concurrently to both IPv4
// and IPv6 peers. Verifies no packets get misrouted (IPv4 peer only gets v4- packets,
// IPv6 peer only gets v6- packets). Some packet loss is acceptable for UDP.
func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Peer.Close()
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Peer.Close()
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Local.Close()
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
defer ipv6Local.Close()
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
const packetsPerFamily = 500
ipv4Received := make(chan string, packetsPerFamily)
ipv6Received := make(chan string, packetsPerFamily)
startGate := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, 100)
for i := 0; i < packetsPerFamily; i++ {
n, _, err := ipv4Peer.ReadFrom(buf)
if err != nil {
return
}
ipv4Received <- string(buf[:n])
}
}()
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, 100)
for i := 0; i < packetsPerFamily; i++ {
n, _, err := ipv6Peer.ReadFrom(buf)
if err != nil {
return
}
ipv6Received <- string(buf[:n])
}
}()
wg.Add(1)
go func() {
defer wg.Done()
<-startGate
for i := 0; i < packetsPerFamily; i++ {
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v4-%04d", i)), ipv4Peer.LocalAddr())
}
}()
wg.Add(1)
go func() {
defer wg.Done()
<-startGate
for i := 0; i < packetsPerFamily; i++ {
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v6-%04d", i)), ipv6Peer.LocalAddr())
}
}()
close(startGate)
time.AfterFunc(5*time.Second, func() {
_ = ipv4Peer.SetReadDeadline(time.Now())
_ = ipv6Peer.SetReadDeadline(time.Now())
})
wg.Wait()
close(ipv4Received)
close(ipv6Received)
ipv4Count := 0
for pkt := range ipv4Received {
require.True(t, len(pkt) >= 3 && pkt[:3] == "v4-", "IPv4 peer got misrouted packet: %s", pkt)
ipv4Count++
}
ipv6Count := 0
for pkt := range ipv6Received {
require.True(t, len(pkt) >= 3 && pkt[:3] == "v6-", "IPv6 peer got misrouted packet: %s", pkt)
ipv6Count++
}
assert.Equal(t, packetsPerFamily, ipv4Count)
assert.Equal(t, packetsPerFamily, ipv6Count)
}
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
tests := []struct {
name string
network string
addr string
wantIPv4 bool
}{
{"IPv4 any", "udp4", "0.0.0.0:0", true},
{"IPv4 loopback", "udp4", "127.0.0.1:0", true},
{"IPv6 any", "udp6", "[::]:0", false},
{"IPv6 loopback", "udp6", "[::1]:0", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, err := net.ResolveUDPAddr(tt.network, tt.addr)
require.NoError(t, err)
conn, err := net.ListenUDP(tt.network, addr)
if err != nil {
t.Skipf("%s not available: %v", tt.network, err)
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
isIPv4 := localAddr.IP.To4() != nil
assert.Equal(t, tt.wantIPv4, isIPv4)
})
}
}
// helpers
func setupICEBind(t *testing.T) *ICEBind {
t.Helper()
transportNet, err := stdnet.NewNet()
require.NoError(t, err)
address := wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/10"),
}
return NewICEBind(transportNet, nil, address, 1280)
}
func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) {
t.Helper()
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
require.NoError(t, err)
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
ipv4Conn.Close()
t.Skipf("IPv6 not available: %v", err)
}
return ipv4Conn, ipv6Conn
}
func createMsgPool() *sync.Pool {
return &sync.Pool{
New: func() any {
msgs := make([]ipv6.Message, 1)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, 0, 40)
}
return &msgs
},
}
}
func listenUDP(t *testing.T, network, addr string) *net.UDPConn {
t.Helper()
udpAddr, err := net.ResolveUDPAddr(network, addr)
require.NoError(t, err)
conn, err := net.ListenUDP(network, udpAddr)
require.NoError(t, err)
return conn
}

View File

@@ -3,8 +3,22 @@ package configurer
import (
"net"
"net/netip"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// buildPresharedKeyConfig creates a wgtypes.Config for setting a preshared key on a peer.
// This is a shared helper used by both kernel and userspace configurers.
func buildPresharedKeyConfig(peerKey wgtypes.Key, psk wgtypes.Key, updateOnly bool) wgtypes.Config {
return wgtypes.Config{
Peers: []wgtypes.PeerConfig{{
PublicKey: peerKey,
PresharedKey: &psk,
UpdateOnly: updateOnly,
}},
}
}
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {

View File

@@ -15,8 +15,6 @@ import (
"github.com/netbirdio/netbird/monotime"
)
var zeroKey wgtypes.Key
type KernelConfigurer struct {
deviceName string
}
@@ -48,6 +46,18 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil
}
// SetPresharedKey sets the preshared key for a peer.
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
func (c *KernelConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
return c.configure(cfg)
}
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
@@ -279,7 +289,7 @@ func (c *KernelConfigurer) FullStats() (*Stats, error) {
TxBytes: p.TransmitBytes,
RxBytes: p.ReceiveBytes,
LastHandshake: p.LastHandshakeTime,
PresharedKey: p.PresharedKey != zeroKey,
PresharedKey: [32]byte(p.PresharedKey),
}
if p.Endpoint != nil {
peer.Endpoint = *p.Endpoint

View File

@@ -22,17 +22,16 @@ import (
)
const (
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
)
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
@@ -72,6 +71,18 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config))
}
// SetPresharedKey sets the preshared key for a peer.
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
func (c *WGUSPConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
return c.device.IpcSet(toWgUserspaceString(cfg))
}
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
@@ -422,23 +433,19 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
hexKey := hex.EncodeToString(p.PublicKey[:])
sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey))
if p.Remove {
sb.WriteString("remove=true\n")
}
if p.UpdateOnly {
sb.WriteString("update_only=true\n")
}
if p.PresharedKey != nil {
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey))
}
if p.Remove {
sb.WriteString("remove=true")
}
if p.ReplaceAllowedIPs {
sb.WriteString("replace_allowed_ips=true\n")
}
for _, aip := range p.AllowedIPs {
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
}
if p.Endpoint != nil {
sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String()))
}
@@ -446,6 +453,14 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
if p.PersistentKeepaliveInterval != nil {
sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds())))
}
if p.ReplaceAllowedIPs {
sb.WriteString("replace_allowed_ips=true\n")
}
for _, aip := range p.AllowedIPs {
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
}
}
return sb.String()
}
@@ -599,7 +614,9 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
continue
}
if val != "" && val != "0000000000000000000000000000000000000000000000000000000000000000" {
currentPeer.PresharedKey = true
if pskKey, err := hexToWireguardKey(val); err == nil {
currentPeer.PresharedKey = [32]byte(pskKey)
}
}
}
}

View File

@@ -12,7 +12,7 @@ type Peer struct {
TxBytes int64
RxBytes int64
LastHandshake time.Time
PresharedKey bool
PresharedKey [32]byte
}
type Stats struct {

View File

@@ -17,6 +17,7 @@ type WGConfigurer interface {
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
Close()
GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)

View File

@@ -297,6 +297,19 @@ func (w *WGIface) FullStats() (*configurer.Stats, error) {
return w.configurer.FullStats()
}
// SetPresharedKey sets or updates the preshared key for a peer.
// If updateOnly is true, only updates existing peer; if false, creates or updates.
func (w *WGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
return w.configurer.SetPresharedKey(peerKey, psk, updateOnly)
}
func (w *WGIface) waitUntilRemoved() error {
maxWaitTime := 5 * time.Second
timeout := time.NewTimer(maxWaitTime)

View File

@@ -59,7 +59,6 @@ func NewConnectClient(
config *profilemanager.Config,
statusRecorder *peer.Status,
doInitalAutoUpdate bool,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
@@ -71,8 +70,8 @@ func NewConnectClient(
}
// Run with main logic.
func (c *ConnectClient) Run(runningChan chan struct{}) error {
return c.run(MobileDependency{}, runningChan)
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
return c.run(MobileDependency{}, runningChan, logPath)
}
// RunOnAndroid with main logic on mobile system
@@ -93,7 +92,7 @@ func (c *ConnectClient) RunOnAndroid(
DnsReadyListener: dnsReadyListener,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil)
return c.run(mobileDependency, nil, "")
}
func (c *ConnectClient) RunOniOS(
@@ -111,10 +110,10 @@ func (c *ConnectClient) RunOniOS(
DnsManager: dnsManager,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil)
return c.run(mobileDependency, nil, "")
}
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
defer func() {
if r := recover(); r != nil {
rec := c.statusRecorder
@@ -284,7 +283,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
relayURLs, token := parseRelayInfo(loginResp)
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
if err != nil {
log.Error(err)
return wrapErr(err)
@@ -420,6 +419,19 @@ func (c *ConnectClient) GetLatestSyncResponse() (*mgmProto.SyncResponse, error)
return syncResponse, nil
}
// SetLogLevel sets the log level for the firewall manager if the engine is running.
func (c *ConnectClient) SetLogLevel(level log.Level) {
engine := c.Engine()
if engine == nil {
return
}
fwManager := engine.GetFirewallManager()
if fwManager != nil {
fwManager.SetLogLevel(level)
}
}
// Status returns the current client status
func (c *ConnectClient) Status() StatusType {
if c == nil {
@@ -459,7 +471,7 @@ func (c *ConnectClient) SetSyncResponsePersistence(enabled bool) {
}
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig, logPath string) (*EngineConfig, error) {
nm := false
if config.NetworkMonitor != nil {
nm = *config.NetworkMonitor
@@ -494,7 +506,10 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
LazyConnectionEnabled: config.LazyConnectionEnabled,
MTU: selectMTU(config.MTU, peerConfig.Mtu),
MTU: selectMTU(config.MTU, peerConfig.Mtu),
LogPath: logPath,
ProfileConfig: config,
}
if config.PreSharedKey != "" {

View File

@@ -28,8 +28,10 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
const readmeContent = `Netbird debug bundle
@@ -57,6 +59,7 @@ block.prof: Block profiling information.
heap.prof: Heap profiling information (snapshot of memory allocations).
allocs.prof: Allocations profiling information.
threadcreate.prof: Thread creation profiling information.
cpu.prof: CPU profiling information.
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
@@ -223,10 +226,10 @@ type BundleGenerator struct {
internalConfig *profilemanager.Config
statusRecorder *peer.Status
syncResponse *mgmProto.SyncResponse
logFile string
logPath string
cpuProfile []byte
anonymize bool
clientStatus string
includeSystemInfo bool
logFileCount uint32
@@ -235,7 +238,6 @@ type BundleGenerator struct {
type BundleConfig struct {
Anonymize bool
ClientStatus string
IncludeSystemInfo bool
LogFileCount uint32
}
@@ -244,7 +246,8 @@ type GeneratorDependencies struct {
InternalConfig *profilemanager.Config
StatusRecorder *peer.Status
SyncResponse *mgmProto.SyncResponse
LogFile string
LogPath string
CPUProfile []byte
}
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
@@ -260,10 +263,10 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
internalConfig: deps.InternalConfig,
statusRecorder: deps.StatusRecorder,
syncResponse: deps.SyncResponse,
logFile: deps.LogFile,
logPath: deps.LogPath,
cpuProfile: deps.CPUProfile,
anonymize: cfg.Anonymize,
clientStatus: cfg.ClientStatus,
includeSystemInfo: cfg.IncludeSystemInfo,
logFileCount: logFileCount,
}
@@ -309,13 +312,6 @@ func (g *BundleGenerator) createArchive() error {
return fmt.Errorf("add status: %w", err)
}
if g.statusRecorder != nil {
status := g.statusRecorder.GetFullStatus()
seedFromStatus(g.anonymizer, &status)
} else {
log.Debugf("no status recorder available for seeding")
}
if err := g.addConfig(); err != nil {
log.Errorf("failed to add config to debug bundle: %v", err)
}
@@ -332,6 +328,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add profiles to debug bundle: %v", err)
}
if err := g.addCPUProfile(); err != nil {
log.Errorf("failed to add CPU profile to debug bundle: %v", err)
}
if err := g.addStackTrace(); err != nil {
log.Errorf("failed to add stack trace to debug bundle: %v", err)
}
@@ -352,7 +352,7 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add wg show output: %v", err)
}
if g.logFile != "" && !slices.Contains(util.SpecialLogs, g.logFile) {
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
@@ -401,11 +401,26 @@ func (g *BundleGenerator) addReadme() error {
}
func (g *BundleGenerator) addStatus() error {
if status := g.clientStatus; status != "" {
statusReader := strings.NewReader(status)
if g.statusRecorder != nil {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
fullStatus := g.statusRecorder.GetFullStatus()
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
statusOutput := overview.FullDetailSummary()
statusReader := strings.NewReader(statusOutput)
if err := g.addFileToZip(statusReader, "status.txt"); err != nil {
return fmt.Errorf("add status file to zip: %w", err)
}
seedFromStatus(g.anonymizer, &fullStatus)
} else {
log.Debugf("no status recorder available for seeding")
}
return nil
}
@@ -535,6 +550,19 @@ func (g *BundleGenerator) addProf() (err error) {
return nil
}
func (g *BundleGenerator) addCPUProfile() error {
if len(g.cpuProfile) == 0 {
return nil
}
reader := bytes.NewReader(g.cpuProfile)
if err := g.addFileToZip(reader, "cpu.prof"); err != nil {
return fmt.Errorf("add CPU profile to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addStackTrace() error {
buf := make([]byte, 5242880) // 5 MB buffer
n := runtime.Stack(buf, true)
@@ -710,14 +738,14 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
}
func (g *BundleGenerator) addLogfile() error {
if g.logFile == "" {
if g.logPath == "" {
log.Debugf("skipping empty log file in debug bundle")
return nil
}
logDir := filepath.Dir(g.logFile)
logDir := filepath.Dir(g.logPath)
if err := g.addSingleLogfile(g.logFile, clientLogFile); err != nil {
if err := g.addSingleLogfile(g.logPath, clientLogFile); err != nil {
return fmt.Errorf("add client log file to zip: %w", err)
}

View File

@@ -0,0 +1,101 @@
package debug
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"github.com/netbirdio/netbird/upload-server/types"
)
const maxBundleUploadSize = 50 * 1024 * 1024
func UploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
response, err := getUploadURL(ctx, url, managementURL)
if err != nil {
return "", err
}
err = upload(ctx, filePath, response)
if err != nil {
return "", err
}
return response.Key, nil
}
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
fileData, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("open file: %w", err)
}
defer fileData.Close()
stat, err := fileData.Stat()
if err != nil {
return fmt.Errorf("stat file: %w", err)
}
if stat.Size() > maxBundleUploadSize {
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
}
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
if err != nil {
return fmt.Errorf("create PUT request: %w", err)
}
req.ContentLength = stat.Size()
req.Header.Set("Content-Type", "application/octet-stream")
putResp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("upload failed: %v", err)
}
defer putResp.Body.Close()
if putResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(putResp.Body)
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
}
return nil
}
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
id := getURLHash(managementURL)
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
if err != nil {
return nil, fmt.Errorf("create GET request: %w", err)
}
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
resp, err := http.DefaultClient.Do(getReq)
if err != nil {
return nil, fmt.Errorf("get presigned URL: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
}
urlBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body: %w", err)
}
var response types.GetURLResponse
if err := json.Unmarshal(urlBytes, &response); err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
return &response, nil
}
func getURLHash(url string) string {
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
}

View File

@@ -1,4 +1,4 @@
package server
package debug
import (
"context"
@@ -38,7 +38,7 @@ func TestUpload(t *testing.T) {
fileContent := []byte("test file content")
err := os.WriteFile(file, fileContent, 0640)
require.NoError(t, err)
key, err := uploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
key, err := UploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
require.NoError(t, err)
id := getURLHash(testURL)
require.Contains(t, key, id+"/")

View File

@@ -60,7 +60,7 @@ func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
}
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
if peer.PresharedKey {
if peer.PresharedKey != [32]byte{} {
sb.WriteString(" preshared key: (hidden)\n")
}
}

View File

@@ -76,7 +76,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones {
if zone.SkipPTRProcess {
if zone.NonAuthoritative {
continue
}
for _, record := range zone.Records {

View File

@@ -3,17 +3,21 @@ package dns
import (
"fmt"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
)
const (
PriorityMgmtCache = 150
PriorityLocal = 100
PriorityDNSRoute = 75
PriorityDNSRoute = 100
PriorityLocal = 75
PriorityUpstream = 50
PriorityDefault = 1
PriorityFallback = -100
@@ -43,7 +47,23 @@ type HandlerChain struct {
type ResponseWriterChain struct {
dns.ResponseWriter
origPattern string
requestID string
shouldContinue bool
response *dns.Msg
meta map[string]string
}
// RequestID returns the request ID for tracing
func (w *ResponseWriterChain) RequestID() string {
return w.requestID
}
// SetMeta sets a metadata key-value pair for logging
func (w *ResponseWriterChain) SetMeta(key, value string) {
if w.meta == nil {
w.meta = make(map[string]string)
}
w.meta[key] = value
}
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
@@ -52,6 +72,7 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
w.shouldContinue = true
return nil
}
w.response = m
return w.ResponseWriter.WriteMsg(m)
}
@@ -101,6 +122,8 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
pos := c.findHandlerPosition(entry)
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
c.logHandlers()
}
// findHandlerPosition determines where to insert a new handler based on priority and specificity
@@ -140,68 +163,109 @@ func (c *HandlerChain) removeEntry(pattern string, priority int) {
for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
log.Debugf("removing handler pattern: domain=%s priority=%d", entry.OrigPattern, priority)
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
c.logHandlers()
break
}
}
}
// logHandlers logs the current handler chain state. Caller must hold the lock.
func (c *HandlerChain) logHandlers() {
if !log.IsLevelEnabled(log.TraceLevel) {
return
}
var b strings.Builder
b.WriteString("handler chain (" + strconv.Itoa(len(c.handlers)) + "):\n")
for _, h := range c.handlers {
b.WriteString(" - pattern: domain=" + h.Pattern + " original: domain=" + h.OrigPattern +
" wildcard=" + strconv.FormatBool(h.IsWildcard) +
" match_subdomain=" + strconv.FormatBool(h.MatchSubdomains) +
" priority=" + strconv.Itoa(h.Priority) + "\n")
}
log.Trace(strings.TrimSuffix(b.String(), "\n"))
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
qname := strings.ToLower(r.Question[0].Name)
startTime := time.Now()
requestID := resutil.GenerateRequestID()
logger := log.WithFields(log.Fields{
"request_id": requestID,
"dns_id": fmt.Sprintf("%04x", r.Id),
})
question := r.Question[0]
qname := strings.ToLower(question.Name)
c.mu.RLock()
handlers := slices.Clone(c.handlers)
c.mu.RUnlock()
if log.IsLevelEnabled(log.TraceLevel) {
var b strings.Builder
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
for _, h := range handlers {
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
}
log.Trace(strings.TrimSuffix(b.String(), "\n"))
}
// Try handlers in priority order
for _, entry := range handlers {
matched := c.isHandlerMatch(qname, entry)
if matched {
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
// Only log continue for non-management cache handlers to reduce noise
if entry.Priority != PriorityMgmtCache {
log.Tracef("handler requested continue to next handler for domain=%s", qname)
}
continue
}
return
if !c.isHandlerMatch(qname, entry) {
continue
}
handlerName := entry.OrigPattern
if s, ok := entry.Handler.(interface{ String() string }); ok {
handlerName = s.String()
}
logger.Tracef("question: domain=%s type=%s class=%s -> handler=%s pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass],
handlerName, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
requestID: requestID,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
if entry.Priority != PriorityMgmtCache {
logger.Tracef("handler requested continue for domain=%s", qname)
}
continue
}
c.logResponse(logger, chainWriter, qname, startTime)
return
}
// No handler matched or all handlers passed
log.Tracef("no handler found for domain=%s", qname)
logger.Tracef("no handler found for domain=%s type=%s class=%s",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeRefused)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
logger.Errorf("failed to write DNS response: %v", err)
}
}
func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, qname string, startTime time.Time) {
if cw.response == nil {
return
}
var meta string
for k, v := range cw.meta {
meta += " " + k + "=" + v
}
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
meta, time.Since(startTime))
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":

View File

@@ -1,30 +1,52 @@
package local
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"strings"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
)
const externalResolutionTimeout = 4 * time.Second
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
domains map[domain.Domain]struct{}
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
zones map[domain.Domain]bool
resolver resolver
ctx context.Context
cancel context.CancelFunc
}
func NewResolver() *Resolver {
ctx, cancel := context.WithCancel(context.Background())
return &Resolver{
records: make(map[dns.Question][]dns.RR),
domains: make(map[domain.Domain]struct{}),
zones: make(map[domain.Domain]bool),
ctx: ctx,
cancel: cancel,
}
}
@@ -37,7 +59,18 @@ func (d *Resolver) String() string {
return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
}
func (d *Resolver) Stop() {}
func (d *Resolver) Stop() {
if d.cancel != nil {
d.cancel()
}
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
maps.Clear(d.domains)
maps.Clear(d.zones)
}
// ID returns the unique handler ID
func (d *Resolver) ID() types.HandlerID {
@@ -48,60 +81,147 @@ func (d *Resolver) ProbeAvailability() {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithField("request_id", resutil.GetRequestID(w))
if len(r.Question) == 0 {
log.Debugf("received local resolver request with no question")
logger.Debug("received local resolver request with no question")
return
}
question := r.Question[0]
question.Name = strings.ToLower(dns.Fqdn(question.Name))
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
// lookup all records matching the question
records := d.lookupRecords(question)
if len(records) > 0 {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
// Check if we have any records for this domain name with different types
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
} else {
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
}
result := d.lookupRecords(logger, question)
replyMessage.Authoritative = !result.hasExternalData
replyMessage.Answer = result.records
replyMessage.Rcode = d.determineRcode(question, result)
if replyMessage.Rcode == dns.RcodeNameError && d.shouldFallthrough(question.Name) {
d.continueToNext(logger, w, r)
return
}
if err := w.WriteMsg(replyMessage); err != nil {
log.Warnf("failed to write the local resolver response: %v", err)
logger.Warnf("failed to write the local resolver response: %v", err)
}
}
// determineRcode returns the appropriate DNS response code.
// Per RFC 6604, CNAME chains should return the rcode of the final target resolution,
// even if CNAME records are included in the answer.
func (d *Resolver) determineRcode(question dns.Question, result lookupResult) int {
// Use the rcode from lookup - this properly handles CNAME chains where
// the target may be NXDOMAIN or SERVFAIL even though we have CNAME records
if result.rcode != 0 {
return result.rcode
}
// No records found, but domain exists with different record types (NODATA)
if d.hasRecordsForDomain(domain.Domain(question.Name), question.Qtype) {
return dns.RcodeSuccess
}
return dns.RcodeNameError
}
// findZone finds the matching zone for a query name using reverse suffix lookup.
// Returns (nonAuthoritative, found). This is O(k) where k = number of labels in qname.
func (d *Resolver) findZone(qname string) (nonAuthoritative bool, found bool) {
qname = strings.ToLower(dns.Fqdn(qname))
for {
if nonAuth, ok := d.zones[domain.Domain(qname)]; ok {
return nonAuth, true
}
// Move to parent domain
idx := strings.Index(qname, ".")
if idx == -1 || idx == len(qname)-1 {
return false, false
}
qname = qname[idx+1:]
}
}
// shouldFallthrough checks if the query should fallthrough to the next handler.
// Returns true if the queried name belongs to a non-authoritative zone.
func (d *Resolver) shouldFallthrough(qname string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
nonAuth, found := d.findZone(qname)
return found && nonAuth
}
func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
logger.Warnf("failed to write continue signal: %v", err)
}
}
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain, qType uint16) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, exists := d.domains[domainName]
if !exists && supportsWildcard(qType) {
testWild := transformDomainToWildcard(string(domainName))
_, exists = d.domains[domain.Domain(testWild)]
}
return exists
}
// isInManagedZone checks if the given name falls within any of our managed zones.
// This is used to avoid unnecessary external resolution for CNAME targets that
// are within zones we manage - if we don't have a record for it, it doesn't exist.
// Caller must NOT hold the lock.
func (d *Resolver) isInManagedZone(name string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, found := d.findZone(name)
return found
}
// lookupResult contains the result of a DNS lookup operation.
type lookupResult struct {
records []dns.RR
rcode int
hasExternalData bool
}
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
d.mu.RLock()
records, found := d.records[question]
usingWildcard := false
wildQuestion := transformToWildcard(question)
// RFC 4592 section 2.2.1: wildcard only matches if the name does NOT exist in the zone.
// If the domain exists with any record type, return NODATA instead of wildcard match.
if !found && supportsWildcard(question.Qtype) {
if _, domainExists := d.domains[domain.Domain(question.Name)]; !domainExists {
records, found = d.records[wildQuestion]
usingWildcard = found
}
}
if !found {
d.mu.RUnlock()
// alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME {
question.Qtype = dns.TypeCNAME
return d.lookupRecords(question)
cnameQuestion := dns.Question{
Name: question.Name,
Qtype: dns.TypeCNAME,
Qclass: question.Qclass,
}
return d.lookupCNAMEChain(logger, cnameQuestion, question.Qtype)
}
return nil
return lookupResult{rcode: dns.RcodeNameError}
}
recordsCopy := slices.Clone(records)
@@ -110,29 +230,229 @@ func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
// if there's more than one record, rotate them (round-robin)
if len(recordsCopy) > 1 {
d.mu.Lock()
records = d.records[question]
q := question
if usingWildcard {
q = wildQuestion
}
records = d.records[q]
if len(records) > 1 {
first := records[0]
records = append(records[1:], first)
d.records[question] = records
d.records[q] = records
}
d.mu.Unlock()
}
return recordsCopy
if usingWildcard {
return responseFromWildRecords(question.Name, wildQuestion.Name, recordsCopy)
}
return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
}
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
func transformToWildcard(question dns.Question) dns.Question {
wildQuestion := question
wildQuestion.Name = transformDomainToWildcard(wildQuestion.Name)
return wildQuestion
}
func transformDomainToWildcard(domain string) string {
s := strings.Split(domain, ".")
s[0] = "*"
return strings.Join(s, ".")
}
func supportsWildcard(queryType uint16) bool {
return queryType != dns.TypeNS && queryType != dns.TypeSOA
}
func responseFromWildRecords(originalName, wildName string, wildRecords []dns.RR) lookupResult {
records := make([]dns.RR, len(wildRecords))
for i, record := range wildRecords {
copiedRecord := dns.Copy(record)
copiedRecord.Header().Name = originalName
records[i] = copiedRecord
}
return lookupResult{records: records, rcode: dns.RcodeSuccess}
}
// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
// the final resolved record of the requested type. This is required for musl libc
// compatibility, which expects the full answer chain rather than just the CNAME.
func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Question, targetType uint16) lookupResult {
const maxDepth = 8
var chain []dns.RR
for range maxDepth {
cnameRecords := d.getRecords(cnameQuestion)
if len(cnameRecords) == 0 && supportsWildcard(targetType) {
wildQuestion := transformToWildcard(cnameQuestion)
if wildRecords := d.getRecords(wildQuestion); len(wildRecords) > 0 {
cnameRecords = responseFromWildRecords(cnameQuestion.Name, wildQuestion.Name, wildRecords).records
}
}
if len(cnameRecords) == 0 {
break
}
chain = append(chain, cnameRecords...)
cname, ok := cnameRecords[0].(*dns.CNAME)
if !ok {
break
}
targetName := strings.ToLower(cname.Target)
targetResult := d.resolveCNAMETarget(logger, targetName, targetType, cnameQuestion.Qclass)
// keep following chain
if targetResult.rcode == -1 {
cnameQuestion = dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: cnameQuestion.Qclass}
continue
}
return d.buildChainResult(chain, targetResult)
}
if len(chain) > 0 {
return lookupResult{records: chain, rcode: dns.RcodeSuccess}
}
return lookupResult{rcode: dns.RcodeSuccess}
}
// buildChainResult combines CNAME chain records with the target resolution result.
// Per RFC 6604, the final rcode is propagated through the chain.
func (d *Resolver) buildChainResult(chain []dns.RR, target lookupResult) lookupResult {
records := chain
if len(target.records) > 0 {
records = append(records, target.records...)
}
// preserve hasExternalData for SERVFAIL so caller knows the error came from upstream
if target.hasExternalData && target.rcode == dns.RcodeServerFailure {
return lookupResult{
records: records,
rcode: dns.RcodeServerFailure,
hasExternalData: true,
}
}
return lookupResult{
records: records,
rcode: target.rcode,
hasExternalData: target.hasExternalData,
}
}
// resolveCNAMETarget attempts to resolve a CNAME target name.
// Returns rcode=-1 to signal "keep following the chain".
func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targetType uint16, qclass uint16) lookupResult {
if records := d.getRecords(dns.Question{Name: targetName, Qtype: targetType, Qclass: qclass}); len(records) > 0 {
return lookupResult{records: records, rcode: dns.RcodeSuccess}
}
// another CNAME, keep following
if d.hasRecord(dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: qclass}) {
return lookupResult{rcode: -1}
}
// domain exists locally but not this record type (NODATA)
if d.hasRecordsForDomain(domain.Domain(targetName), targetType) {
return lookupResult{rcode: dns.RcodeSuccess}
}
// in our zone but doesn't exist (NXDOMAIN)
if d.isInManagedZone(targetName) {
return lookupResult{rcode: dns.RcodeNameError}
}
return d.resolveExternal(logger, targetName, targetType)
}
func (d *Resolver) getRecords(q dns.Question) []dns.RR {
d.mu.RLock()
defer d.mu.RUnlock()
return d.records[q]
}
func (d *Resolver) hasRecord(q dns.Question) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, ok := d.records[q]
return ok
}
// resolveExternal resolves a domain name using the system resolver.
// This is used to resolve CNAME targets that point outside our local zone,
// which is required for musl libc compatibility (musl expects complete answers).
func (d *Resolver) resolveExternal(logger *log.Entry, name string, qtype uint16) lookupResult {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return lookupResult{rcode: dns.RcodeNotImplemented}
}
resolver := d.resolver
if resolver == nil {
resolver = net.DefaultResolver
}
ctx, cancel := context.WithTimeout(d.ctx, externalResolutionTimeout)
defer cancel()
result := resutil.LookupIP(ctx, resolver, network, name, qtype)
if result.Err != nil {
d.logDNSError(logger, name, qtype, result.Err)
return lookupResult{rcode: result.Rcode, hasExternalData: true}
}
return lookupResult{
records: resutil.IPsToRRs(name, result.IPs, 60),
rcode: dns.RcodeSuccess,
hasExternalData: true,
}
}
// logDNSError logs DNS resolution errors for debugging.
func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16, err error) {
qtypeName := dns.TypeToString[qtype]
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
return
}
if dnsErr.IsNotFound {
logger.Tracef("DNS target not found: %s type %s", hostname, qtypeName)
return
}
if dnsErr.Server != "" {
logger.Debugf("DNS resolution failed for %s type %s server=%s: %v", hostname, qtypeName, dnsErr.Server, err)
} else {
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
}
}
// Update replaces all zones and their records
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
maps.Clear(d.domains)
maps.Clear(d.zones)
for _, rec := range update {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
continue
for _, zone := range customZones {
zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain)))
d.zones[zoneDomain] = zone.NonAuthoritative
for _, rec := range zone.Records {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,197 @@
// Package resutil provides shared DNS resolution utilities
package resutil
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"net"
"net/netip"
"strings"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
// GenerateRequestID creates a random 8-character hex string for request tracing.
func GenerateRequestID() string {
bytes := make([]byte, 4)
if _, err := rand.Read(bytes); err != nil {
log.Errorf("generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}
// IPsToRRs converts a slice of IP addresses to DNS resource records.
// IPv4 addresses become A records, IPv6 addresses become AAAA records.
func IPsToRRs(name string, ips []netip.Addr, ttl uint32) []dns.RR {
var result []dns.RR
for _, ip := range ips {
if ip.Is6() {
result = append(result, &dns.AAAA{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: ttl,
},
AAAA: ip.AsSlice(),
})
} else {
result = append(result, &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: ttl,
},
A: ip.AsSlice(),
})
}
}
return result
}
// NetworkForQtype returns the network string ("ip4" or "ip6") for a DNS query type.
// Returns empty string for unsupported types.
func NetworkForQtype(qtype uint16) string {
switch qtype {
case dns.TypeA:
return "ip4"
case dns.TypeAAAA:
return "ip6"
default:
return ""
}
}
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
// chainedWriter is implemented by ResponseWriters that carry request metadata
type chainedWriter interface {
RequestID() string
SetMeta(key, value string)
}
// GetRequestID extracts a request ID from the ResponseWriter if available,
// otherwise generates a new one.
func GetRequestID(w dns.ResponseWriter) string {
if cw, ok := w.(chainedWriter); ok {
if id := cw.RequestID(); id != "" {
return id
}
}
return GenerateRequestID()
}
// SetMeta sets metadata on the ResponseWriter if it supports it.
func SetMeta(w dns.ResponseWriter, key, value string) {
if cw, ok := w.(chainedWriter); ok {
cw.SetMeta(key, value)
}
}
// LookupResult contains the result of an external DNS lookup
type LookupResult struct {
IPs []netip.Addr
Rcode int
Err error // Original error for caller's logging needs
}
// LookupIP performs a DNS lookup and determines the appropriate rcode.
func LookupIP(ctx context.Context, r resolver, network, host string, qtype uint16) LookupResult {
ips, err := r.LookupNetIP(ctx, network, host)
if err != nil {
return LookupResult{
Rcode: getRcodeForError(ctx, r, host, qtype, err),
Err: err,
}
}
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
return LookupResult{
IPs: ips,
Rcode: dns.RcodeSuccess,
}
}
func getRcodeForError(ctx context.Context, r resolver, host string, qtype uint16, err error) int {
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
return dns.RcodeServerFailure
}
if dnsErr.IsNotFound {
return getRcodeForNotFound(ctx, r, host, qtype)
}
return dns.RcodeServerFailure
}
// getRcodeForNotFound distinguishes between NXDOMAIN (domain doesn't exist) and NODATA
// (domain exists but no records of requested type) by checking the opposite record type.
//
// musl libc (the reason we need this distinction) only queries A/AAAA pairs in getaddrinfo,
// so checking the opposite A/AAAA type is sufficient. Other record types (MX, TXT, etc.)
// are not queried by musl and don't need this handling.
func getRcodeForNotFound(ctx context.Context, r resolver, domain string, originalQtype uint16) int {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
return dns.RcodeNameError
}
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
return dns.RcodeNameError
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
return dns.RcodeSuccess
}
// Alternative query succeeded - domain exists but has no records of this type
return dns.RcodeSuccess
}
// FormatAnswers formats DNS resource records for logging.
func FormatAnswers(answers []dns.RR) string {
if len(answers) == 0 {
return "[]"
}
parts := make([]string, 0, len(answers))
for _, rr := range answers {
switch r := rr.(type) {
case *dns.A:
parts = append(parts, r.A.String())
case *dns.AAAA:
parts = append(parts, r.AAAA.String())
case *dns.CNAME:
parts = append(parts, "CNAME:"+r.Target)
case *dns.PTR:
parts = append(parts, "PTR:"+r.Ptr)
default:
parts = append(parts, dns.TypeToString[rr.Header().Rrtype])
}
}
return "[" + strings.Join(parts, ", ") + "]"
}

View File

@@ -485,7 +485,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
}
}
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
localMuxUpdates, localZones, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("local handler updater: %w", err)
}
@@ -498,8 +498,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.updateMux(muxUpdates)
// register local records
s.localResolver.Update(localRecords)
s.localResolver.Update(localZones)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
@@ -632,9 +631,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.wgInterface,
s.statusRecorder,
s.hostsDNSHolder,
nbdns.RootZone,
@@ -659,9 +656,9 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback)
}
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) {
var muxUpdates []handlerWrapper
var localRecords []nbdns.SimpleRecord
var zones []nbdns.CustomZone
for _, customZone := range customZones {
if len(customZone.Records) == 0 {
@@ -675,17 +672,20 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
priority: PriorityLocal,
})
// zone records contain the fqdn, so we can just flatten them
var localRecords []nbdns.SimpleRecord
for _, record := range customZone.Records {
if record.Class != nbdns.DefaultClass {
log.Warnf("received an invalid class type: %s", record.Class)
continue
}
// zone records contain the fqdn, so we can just flatten them
localRecords = append(localRecords, record)
}
customZone.Records = localRecords
zones = append(zones, customZone)
}
return muxUpdates, localRecords, nil
return muxUpdates, zones, nil
}
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {
@@ -741,9 +741,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.wgInterface,
s.statusRecorder,
s.hostsDNSHolder,
domainGroup.domain,
@@ -924,9 +922,7 @@ func (s *DefaultServer) addHostRootZone() {
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.wgInterface,
s.statusRecorder,
s.hostsDNSHolder,
nbdns.RootZone,

View File

@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
@@ -81,6 +82,10 @@ func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
return configurer.WGStats{}, nil
}
func (w *mocWGIface) GetNet() *netstack.Net {
return nil
}
var zoneRecords = []nbdns.SimpleRecord{
{
Name: "peera.netbird.cloud",
@@ -128,7 +133,7 @@ func TestUpdateDNSServer(t *testing.T) {
testCases := []struct {
name string
initUpstreamMap registeredHandlerMap
initLocalRecords []nbdns.SimpleRecord
initLocalZones []nbdns.CustomZone
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
@@ -180,8 +185,8 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
},
{
name: "New Config Should Succeed",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud",
@@ -221,19 +226,19 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -251,11 +256,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -273,11 +278,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -299,8 +304,8 @@ func TestUpdateDNSServer(t *testing.T) {
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
@@ -315,8 +320,8 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
@@ -385,7 +390,7 @@ func TestUpdateDNSServer(t *testing.T) {
}()
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalRecords)
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
@@ -510,8 +515,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
priority: PriorityUpstream,
},
}
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
@@ -2048,7 +2052,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) {
func TestLocalResolverPriorityConstants(t *testing.T) {
// Test that priority constants are ordered correctly
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
assert.Greater(t, PriorityDNSRoute, PriorityLocal, "DNS Route should be higher than Local priority")
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")

View File

@@ -2,7 +2,6 @@ package dns
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
@@ -19,8 +18,10 @@ import (
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
@@ -113,10 +114,7 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
logger := log.WithField("request_id", resutil.GetRequestID(w))
u.prepareRequest(r)
@@ -202,11 +200,18 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
resutil.SetMeta(w, "upstream", upstream.String())
// Clear Zero bit from external responses to prevent upstream servers from
// manipulating our internal fallthrough signaling mechanism
rm.MsgHdr.Zero = false
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
return true
}
return true
}
@@ -414,16 +419,56 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil
}
func GenerateRequestID() string {
bytes := make([]byte, 4)
_, err := rand.Read(bytes)
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
if err != nil {
log.Errorf("failed to generate request ID: %v", err)
return ""
return nil, err
}
return hex.EncodeToString(bytes)
// If response is truncated, retry with TCP
if reply != nil && reply.MsgHdr.Truncated {
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
}
return reply, nil
}
func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream, network string) (*dns.Msg, error) {
conn, err := nsNet.DialContext(ctx, network, upstream)
if err != nil {
return nil, fmt.Errorf("with %s: %w", network, err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close DNS connection: %v", err)
}
}()
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
}
dnsConn := &dns.Conn{Conn: conn}
if err := dnsConn.WriteMsg(r); err != nil {
return nil, fmt.Errorf("write %s message: %w", network, err)
}
reply, err := dnsConn.ReadMsg()
if err != nil {
return nil, fmt.Errorf("read %s message: %w", network, err)
}
return reply, nil
}
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
func FormatPeerStatus(peerState *peer.State) string {
isConnected := peerState.ConnStatus == peer.StatusConnected

View File

@@ -23,9 +23,7 @@ type upstreamResolver struct {
// first time, and we need to wait for a while to start to use again the proper DNS resolver.
func newUpstreamResolver(
ctx context.Context,
_ string,
_ netip.Addr,
_ netip.Prefix,
_ WGIface,
statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder,
domain string,

View File

@@ -5,22 +5,23 @@ package dns
import (
"context"
"net/netip"
"runtime"
"time"
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/internal/peer"
)
type upstreamResolver struct {
*upstreamResolverBase
nsNet *netstack.Net
}
func newUpstreamResolver(
ctx context.Context,
_ string,
_ netip.Addr,
_ netip.Prefix,
wgIface WGIface,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
@@ -28,12 +29,23 @@ func newUpstreamResolver(
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
nonIOS := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
nsNet: wgIface.GetNet(),
}
upstreamResolverBase.upstreamClient = nonIOS
return nonIOS, nil
}
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
// TODO: Check if upstream DNS server is routed through a peer before using netstack.
// Similar to iOS logic, we should determine if the DNS server is reachable directly
// or needs to go through the tunnel, and only use netstack when necessary.
// For now, only use netstack on JS platform where direct access is not possible.
if u.nsNet != nil && runtime.GOOS == "js" {
start := time.Now()
reply, err := ExchangeWithNetstack(ctx, u.nsNet, r, upstream)
return reply, time.Since(start), err
}
client := &dns.Client{
Timeout: ClientTimeout,
}

View File

@@ -26,9 +26,7 @@ type upstreamResolverIOS struct {
func newUpstreamResolver(
ctx context.Context,
interfaceName string,
ip netip.Addr,
net netip.Prefix,
wgIface WGIface,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
@@ -37,9 +35,9 @@ func newUpstreamResolver(
ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase,
lIP: ip,
lNet: net,
interfaceName: interfaceName,
lIP: wgIface.Address().IP,
lNet: wgIface.Address().Network,
interfaceName: wgIface.Name(),
}
ios.upstreamClient = ios

View File

@@ -2,13 +2,17 @@ package dns
import (
"context"
"net"
"net/netip"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/test"
)
@@ -58,7 +62,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
resolver, _ := newUpstreamResolver(ctx, &mockNetstackProvider{}, nil, nil, ".")
// Convert test servers to netip.AddrPort
var servers []netip.AddrPort
for _, server := range testCase.InputServers {
@@ -112,6 +116,19 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
}
}
type mockNetstackProvider struct{}
func (m *mockNetstackProvider) Name() string { return "mock" }
func (m *mockNetstackProvider) Address() wgaddr.Address { return wgaddr.Address{} }
func (m *mockNetstackProvider) ToInterface() *net.Interface { return nil }
func (m *mockNetstackProvider) IsUserspaceBind() bool { return false }
func (m *mockNetstackProvider) GetFilter() device.PacketFilter { return nil }
func (m *mockNetstackProvider) GetDevice() *device.FilteredDevice { return nil }
func (m *mockNetstackProvider) GetNet() *netstack.Net { return nil }
func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
return "", nil
}
type mockUpstreamResolver struct {
r *dns.Msg
rtt time.Duration

View File

@@ -5,6 +5,8 @@ package dns
import (
"net"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -17,4 +19,5 @@ type WGIface interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
}

View File

@@ -1,6 +1,8 @@
package dns
import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -12,5 +14,6 @@ type WGIface interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
GetInterfaceGUIDString() (string, error)
}

View File

@@ -18,6 +18,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/route"
)
@@ -189,29 +190,22 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
return nberrors.FormatErrorOrNil(result)
}
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
if len(query.Question) == 0 {
return nil
}
question := query.Question[0]
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
question.Name, question.Qtype, question.Qclass)
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
domain := strings.ToLower(question.Name)
resp := query.SetReply(query)
var network string
switch question.Qtype {
case dns.TypeA:
network = "ip4"
case dns.TypeAAAA:
network = "ip6"
default:
// TODO: Handle other types
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
resp.Rcode = dns.RcodeNotImplemented
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
}
@@ -221,33 +215,35 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
if err != nil {
f.handleDNSError(ctx, w, question, resp, domain, err)
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
return nil
}
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips)
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
f.cache.set(domain, question.Qtype, result.IPs)
return resp
}
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query)
startTime := time.Now()
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
@@ -265,19 +261,33 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query)
startTime := time.Now()
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
@@ -315,140 +325,64 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
}
}
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
//
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
// only handles A/AAAA queries and returns NOTIMP for other types.
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
resp.Rcode = dns.RcodeNameError
return
}
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
resp.Rcode = dns.RcodeNameError
return
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
resp.Rcode = dns.RcodeSuccess
return
}
// Alternative query succeeded - domain exists but has no records of this type
resp.Rcode = dns.RcodeSuccess
}
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
func (f *DNSForwarder) handleDNSError(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
question dns.Question,
resp *dns.Msg,
domain string,
err error,
result resutil.LookupResult,
) {
// Default to SERVFAIL; override below when appropriate.
resp.Rcode = dns.RcodeServerFailure
qType := question.Qtype
qTypeName := dns.TypeToString[qType]
// Prefer typed DNS errors; fall back to generic logging otherwise.
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
log.Warnf(errResolveFailed, domain, err)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
resp.Rcode = result.Rcode
// NotFound: set NXDOMAIN / appropriate code via helper.
if dnsErr.IsNotFound {
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
// NotFound: cache negative result and respond
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
f.cache.set(domain, question.Qtype, nil)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
// Upstream failed but we might have a cached answer—serve it if present.
if ips, ok := f.cache.get(domain, qType); ok {
if len(ips) > 0 {
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
f.addIPsToResponse(resp, domain, ips)
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
resp.Rcode = dns.RcodeSuccess
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write cached DNS response: %v", writeErr)
}
} else { // send NXDOMAIN / appropriate code if cache is empty
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
logger.Errorf("failed to write cached DNS response: %v", writeErr)
}
return
}
// Cached negative result - re-verify NXDOMAIN vs NODATA
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
resp.Rcode = verifyResult.Rcode
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
return
}
// No cache. Log with or without the server field for more context.
if dnsErr.Server != "" {
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
// No cache or verification failed. Log with or without the server field for more context.
var dnsErr *net.DNSError
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
} else {
log.Warnf(errResolveFailed, domain, err)
logger.Warnf(errResolveFailed, domain, result.Err)
}
// Write final failure response.
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
}
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) {
for _, ip := range ips {
var respRecord dns.RR
if ip.Is6() {
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
rr := dns.AAAA{
AAAA: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
} else {
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
rr := dns.A{
A: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
}
resp.Answer = append(resp.Answer, respRecord)
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
}

View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
@@ -317,7 +318,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
@@ -465,7 +466,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
// Verify response
if tt.shouldResolve {
@@ -527,7 +528,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Verify response contains all IPs
require.NotNil(t, resp)
@@ -604,7 +605,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
},
}
_ = forwarder.handleDNSQuery(mockWriter, query)
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
@@ -674,7 +675,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1)
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -684,7 +685,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2)
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
@@ -714,7 +715,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1)
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -728,7 +729,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2)
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
@@ -783,7 +784,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
@@ -904,7 +905,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// If a response was returned, it means it should be written (happens in wrapper functions)
if resp != nil && writtenResp == nil {
@@ -937,7 +938,7 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
return nil
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")

View File

@@ -31,6 +31,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
@@ -42,12 +43,14 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/jobexec"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
@@ -132,6 +135,11 @@ type EngineConfig struct {
LazyConnectionEnabled bool
MTU uint16
// for debug bundle generation
ProfileConfig *profilemanager.Config
LogPath string
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -195,7 +203,8 @@ type Engine struct {
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
// Sync response persistence
// Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex
persistSyncResponse bool
latestSyncResponse *mgmProto.SyncResponse
connSemaphore *semaphoregroup.SemaphoreGroup
@@ -211,6 +220,9 @@ type Engine struct {
shutdownWg sync.WaitGroup
probeStunTurn *relay.StunTurnProbe
jobExecutor *jobexec.Executor
jobExecutorWG sync.WaitGroup
}
// Peer is an instance of the Connection Peer
@@ -224,7 +236,18 @@ type localIpUpdater interface {
}
// NewEngine creates a new Connection Engine with probes attached
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine {
func NewEngine(
clientCtx context.Context,
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
stateManager *statemanager.Manager,
) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
@@ -244,6 +267,7 @@ func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signa
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
}
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
@@ -312,6 +336,8 @@ func (e *Engine) Stop() error {
e.cancel()
}
e.jobExecutorWG.Wait() // block until job goroutines finish
e.close()
// stop flow manager after wg interface is gone
@@ -479,6 +505,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return fmt.Errorf("up wg interface: %w", err)
}
// Set the WireGuard interface for rosenpass after interface is up
if e.rpManager != nil {
e.rpManager.SetInterface(e.wgInterface)
}
// if inbound conns are blocked there is no need to create the ACL manager
if e.firewall != nil && !e.config.BlockInbound {
e.acl = acl.NewDefaultManager(e.firewall)
@@ -500,6 +531,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.receiveSignalEvents()
e.receiveManagementEvents()
e.receiveJobEvents()
// starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor()
@@ -828,9 +860,18 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil
}
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
// Read the storage-enabled flag under the syncRespMux too.
e.syncRespMux.RLock()
enabled := e.persistSyncResponse
e.syncRespMux.RUnlock()
// Store sync response if persistence is enabled
if e.persistSyncResponse {
if enabled {
e.syncRespMux.Lock()
e.latestSyncResponse = update
e.syncRespMux.Unlock()
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
}
@@ -960,6 +1001,77 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return nil
}
func (e *Engine) receiveJobEvents() {
e.jobExecutorWG.Add(1)
go func() {
defer e.jobExecutorWG.Done()
err := e.mgmClient.Job(e.ctx, func(msg *mgmProto.JobRequest) *mgmProto.JobResponse {
resp := mgmProto.JobResponse{
ID: msg.ID,
Status: mgmProto.JobStatus_failed,
}
switch params := msg.WorkloadParameters.(type) {
case *mgmProto.JobRequest_Bundle:
bundleResult, err := e.handleBundle(params.Bundle)
if err != nil {
log.Errorf("handling bundle: %v", err)
resp.Reason = []byte(err.Error())
return &resp
}
resp.Status = mgmProto.JobStatus_succeeded
resp.WorkloadResults = bundleResult
return &resp
default:
resp.Reason = []byte(jobexec.ErrJobNotImplemented.Error())
return &resp
}
})
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return
}
log.Info("stopped receiving jobs from Management Service")
}()
log.Info("connecting to Management Service jobs stream")
}
func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobResponse_Bundle, error) {
log.Infof("handle remote debug bundle request: %s", params.String())
syncResponse, err := e.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
bundleDeps := debug.GeneratorDependencies{
InternalConfig: e.config.ProfileConfig,
StatusRecorder: e.statusRecorder,
SyncResponse: syncResponse,
LogPath: e.config.LogPath,
}
bundleJobParams := debug.BundleConfig{
Anonymize: params.Anonymize,
IncludeSystemInfo: true,
LogFileCount: uint32(params.LogFileCount),
}
waitFor := time.Duration(params.BundleForTime) * time.Minute
uploadKey, err := e.jobExecutor.BundleJob(e.ctx, bundleDeps, bundleJobParams, waitFor, e.config.ProfileConfig.ManagementURL.String())
if err != nil {
return nil, err
}
response := &mgmProto.JobResponse_Bundle{
Bundle: &mgmProto.BundleResult{
UploadKey: uploadKey,
},
}
return response, nil
}
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it.
@@ -1251,11 +1363,16 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
ForwarderPort: forwarderPort,
}
for _, zone := range protoDNSConfig.GetCustomZones() {
protoZones := protoDNSConfig.GetCustomZones()
// Treat single zone as authoritative for backward compatibility with old servers
// that only send the peer FQDN zone without setting field 4.
singleZoneCompat := len(protoZones) == 1
for _, zone := range protoZones {
dnsZone := nbdns.CustomZone{
Domain: zone.GetDomain(),
SearchDomainDisabled: zone.GetSearchDomainDisabled(),
SkipPTRProcess: zone.GetSkipPTRProcess(),
NonAuthoritative: zone.GetNonAuthoritative() && !singleZoneCompat,
}
for _, record := range zone.Records {
dnsRecord := nbdns.SimpleRecord{
@@ -1400,6 +1517,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
if e.rpManager != nil {
peerConn.SetOnConnected(e.rpManager.OnConnected)
peerConn.SetOnDisconnected(e.rpManager.OnDisconnected)
peerConn.SetRosenpassInitializedPresharedKeyValidator(e.rpManager.IsPresharedKeyInitialized)
}
return peerConn, nil
@@ -1743,22 +1861,26 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
}
e.syncMsgMux.Unlock()
var results []relay.ProbeResult
if waitForResult {
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
} else {
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
}
e.statusRecorder.UpdateRelayStates(results)
// Skip STUN/TURN probing for JS/WASM as it's not available
relayHealthy := true
for _, res := range results {
if res.Err != nil {
relayHealthy = false
break
if runtime.GOOS != "js" {
var results []relay.ProbeResult
if waitForResult {
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
} else {
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
}
e.statusRecorder.UpdateRelayStates(results)
for _, res := range results {
if res.Err != nil {
relayHealthy = false
break
}
}
log.Debugf("relay health check: healthy=%t", relayHealthy)
}
log.Debugf("relay health check: healthy=%t", relayHealthy)
allHealthy := signalHealthy && managementHealthy && relayHealthy
log.Debugf("all health checks completed: healthy=%t", allHealthy)
@@ -1839,8 +1961,8 @@ func (e *Engine) stopDNSServer() {
// SetSyncResponsePersistence enables or disables sync response persistence
func (e *Engine) SetSyncResponsePersistence(enabled bool) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.syncRespMux.Lock()
defer e.syncRespMux.Unlock()
if enabled == e.persistSyncResponse {
return
@@ -1855,20 +1977,22 @@ func (e *Engine) SetSyncResponsePersistence(enabled bool) {
// GetLatestSyncResponse returns the stored sync response if persistence is enabled
func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.syncRespMux.RLock()
enabled := e.persistSyncResponse
latest := e.latestSyncResponse
e.syncRespMux.RUnlock()
if !e.persistSyncResponse {
if !enabled {
return nil, errors.New("sync response persistence is disabled")
}
if e.latestSyncResponse == nil {
if latest == nil {
//nolint:nilnil
return nil, nil
}
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(e.latestSyncResponse))
sr, ok := proto.Clone(e.latestSyncResponse).(*mgmProto.SyncResponse)
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(latest))
sr, ok := proto.Clone(latest).(*mgmProto.SyncResponse)
if !ok {
return nil, fmt.Errorf("failed to clone sync response")
}

View File

@@ -72,9 +72,16 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
}
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
audiences := protoJWT.GetAudiences()
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
audiences = []string{protoJWT.GetAudience()}
}
log.Debugf("starting SSH server with JWT authentication: audiences=%v", audiences)
jwtConfig := &sshserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audience: protoJWT.GetAudience(),
Audiences: audiences,
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
}

View File

@@ -25,6 +25,7 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/management-integrations/integrations"
@@ -213,6 +214,10 @@ func (m *MockWGIface) LastActivities() map[string]monotime.Time {
return nil
}
func (m *MockWGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
return nil
}
func TestMain(m *testing.M) {
_ = util.InitLog("debug", util.LogConsole)
code := m.Run()
@@ -1599,6 +1604,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
@@ -1622,7 +1628,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}
@@ -1631,7 +1637,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
return nil, "", err
}

View File

@@ -42,4 +42,5 @@ type wgIfaceBase interface {
GetNet() *netstack.Net
FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
}

View File

@@ -88,8 +88,9 @@ type Conn struct {
relayManager *relayClient.Manager
srWatcher *guard.SRWatcher
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string)
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string)
rosenpassInitializedPresharedKeyValidator func(peerKey string) bool
statusRelay *worker.AtomicWorkerStatus
statusICE *worker.AtomicWorkerStatus
@@ -98,7 +99,10 @@ type Conn struct {
workerICE *WorkerICE
workerRelay *WorkerRelay
wgWatcherWg sync.WaitGroup
wgWatcher *WGWatcher
wgWatcherWg sync.WaitGroup
wgWatcherCancel context.CancelFunc
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
rosenpassRemoteKey []byte
@@ -126,6 +130,7 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
connLog := log.WithField("peer", config.Key)
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
var conn = &Conn{
Log: connLog,
config: config,
@@ -137,8 +142,9 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
semaphore: services.Semaphore,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
}
return conn, nil
@@ -162,7 +168,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
@@ -231,7 +237,9 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.Log.Infof("close peer connection")
conn.ctxCancel()
conn.workerRelay.DisableWgWatcher()
if conn.wgWatcherCancel != nil {
conn.wgWatcherCancel()
}
conn.workerRelay.CloseConn()
conn.workerICE.Close()
@@ -289,6 +297,13 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
conn.onDisconnected = handler
}
// SetRosenpassInitializedPresharedKeyValidator sets a function to check if Rosenpass has taken over
// PSK management for a peer. When this returns true, presharedKey() returns nil
// to prevent UpdatePeer from overwriting the Rosenpass-managed PSK.
func (conn *Conn) SetRosenpassInitializedPresharedKeyValidator(handler func(peerKey string) bool) {
conn.rosenpassInitializedPresharedKeyValidator = handler
}
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) {
conn.dumpState.RemoteOffer()
conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
@@ -366,9 +381,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
ep = directEp
}
conn.workerRelay.DisableWgWatcher()
// todo consider to run conn.wgWatcherWg.Wait() here
if conn.wgProxyRelay != nil {
conn.wgProxyRelay.Pause()
}
@@ -390,6 +402,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.wgProxyRelay.RedirectAs(ep)
}
conn.enableWgWatcherIfNeeded()
conn.currentConnPriority = priority
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo)
@@ -423,11 +437,6 @@ func (conn *Conn) onICEStateDisconnected() {
conn.Log.Errorf("failed to switch to relay conn: %v", err)
}
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
conn.wgProxyRelay.Work()
conn.currentConnPriority = conntype.Relay
} else {
@@ -444,15 +453,15 @@ func (conn *Conn) onICEStateDisconnected() {
}
conn.statusICE.SetDisconnected()
conn.disableWgWatcherIfNeeded()
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.evalStatus(),
Relayed: conn.isRelayed(),
ConnStatusUpdate: time.Now(),
}
err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
if err != nil {
if err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState); err != nil {
conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
}
}
@@ -500,11 +509,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return
}
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
conn.enableWgWatcherIfNeeded()
wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey
@@ -519,7 +524,11 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
func (conn *Conn) onRelayDisconnected() {
conn.mu.Lock()
defer conn.mu.Unlock()
conn.handleRelayDisconnectedLocked()
}
// handleRelayDisconnectedLocked handles relay disconnection. Caller must hold conn.mu.
func (conn *Conn) handleRelayDisconnectedLocked() {
if conn.ctx.Err() != nil {
return
}
@@ -545,6 +554,8 @@ func (conn *Conn) onRelayDisconnected() {
}
conn.statusRelay.SetDisconnected()
conn.disableWgWatcherIfNeeded()
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.evalStatus(),
@@ -563,6 +574,28 @@ func (conn *Conn) onGuardEvent() {
}
}
func (conn *Conn) onWGDisconnected() {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.ctx.Err() != nil {
return
}
conn.Log.Warnf("WireGuard handshake timeout detected, closing current connection")
// Close the active connection based on current priority
switch conn.currentConnPriority {
case conntype.Relay:
conn.workerRelay.CloseConn()
conn.handleRelayDisconnectedLocked()
case conntype.ICEP2P, conntype.ICETurn:
conn.workerICE.Close()
default:
conn.Log.Debugf("No active connection to close on WG timeout")
}
}
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
peerState := State{
PubKey: conn.config.Key,
@@ -669,10 +702,17 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
}
}()
if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
// For JS platform: only relay connection is supported
if runtime.GOOS == "js" {
return conn.statusRelay.Get() == worker.StatusConnected
}
// For non-JS platforms: check ICE connection status
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false
}
// If relay is supported with peer, it must also be connected
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() == worker.StatusDisconnected {
return false
@@ -682,6 +722,25 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
return true
}
func (conn *Conn) enableWgWatcherIfNeeded() {
if !conn.wgWatcher.IsEnabled() {
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
conn.wgWatcherCancel = wgWatcherCancel
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, conn.onWGDisconnected)
}()
}
}
func (conn *Conn) disableWgWatcherIfNeeded() {
if conn.currentConnPriority == conntype.None && conn.wgWatcherCancel != nil {
conn.wgWatcherCancel()
conn.wgWatcherCancel = nil
}
}
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.Log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{
@@ -752,10 +811,24 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
return conn.config.WgConfig.PreSharedKey
}
// If Rosenpass has already set a PSK for this peer, return nil to prevent
// UpdatePeer from overwriting the Rosenpass-managed key.
if conn.rosenpassInitializedPresharedKeyValidator != nil && conn.rosenpassInitializedPresharedKeyValidator(conn.config.Key) {
return nil
}
// Use NetBird PSK as the seed for Rosenpass. This same PSK is passed to
// Rosenpass as PeerConfig.PresharedKey, ensuring the derived post-quantum
// key is cryptographically bound to the original secret.
if conn.config.WgConfig.PreSharedKey != nil {
return conn.config.WgConfig.PreSharedKey
}
// Fallback to deterministic key if no NetBird PSK is configured
determKey, err := conn.rosenpassDetermKey()
if err != nil {
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
return conn.config.WgConfig.PreSharedKey
return nil
}
return determKey

View File

@@ -284,3 +284,27 @@ func TestConn_presharedKey(t *testing.T) {
})
}
}
func TestConn_presharedKey_RosenpassManaged(t *testing.T) {
conn := Conn{
config: ConnConfig{
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
RosenpassConfig: RosenpassConfig{PubKey: []byte("dummykey")},
},
}
// When Rosenpass has already initialized the PSK for this peer,
// presharedKey must return nil to avoid UpdatePeer overwriting it.
conn.rosenpassInitializedPresharedKeyValidator = func(peerKey string) bool { return true }
if k := conn.presharedKey([]byte("remote")); k != nil {
t.Fatalf("expected nil presharedKey when Rosenpass manages PSK, got %v", k)
}
// When Rosenpass hasn't taken over yet, presharedKey should provide
// a non-nil initial key (deterministic or from NetBird PSK).
conn.rosenpassInitializedPresharedKeyValidator = func(peerKey string) bool { return false }
if k := conn.presharedKey([]byte("remote")); k == nil {
t.Fatalf("expected non-nil presharedKey before Rosenpass manages PSK")
}
}

View File

@@ -14,6 +14,7 @@ import (
"golang.org/x/exp/maps"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -158,6 +159,7 @@ type FullStatus struct {
NSGroupStates []NSGroupState
NumOfForwardingRules int
LazyConnectionEnabled bool
Events []*proto.SystemEvent
}
type StatusChangeSubscription struct {
@@ -981,6 +983,7 @@ func (d *Status) GetFullStatus() FullStatus {
}
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
fullStatus.Events = d.GetEventHistory()
return fullStatus
}
@@ -1181,3 +1184,97 @@ type EventSubscription struct {
func (s *EventSubscription) Events() <-chan *proto.SystemEvent {
return s.events
}
// ToProto converts FullStatus to proto.FullStatus.
func (fs FullStatus) ToProto() *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fs.ManagementState.URL
pbFullStatus.ManagementState.Connected = fs.ManagementState.Connected
if err := fs.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fs.SignalState.URL
pbFullStatus.SignalState.Connected = fs.SignalState.Connected
if err := fs.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fs.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fs.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fs.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fs.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fs.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fs.RosenpassState.Enabled
pbFullStatus.NumberOfForwardingRules = int32(fs.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fs.LazyConnectionEnabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fs.LocalPeerState.Routes)
for _, peerState := range fs.Peers {
networks := maps.Keys(peerState.GetRoutes())
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: networks,
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fs.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fs.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
pbFullStatus.Events = fs.Events
return &pbFullStatus
}

View File

@@ -30,10 +30,8 @@ type WGWatcher struct {
peerKey string
stateDump *stateDump
ctx context.Context
ctxCancel context.CancelFunc
ctxLock sync.Mutex
enabledTime time.Time
enabled bool
muEnabled sync.RWMutex
}
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@@ -46,52 +44,44 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
}
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
w.log.Debugf("enable WireGuard watcher")
w.ctxLock.Lock()
w.enabledTime = time.Now()
if w.ctx != nil && w.ctx.Err() == nil {
w.log.Errorf("WireGuard watcher already enabled")
w.ctxLock.Unlock()
// The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management.
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()) {
w.muEnabled.Lock()
if w.enabled {
w.muEnabled.Unlock()
return
}
ctx, ctxCancel := context.WithCancel(parentCtx)
w.ctx = ctx
w.ctxCancel = ctxCancel
w.ctxLock.Unlock()
w.log.Debugf("enable WireGuard watcher")
enabledTime := time.Now()
w.enabled = true
w.muEnabled.Unlock()
initialHandshake, err := w.wgState()
if err != nil {
w.log.Warnf("failed to read initial wg stats: %v", err)
}
w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
w.periodicHandshakeCheck(ctx, onDisconnectedFn, enabledTime, initialHandshake)
w.muEnabled.Lock()
w.enabled = false
w.muEnabled.Unlock()
}
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
func (w *WGWatcher) DisableWgWatcher() {
w.ctxLock.Lock()
defer w.ctxLock.Unlock()
if w.ctxCancel == nil {
return
}
w.log.Debugf("disable WireGuard watcher")
w.ctxCancel()
w.ctxCancel = nil
// IsEnabled returns true if the WireGuard watcher is currently enabled
func (w *WGWatcher) IsEnabled() bool {
w.muEnabled.RLock()
defer w.muEnabled.RUnlock()
return w.enabled
}
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started")
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
defer ctxCancel()
lastHandshake := initialHandshake
@@ -104,7 +94,7 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
return
}
if lastHandshake.IsZero() {
elapsed := handshake.Sub(w.enabledTime).Seconds()
elapsed := calcElapsed(enabledTime, *handshake)
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
}
@@ -134,19 +124,19 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
// the current know handshake did not change
if handshake.Equal(lastHandshake) {
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
return nil, false
}
// in case if the machine is suspended, the handshake time will be in the past
if handshake.Add(checkPeriod).Before(time.Now()) {
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
return nil, false
}
// error handling for handshake time in the future
if handshake.After(time.Now()) {
w.log.Warnf("WireGuard handshake is in the future, closing relay connection: %v", handshake)
w.log.Warnf("WireGuard handshake is in the future: %v", handshake)
return nil, false
}
@@ -164,3 +154,13 @@ func (w *WGWatcher) wgState() (time.Time, error) {
}
return wgState.LastHandshake, nil
}
// calcElapsed calculates elapsed time since watcher was enabled.
// The watcher started after the wg configuration happens, because of this need to normalise the negative value
func calcElapsed(enabledTime, handshake time.Time) float64 {
elapsed := handshake.Sub(enabledTime).Seconds()
if elapsed < 0 {
elapsed = 0
}
return elapsed
}

View File

@@ -2,6 +2,7 @@ package peer
import (
"context"
"sync"
"testing"
"time"
@@ -48,7 +49,6 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
case <-time.After(10 * time.Second):
t.Errorf("timeout")
}
watcher.DisableWgWatcher()
}
func TestWGWatcher_ReEnable(t *testing.T) {
@@ -60,14 +60,21 @@ func TestWGWatcher_ReEnable(t *testing.T) {
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{}))
ctx, cancel := context.WithCancel(context.Background())
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
watcher.EnableWgWatcher(ctx, func() {})
}()
cancel()
wg.Wait()
// Re-enable with a new context
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
onDisconnected := make(chan struct{}, 1)
go watcher.EnableWgWatcher(ctx, func() {})
time.Sleep(1 * time.Second)
watcher.DisableWgWatcher()
go watcher.EnableWgWatcher(ctx, func() {
onDisconnected <- struct{}{}
})
@@ -80,5 +87,4 @@ func TestWGWatcher_ReEnable(t *testing.T) {
case <-time.After(10 * time.Second):
t.Errorf("timeout")
}
watcher.DisableWgWatcher()
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"time"
@@ -286,8 +287,8 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
RosenpassAddr: remoteOfferAnswer.RosenpassAddr,
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
LocalIceCandidateEndpoint: net.JoinHostPort(pair.Local.Address(), strconv.Itoa(pair.Local.Port())),
RemoteIceCandidateEndpoint: net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(pair.Remote.Port())),
Relayed: isRelayed(pair),
RelayedOnLocal: isRelayCandidate(pair.Local),
}
@@ -328,13 +329,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
// wait local endpoint configuration
time.Sleep(time.Second)
addrString := pair.Remote.Address()
parsed, err := netip.ParseAddr(addrString)
if (err == nil) && (parsed.Is6()) {
addrString = fmt.Sprintf("[%s]", addrString)
//IPv6 Literals need to be wrapped in brackets for Resolve*Addr()
}
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addrString, remoteWgPort))
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(remoteWgPort)))
if err != nil {
w.log.Warnf("got an error while resolving the udp address, err: %s", err)
return
@@ -386,12 +381,44 @@ func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent,
}
}
func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
sessionID := w.SessionID()
stats := agent.GetCandidatePairsStats()
localCandidates, _ := agent.GetLocalCandidates()
remoteCandidates, _ := agent.GetRemoteCandidates()
localMap := make(map[string]ice.Candidate)
for _, c := range localCandidates {
localMap[c.ID()] = c
}
remoteMap := make(map[string]ice.Candidate)
for _, c := range remoteCandidates {
remoteMap[c.ID()] = c
}
for _, stat := range stats {
if stat.State == ice.CandidatePairStateSucceeded {
local, lok := localMap[stat.LocalCandidateID]
remote, rok := remoteMap[stat.RemoteCandidateID]
if !lok || !rok {
continue
}
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
sessionID,
local.NetworkType(), local.Type(), local.Address(),
remote.NetworkType(), remote.Type(), remote.Address(),
stat.CurrentRoundTripTime*1000)
}
}
}
func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) {
return func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
w.logSuccessfulPaths(agent)
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed:
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to

View File

@@ -30,11 +30,9 @@ type WorkerRelay struct {
relayLock sync.Mutex
relaySupportedOnRemotePeer atomic.Bool
wgWatcher *WGWatcher
}
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay {
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager) *WorkerRelay {
r := &WorkerRelay{
peerCtx: ctx,
log: log,
@@ -42,7 +40,6 @@ func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnC
config: config,
conn: conn,
relayManager: relayManager,
wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key, stateDump),
}
return r
}
@@ -93,14 +90,6 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
})
}
func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
w.wgWatcher.EnableWgWatcher(ctx, w.onWGDisconnected)
}
func (w *WorkerRelay) DisableWgWatcher() {
w.wgWatcher.DisableWgWatcher()
}
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
return w.relayManager.RelayInstanceAddress()
}
@@ -125,14 +114,6 @@ func (w *WorkerRelay) CloseConn() {
}
}
func (w *WorkerRelay) onWGDisconnected() {
w.relayLock.Lock()
_ = w.relayedConn.Close()
w.relayLock.Unlock()
w.conn.onRelayDisconnected()
}
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
if !w.relayManager.HasRelayAddress() {
return false
@@ -148,6 +129,5 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st
}
func (w *WorkerRelay) onRelayClientDisconnected() {
w.wgWatcher.DisableWgWatcher()
go w.conn.onRelayDisconnected()
}

View File

@@ -34,6 +34,7 @@ type Manager struct {
server *rp.Server
lock sync.Mutex
port int
wgIface PresharedKeySetter
}
// NewManager creates a new Rosenpass manager
@@ -109,7 +110,13 @@ func (m *Manager) generateConfig() (rp.Config, error) {
cfg.SecretKey = m.ssk
cfg.Peers = []rp.PeerConfig{}
m.rpWgHandler, _ = NewNetbirdHandler(m.preSharedKey, m.ifaceName)
m.lock.Lock()
m.rpWgHandler = NewNetbirdHandler()
if m.wgIface != nil {
m.rpWgHandler.SetInterface(m.wgIface)
}
m.lock.Unlock()
cfg.Handlers = []rp.Handler{m.rpWgHandler}
@@ -172,6 +179,20 @@ func (m *Manager) Close() error {
return nil
}
// SetInterface sets the WireGuard interface for the rosenpass handler.
// This can be called before or after Run() - the interface will be stored
// and passed to the handler when it's created or updated immediately if
// already running.
func (m *Manager) SetInterface(iface PresharedKeySetter) {
m.lock.Lock()
defer m.lock.Unlock()
m.wgIface = iface
if m.rpWgHandler != nil {
m.rpWgHandler.SetInterface(iface)
}
}
// OnConnected is a handler function that is triggered when a connection to a remote peer establishes
func (m *Manager) OnConnected(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) {
m.lock.Lock()
@@ -192,6 +213,20 @@ func (m *Manager) OnConnected(remoteWireGuardKey string, remoteRosenpassPubKey [
}
}
// IsPresharedKeyInitialized returns true if Rosenpass has completed a handshake
// and set a PSK for the given WireGuard peer.
func (m *Manager) IsPresharedKeyInitialized(wireGuardPubKey string) bool {
m.lock.Lock()
defer m.lock.Unlock()
peerID, ok := m.rpPeerIDs[wireGuardPubKey]
if !ok || peerID == nil {
return false
}
return m.rpWgHandler.IsPeerInitialized(*peerID)
}
func findRandomAvailableUDPPort() (int, error) {
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {

View File

@@ -1,46 +1,50 @@
package rosenpass
import (
"fmt"
"log/slog"
"sync"
rp "cunicu.li/go-rosenpass"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// PresharedKeySetter is the interface for setting preshared keys on WireGuard peers.
// This minimal interface allows rosenpass to update PSKs without depending on the full WGIface.
type PresharedKeySetter interface {
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
}
type wireGuardPeer struct {
Interface string
PublicKey rp.Key
}
type NetbirdHandler struct {
ifaceName string
client *wgctrl.Client
peers map[rp.PeerID]wireGuardPeer
presharedKey [32]byte
mu sync.Mutex
iface PresharedKeySetter
peers map[rp.PeerID]wireGuardPeer
initializedPeers map[rp.PeerID]bool
}
func NewNetbirdHandler(preSharedKey *[32]byte, wgIfaceName string) (hdlr *NetbirdHandler, err error) {
hdlr = &NetbirdHandler{
ifaceName: wgIfaceName,
peers: map[rp.PeerID]wireGuardPeer{},
func NewNetbirdHandler() *NetbirdHandler {
return &NetbirdHandler{
peers: map[rp.PeerID]wireGuardPeer{},
initializedPeers: map[rp.PeerID]bool{},
}
}
if preSharedKey != nil {
hdlr.presharedKey = *preSharedKey
}
if hdlr.client, err = wgctrl.New(); err != nil {
return nil, fmt.Errorf("failed to creat WireGuard client: %w", err)
}
return hdlr, nil
// SetInterface sets the WireGuard interface for the handler.
// This must be called after the WireGuard interface is created.
func (h *NetbirdHandler) SetInterface(iface PresharedKeySetter) {
h.mu.Lock()
defer h.mu.Unlock()
h.iface = iface
}
func (h *NetbirdHandler) AddPeer(pid rp.PeerID, intf string, pk rp.Key) {
h.mu.Lock()
defer h.mu.Unlock()
h.peers[pid] = wireGuardPeer{
Interface: intf,
PublicKey: pk,
@@ -48,79 +52,61 @@ func (h *NetbirdHandler) AddPeer(pid rp.PeerID, intf string, pk rp.Key) {
}
func (h *NetbirdHandler) RemovePeer(pid rp.PeerID) {
h.mu.Lock()
defer h.mu.Unlock()
delete(h.peers, pid)
delete(h.initializedPeers, pid)
}
// IsPeerInitialized returns true if Rosenpass has completed a handshake
// and set a PSK for this peer.
func (h *NetbirdHandler) IsPeerInitialized(pid rp.PeerID) bool {
h.mu.Lock()
defer h.mu.Unlock()
return h.initializedPeers[pid]
}
func (h *NetbirdHandler) HandshakeCompleted(pid rp.PeerID, key rp.Key) {
log.Debug("Handshake complete")
h.outputKey(rp.KeyOutputReasonStale, pid, key)
}
func (h *NetbirdHandler) HandshakeExpired(pid rp.PeerID) {
key, _ := rp.GeneratePresharedKey()
log.Debug("Handshake expired")
h.outputKey(rp.KeyOutputReasonStale, pid, key)
}
func (h *NetbirdHandler) outputKey(_ rp.KeyOutputReason, pid rp.PeerID, psk rp.Key) {
h.mu.Lock()
iface := h.iface
wg, ok := h.peers[pid]
isInitialized := h.initializedPeers[pid]
h.mu.Unlock()
if iface == nil {
log.Warn("rosenpass: interface not set, cannot update preshared key")
return
}
if !ok {
return
}
device, err := h.client.Device(h.ifaceName)
if err != nil {
log.Errorf("Failed to get WireGuard device: %v", err)
peerKey := wgtypes.Key(wg.PublicKey).String()
pskKey := wgtypes.Key(psk)
// Use updateOnly=true for later rotations (peer already has Rosenpass PSK)
// Use updateOnly=false for first rotation (peer has original/empty PSK)
if err := iface.SetPresharedKey(peerKey, pskKey, isInitialized); err != nil {
log.Errorf("Failed to apply rosenpass key: %v", err)
return
}
config := []wgtypes.PeerConfig{
{
UpdateOnly: true,
PublicKey: wgtypes.Key(wg.PublicKey),
PresharedKey: (*wgtypes.Key)(&psk),
},
}
for _, peer := range device.Peers {
if peer.PublicKey == wgtypes.Key(wg.PublicKey) {
if publicKeyEmpty(peer.PresharedKey) || peer.PresharedKey == h.presharedKey {
log.Debugf("Restart wireguard connection to peer %s", peer.PublicKey)
config = []wgtypes.PeerConfig{
{
PublicKey: wgtypes.Key(wg.PublicKey),
PresharedKey: (*wgtypes.Key)(&psk),
Endpoint: peer.Endpoint,
AllowedIPs: peer.AllowedIPs,
},
}
err = h.client.ConfigureDevice(wg.Interface, wgtypes.Config{
Peers: []wgtypes.PeerConfig{
{
Remove: true,
PublicKey: wgtypes.Key(wg.PublicKey),
},
},
})
if err != nil {
slog.Debug("Failed to remove peer")
return
}
}
// Mark peer as isInitialized after the successful first rotation
if !isInitialized {
h.mu.Lock()
if _, exists := h.peers[pid]; exists {
h.initializedPeers[pid] = true
}
}
if err = h.client.ConfigureDevice(wg.Interface, wgtypes.Config{
Peers: config,
}); err != nil {
log.Errorf("Failed to apply rosenpass key: %v", err)
h.mu.Unlock()
}
}
func publicKeyEmpty(key wgtypes.Key) bool {
for _, b := range key {
if b != 0 {
return false
}
}
return true
}

View File

@@ -17,12 +17,13 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
iface "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
@@ -37,11 +38,6 @@ type internalDNATer interface {
AddInternalDNATMapping(netip.Addr, netip.Addr) error
}
type wgInterface interface {
Name() string
Address() wgaddr.Address
}
type DnsInterceptor struct {
mu sync.RWMutex
route *route.Route
@@ -51,7 +47,7 @@ type DnsInterceptor struct {
dnsServer nbdns.Server
currentPeerKey string
interceptedDomains domainMap
wgInterface wgInterface
wgInterface iface.WGIface
peerStore *peerstore.Store
firewall firewall.Manager
fakeIPManager *fakeip.Manager
@@ -219,14 +215,14 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
// ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := nbdns.GenerateRequestID()
logger := log.WithField("request_id", requestID)
logger := log.WithFields(log.Fields{
"request_id": resutil.GetRequestID(w),
"dns_id": fmt.Sprintf("%04x", r.Id),
})
if len(r.Question) == 0 {
return
}
logger.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
@@ -249,12 +245,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if err != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return
}
if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true
}
@@ -263,32 +253,15 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
startTime := time.Now()
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
reply := d.queryUpstreamDNS(ctx, w, r, upstream, upstreamIP, peerKey, logger)
if reply == nil {
return
}
var answer []dns.RR
if reply != nil {
answer = reply.Answer
}
logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
resutil.SetMeta(w, "peer", peerKey)
reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil {
if err := d.writeMsg(w, reply, logger); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
}
@@ -324,11 +297,15 @@ func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
return peerAllowedIP, nil
}
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) error {
if r == nil {
return fmt.Errorf("received nil DNS message")
}
// Clear Zero bit from peer responses to prevent external sources from
// manipulating our internal fallthrough signaling mechanism
r.MsgHdr.Zero = false
if len(r.Answer) > 0 && len(r.Question) > 0 {
origPattern := ""
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
@@ -350,14 +327,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
logger.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
logger.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
continue
}
ip = addr
@@ -370,11 +347,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
}
if len(newPrefixes) > 0 {
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
log.Errorf("failed to update domain prefixes: %v", err)
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes, logger); err != nil {
logger.Errorf("failed to update domain prefixes: %v", err)
}
d.replaceIPsInDNSResponse(r, newPrefixes)
d.replaceIPsInDNSResponse(r, newPrefixes, logger)
}
}
@@ -386,22 +363,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
}
// logPrefixChanges handles the logging for prefix changes
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix, logger *log.Entry) {
if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
logger.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toAdd)
}
if len(toRemove) > 0 && !d.route.KeepRoute {
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
logger.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toRemove)
}
}
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix, logger *log.Entry) error {
d.mu.Lock()
defer d.mu.Unlock()
@@ -418,9 +395,9 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
realIP := prefix.Addr()
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
dnatMappings[fakeIP] = realIP
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
logger.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
} else {
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
logger.Errorf("failed to allocate fake IP for %s: %v", realIP, err)
}
}
}
@@ -432,7 +409,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
}
}
d.addDNATMappings(dnatMappings)
d.addDNATMappings(dnatMappings, logger)
if !d.route.KeepRoute {
// Remove old prefixes
@@ -448,7 +425,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
}
}
d.removeDNATMappings(toRemove)
d.removeDNATMappings(toRemove, logger)
}
// Update domain prefixes using resolved domain as key - store real IPs
@@ -463,14 +440,14 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
// Store real IPs for status (user-facing), not fake IPs
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove, logger)
}
return nberrors.FormatErrorOrNil(merr)
}
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix, logger *log.Entry) {
if len(realPrefixes) == 0 {
return
}
@@ -484,9 +461,9 @@ func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
realIP := prefix.Addr()
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
logger.Errorf("failed to remove DNAT mapping for %s: %v", fakeIP, err)
} else {
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
logger.Debugf("removed DNAT mapping: %s -> %s", fakeIP, realIP)
}
}
}
@@ -502,7 +479,7 @@ func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
}
// addDNATMappings adds DNAT mappings to the firewall
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr, logger *log.Entry) {
if len(mappings) == 0 {
return
}
@@ -514,9 +491,9 @@ func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
for fakeIP, realIP := range mappings {
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
logger.Errorf("failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
} else {
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
logger.Debugf("added DNAT mapping: %s -> %s", fakeIP, realIP)
}
}
}
@@ -528,12 +505,12 @@ func (d *DnsInterceptor) cleanupDNATMappings() {
}
for _, prefixes := range d.interceptedDomains {
d.removeDNATMappings(prefixes)
d.removeDNATMappings(prefixes, log.NewEntry(log.StandardLogger()))
}
}
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix, logger *log.Entry) {
if _, ok := d.internalDnatFw(); !ok {
return
}
@@ -549,7 +526,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.A = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
case *dns.AAAA:
@@ -560,7 +537,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.AAAA = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
}
}
@@ -586,6 +563,44 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
return
}
// queryUpstreamDNS queries the upstream DNS server using netstack if available, otherwise uses regular client.
// Returns the DNS reply on success, or nil on error (error responses are written internally).
func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream string, upstreamIP netip.Addr, peerKey string, logger *log.Entry) *dns.Msg {
startTime := time.Now()
nsNet := d.wgInterface.GetNet()
var reply *dns.Msg
var err error
if nsNet != nil {
reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream)
} else {
client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if clientErr != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr))
return nil
}
reply, _, err = nbdns.ExchangeWithFallback(ctx, client, r, upstream)
}
if err == nil {
return reply
}
if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
return nil
}
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
if d.statusRecorder == nil {
return ""

View File

@@ -4,6 +4,8 @@ import (
"net"
"net/netip"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -18,4 +20,5 @@ type wgIfaceBase interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
}

View File

@@ -0,0 +1,76 @@
package jobexec
import (
"context"
"errors"
"fmt"
"os"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/upload-server/types"
)
const (
MaxBundleWaitTime = 60 * time.Minute // maximum wait time for bundle generation (1 hour)
)
var (
ErrJobNotImplemented = errors.New("job not implemented")
)
type Executor struct {
}
func NewExecutor() *Executor {
return &Executor{}
}
func (e *Executor) BundleJob(ctx context.Context, debugBundleDependencies debug.GeneratorDependencies, params debug.BundleConfig, waitForDuration time.Duration, mgmURL string) (string, error) {
if waitForDuration > MaxBundleWaitTime {
log.Warnf("bundle wait time %v exceeds maximum %v, capping to maximum", waitForDuration, MaxBundleWaitTime)
waitForDuration = MaxBundleWaitTime
}
if waitForDuration > 0 {
if err := waitFor(ctx, waitForDuration); err != nil {
return "", err
}
}
log.Infof("execute debug bundle generation")
bundleGenerator := debug.NewBundleGenerator(debugBundleDependencies, params)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
defer func() {
if err := os.Remove(path); err != nil {
log.Errorf("failed to remove debug bundle file: %v", err)
}
}()
key, err := debug.UploadDebugBundle(ctx, types.DefaultBundleURL, mgmURL, path)
if err != nil {
log.Errorf("failed to upload debug bundle: %v", err)
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle has been generated successfully")
return key, nil
}
func waitFor(ctx context.Context, duration time.Duration) error {
log.Infof("wait for %v minutes before executing debug bundle", duration.Minutes())
select {
case <-time.After(duration):
return nil
case <-ctx.Done():
log.Infof("wait cancelled: %v", ctx.Err())
return ctx.Err()
}
}

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
// protoc v3.21.12
// protoc v6.32.1
// source: daemon.proto
package proto
@@ -2757,7 +2757,6 @@ func (x *ForwardingRulesResponse) GetRules() []*ForwardingRule {
type DebugBundleRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"`
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"`
LogFileCount uint32 `protobuf:"varint,5,opt,name=logFileCount,proto3" json:"logFileCount,omitempty"`
@@ -2802,13 +2801,6 @@ func (x *DebugBundleRequest) GetAnonymize() bool {
return false
}
func (x *DebugBundleRequest) GetStatus() string {
if x != nil {
return x.Status
}
return ""
}
func (x *DebugBundleRequest) GetSystemInfo() bool {
if x != nil {
return x.SystemInfo
@@ -5372,6 +5364,154 @@ func (x *WaitJWTTokenResponse) GetExpiresIn() int64 {
return 0
}
// StartCPUProfileRequest for starting CPU profiling
type StartCPUProfileRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *StartCPUProfileRequest) Reset() {
*x = StartCPUProfileRequest{}
mi := &file_daemon_proto_msgTypes[79]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *StartCPUProfileRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StartCPUProfileRequest) ProtoMessage() {}
func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[79]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StartCPUProfileRequest.ProtoReflect.Descriptor instead.
func (*StartCPUProfileRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{79}
}
// StartCPUProfileResponse confirms CPU profiling has started
type StartCPUProfileResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *StartCPUProfileResponse) Reset() {
*x = StartCPUProfileResponse{}
mi := &file_daemon_proto_msgTypes[80]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *StartCPUProfileResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StartCPUProfileResponse) ProtoMessage() {}
func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[80]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StartCPUProfileResponse.ProtoReflect.Descriptor instead.
func (*StartCPUProfileResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{80}
}
// StopCPUProfileRequest for stopping CPU profiling
type StopCPUProfileRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *StopCPUProfileRequest) Reset() {
*x = StopCPUProfileRequest{}
mi := &file_daemon_proto_msgTypes[81]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *StopCPUProfileRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StopCPUProfileRequest) ProtoMessage() {}
func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[81]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StopCPUProfileRequest.ProtoReflect.Descriptor instead.
func (*StopCPUProfileRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{81}
}
// StopCPUProfileResponse confirms CPU profiling has stopped
type StopCPUProfileResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *StopCPUProfileResponse) Reset() {
*x = StopCPUProfileResponse{}
mi := &file_daemon_proto_msgTypes[82]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *StopCPUProfileResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StopCPUProfileResponse) ProtoMessage() {}
func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[82]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StopCPUProfileResponse.ProtoReflect.Descriptor instead.
func (*StopCPUProfileResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{82}
}
type InstallerResultRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -5380,7 +5520,7 @@ type InstallerResultRequest struct {
func (x *InstallerResultRequest) Reset() {
*x = InstallerResultRequest{}
mi := &file_daemon_proto_msgTypes[79]
mi := &file_daemon_proto_msgTypes[83]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -5392,7 +5532,7 @@ func (x *InstallerResultRequest) String() string {
func (*InstallerResultRequest) ProtoMessage() {}
func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[79]
mi := &file_daemon_proto_msgTypes[83]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -5405,7 +5545,7 @@ func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead.
func (*InstallerResultRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{79}
return file_daemon_proto_rawDescGZIP(), []int{83}
}
type InstallerResultResponse struct {
@@ -5418,7 +5558,7 @@ type InstallerResultResponse struct {
func (x *InstallerResultResponse) Reset() {
*x = InstallerResultResponse{}
mi := &file_daemon_proto_msgTypes[80]
mi := &file_daemon_proto_msgTypes[84]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -5430,7 +5570,7 @@ func (x *InstallerResultResponse) String() string {
func (*InstallerResultResponse) ProtoMessage() {}
func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[80]
mi := &file_daemon_proto_msgTypes[84]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -5443,7 +5583,7 @@ func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead.
func (*InstallerResultResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{80}
return file_daemon_proto_rawDescGZIP(), []int{84}
}
func (x *InstallerResultResponse) GetSuccess() bool {
@@ -5470,7 +5610,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{}
mi := &file_daemon_proto_msgTypes[82]
mi := &file_daemon_proto_msgTypes[86]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -5482,7 +5622,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[82]
mi := &file_daemon_proto_msgTypes[86]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -5773,10 +5913,9 @@ const file_daemon_proto_rawDesc = "" +
"\x12translatedHostname\x18\x04 \x01(\tR\x12translatedHostname\x128\n" +
"\x0etranslatedPort\x18\x05 \x01(\v2\x10.daemon.PortInfoR\x0etranslatedPort\"G\n" +
"\x17ForwardingRulesResponse\x12,\n" +
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\xac\x01\n" +
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\x94\x01\n" +
"\x12DebugBundleRequest\x12\x1c\n" +
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x16\n" +
"\x06status\x18\x02 \x01(\tR\x06status\x12\x1e\n" +
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x1e\n" +
"\n" +
"systemInfo\x18\x03 \x01(\bR\n" +
"systemInfo\x12\x1c\n" +
@@ -6003,6 +6142,10 @@ const file_daemon_proto_rawDesc = "" +
"\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" +
"\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" +
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn\"\x18\n" +
"\x16StartCPUProfileRequest\"\x19\n" +
"\x17StartCPUProfileResponse\"\x17\n" +
"\x15StopCPUProfileRequest\"\x18\n" +
"\x16StopCPUProfileResponse\"\x18\n" +
"\x16InstallerResultRequest\"O\n" +
"\x17InstallerResultResponse\x12\x18\n" +
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
@@ -6015,7 +6158,7 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" +
"\x05TRACE\x10\a2\xb4\x13\n" +
"\x05TRACE\x10\a2\xdd\x14\n" +
"\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -6050,7 +6193,9 @@ const file_daemon_proto_rawDesc = "" +
"\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" +
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" +
"\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" +
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" +
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12T\n" +
"\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" +
"\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" +
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3"
@@ -6067,7 +6212,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
}
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 84)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 88)
var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType
@@ -6152,21 +6297,25 @@ var file_daemon_proto_goTypes = []any{
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
(*InstallerResultRequest)(nil), // 83: daemon.InstallerResultRequest
(*InstallerResultResponse)(nil), // 84: daemon.InstallerResultResponse
nil, // 85: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 86: daemon.PortInfo.Range
nil, // 87: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 88: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 89: google.protobuf.Timestamp
(*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest
(*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse
(*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest
(*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse
(*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest
(*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse
nil, // 89: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 90: daemon.PortInfo.Range
nil, // 91: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 92: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 93: google.protobuf.Timestamp
}
var file_daemon_proto_depIdxs = []int32{
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
88, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
92, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
89, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
89, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
88, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
93, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
93, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
92, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
@@ -6177,8 +6326,8 @@ var file_daemon_proto_depIdxs = []int32{
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
85, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
86, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
89, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
90, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
@@ -6189,10 +6338,10 @@ var file_daemon_proto_depIdxs = []int32{
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
89, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
87, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
93, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
91, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
88, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
92, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
@@ -6226,43 +6375,47 @@ var file_daemon_proto_depIdxs = []int32{
77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
83, // 66: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
8, // 67: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
10, // 68: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
12, // 69: daemon.DaemonService.Up:output_type -> daemon.UpResponse
14, // 70: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
16, // 71: daemon.DaemonService.Down:output_type -> daemon.DownResponse
18, // 72: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
29, // 73: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
31, // 74: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 75: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
36, // 76: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
38, // 77: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
40, // 78: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
42, // 79: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
45, // 80: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
47, // 81: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
49, // 82: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
51, // 83: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
55, // 84: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
57, // 85: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
59, // 86: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
61, // 87: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
63, // 88: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
65, // 89: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
67, // 90: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
69, // 91: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
72, // 92: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
74, // 93: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
76, // 94: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
78, // 95: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
80, // 96: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
82, // 97: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
6, // 98: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
84, // 99: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
67, // [67:100] is the sub-list for method output_type
34, // [34:67] is the sub-list for method input_type
83, // 65: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
85, // 66: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
5, // 67: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
87, // 68: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
8, // 69: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
10, // 70: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
12, // 71: daemon.DaemonService.Up:output_type -> daemon.UpResponse
14, // 72: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
16, // 73: daemon.DaemonService.Down:output_type -> daemon.DownResponse
18, // 74: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
29, // 75: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
31, // 76: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 77: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
36, // 78: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
38, // 79: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
40, // 80: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
42, // 81: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
45, // 82: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
47, // 83: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
49, // 84: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
51, // 85: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
55, // 86: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
57, // 87: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
59, // 88: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
61, // 89: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
63, // 90: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
65, // 91: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
67, // 92: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
69, // 93: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
72, // 94: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
74, // 95: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
76, // 96: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
78, // 97: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
80, // 98: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
82, // 99: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
84, // 100: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
86, // 101: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
6, // 102: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
88, // 103: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
69, // [69:104] is the sub-list for method output_type
34, // [34:69] is the sub-list for method input_type
34, // [34:34] is the sub-list for extension type_name
34, // [34:34] is the sub-list for extension extendee
0, // [0:34] is the sub-list for field type_name
@@ -6292,7 +6445,7 @@ func file_daemon_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
NumEnums: 4,
NumMessages: 84,
NumMessages: 88,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -94,6 +94,12 @@ service DaemonService {
// WaitJWTToken waits for JWT authentication completion
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
// StartCPUProfile starts CPU profiling in the daemon
rpc StartCPUProfile(StartCPUProfileRequest) returns (StartCPUProfileResponse) {}
// StopCPUProfile stops CPU profiling in the daemon
rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {}
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
@@ -455,7 +461,6 @@ message ForwardingRulesResponse {
// DebugBundler
message DebugBundleRequest {
bool anonymize = 1;
string status = 2;
bool systemInfo = 3;
string uploadURL = 4;
uint32 logFileCount = 5;
@@ -777,6 +782,18 @@ message WaitJWTTokenResponse {
int64 expiresIn = 3;
}
// StartCPUProfileRequest for starting CPU profiling
message StartCPUProfileRequest {}
// StartCPUProfileResponse confirms CPU profiling has started
message StartCPUProfileResponse {}
// StopCPUProfileRequest for stopping CPU profiling
message StopCPUProfileRequest {}
// StopCPUProfileResponse confirms CPU profiling has stopped
message StopCPUProfileResponse {}
message InstallerResultRequest {
}

View File

@@ -70,6 +70,10 @@ type DaemonServiceClient interface {
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
// StartCPUProfile starts CPU profiling in the daemon
StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error)
// StopCPUProfile stops CPU profiling in the daemon
StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error)
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
}
@@ -384,6 +388,24 @@ func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTToken
return out, nil
}
func (c *daemonServiceClient) StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error) {
out := new(StartCPUProfileResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/StartCPUProfile", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error) {
out := new(StopCPUProfileResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/StopCPUProfile", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) {
out := new(OSLifecycleResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...)
@@ -458,6 +480,10 @@ type DaemonServiceServer interface {
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
// StartCPUProfile starts CPU profiling in the daemon
StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error)
// StopCPUProfile stops CPU profiling in the daemon
StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error)
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
@@ -560,6 +586,12 @@ func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *Request
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
}
func (UnimplementedDaemonServiceServer) StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method StartCPUProfile not implemented")
}
func (UnimplementedDaemonServiceServer) StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method StopCPUProfile not implemented")
}
func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
}
@@ -1140,6 +1172,42 @@ func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, d
return interceptor(ctx, in, info, handler)
}
func _DaemonService_StartCPUProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(StartCPUProfileRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).StartCPUProfile(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/StartCPUProfile",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).StartCPUProfile(ctx, req.(*StartCPUProfileRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_StopCPUProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(StopCPUProfileRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).StopCPUProfile(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/StopCPUProfile",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).StopCPUProfile(ctx, req.(*StopCPUProfileRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(OSLifecycleRequest)
if err := dec(in); err != nil {
@@ -1303,6 +1371,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "WaitJWTToken",
Handler: _DaemonService_WaitJWTToken_Handler,
},
{
MethodName: "StartCPUProfile",
Handler: _DaemonService_StartCPUProfile_Handler,
},
{
MethodName: "StopCPUProfile",
Handler: _DaemonService_StopCPUProfile_Handler,
},
{
MethodName: "NotifyOSLifecycle",
Handler: _DaemonService_NotifyOSLifecycle_Handler,

View File

@@ -3,25 +3,19 @@
package server
import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"runtime/pprof"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/proto"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
)
const maxBundleUploadSize = 50 * 1024 * 1024
// DebugBundle creates a debug bundle and returns the location.
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
s.mutex.Lock()
@@ -32,16 +26,24 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
log.Warnf("failed to get latest sync response: %v", err)
}
var cpuProfileData []byte
if s.cpuProfileBuf != nil && !s.cpuProfiling {
cpuProfileData = s.cpuProfileBuf.Bytes()
defer func() {
s.cpuProfileBuf = nil
}()
}
bundleGenerator := debug.NewBundleGenerator(
debug.GeneratorDependencies{
InternalConfig: s.config,
StatusRecorder: s.statusRecorder,
SyncResponse: syncResponse,
LogFile: s.logFile,
LogPath: s.logFile,
CPUProfile: cpuProfileData,
},
debug.BundleConfig{
Anonymize: req.GetAnonymize(),
ClientStatus: req.GetStatus(),
IncludeSystemInfo: req.GetSystemInfo(),
LogFileCount: req.GetLogFileCount(),
},
@@ -55,7 +57,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
if req.GetUploadURL() == "" {
return &proto.DebugBundleResponse{Path: path}, nil
}
key, err := uploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
key, err := debug.UploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
if err != nil {
log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err)
return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil
@@ -66,92 +68,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil
}
func uploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
response, err := getUploadURL(ctx, url, managementURL)
if err != nil {
return "", err
}
err = upload(ctx, filePath, response)
if err != nil {
return "", err
}
return response.Key, nil
}
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
fileData, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("open file: %w", err)
}
defer fileData.Close()
stat, err := fileData.Stat()
if err != nil {
return fmt.Errorf("stat file: %w", err)
}
if stat.Size() > maxBundleUploadSize {
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
}
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
if err != nil {
return fmt.Errorf("create PUT request: %w", err)
}
req.ContentLength = stat.Size()
req.Header.Set("Content-Type", "application/octet-stream")
putResp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("upload failed: %v", err)
}
defer putResp.Body.Close()
if putResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(putResp.Body)
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
}
return nil
}
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
id := getURLHash(managementURL)
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
if err != nil {
return nil, fmt.Errorf("create GET request: %w", err)
}
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
resp, err := http.DefaultClient.Do(getReq)
if err != nil {
return nil, fmt.Errorf("get presigned URL: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
}
urlBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body: %w", err)
}
var response types.GetURLResponse
if err := json.Unmarshal(urlBytes, &response); err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
return &response, nil
}
func getURLHash(url string) string {
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
}
// GetLogLevel gets the current logging level for the server.
func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) {
s.mutex.Lock()
@@ -173,20 +89,9 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
log.SetLevel(level)
if s.connectClient == nil {
return nil, fmt.Errorf("connect client not initialized")
if s.connectClient != nil {
s.connectClient.SetLogLevel(level)
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("engine not initialized")
}
fwManager := engine.GetFirewallManager()
if fwManager == nil {
return nil, fmt.Errorf("firewall manager not initialized")
}
fwManager.SetLogLevel(level)
log.Infof("Log level set to %s", level.String())
@@ -215,3 +120,43 @@ func (s *Server) getLatestSyncResponse() (*mgmProto.SyncResponse, error) {
return cClient.GetLatestSyncResponse()
}
// StartCPUProfile starts CPU profiling in the daemon.
func (s *Server) StartCPUProfile(_ context.Context, _ *proto.StartCPUProfileRequest) (*proto.StartCPUProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.cpuProfiling {
return nil, fmt.Errorf("CPU profiling already in progress")
}
s.cpuProfileBuf = &bytes.Buffer{}
s.cpuProfiling = true
if err := pprof.StartCPUProfile(s.cpuProfileBuf); err != nil {
s.cpuProfileBuf = nil
s.cpuProfiling = false
return nil, fmt.Errorf("start CPU profile: %w", err)
}
log.Info("CPU profiling started")
return &proto.StartCPUProfileResponse{}, nil
}
// StopCPUProfile stops CPU profiling in the daemon.
func (s *Server) StopCPUProfile(_ context.Context, _ *proto.StopCPUProfileRequest) (*proto.StopCPUProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if !s.cpuProfiling {
return nil, fmt.Errorf("CPU profiling not in progress")
}
pprof.StopCPUProfile()
s.cpuProfiling = false
if s.cpuProfileBuf != nil {
log.Infof("CPU profiling stopped, captured %d bytes", s.cpuProfileBuf.Len())
}
return &proto.StopCPUProfileResponse{}, nil
}

View File

@@ -1,8 +1,6 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto"
@@ -29,8 +27,3 @@ func (s *Server) SubscribeEvents(req *proto.SubscribeRequest, stream proto.Daemo
}
}
}
func (s *Server) GetEvents(context.Context, *proto.GetEventsRequest) (*proto.GetEventsResponse, error) {
events := s.statusRecorder.GetEventHistory()
return &proto.GetEventsResponse{Events: events}, nil
}

View File

@@ -1,6 +1,7 @@
package server
import (
"bytes"
"context"
"errors"
"fmt"
@@ -13,15 +14,11 @@ import (
"time"
"github.com/cenkalti/backoff/v4"
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/types/known/durationpb"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
@@ -70,7 +67,7 @@ type Server struct {
proto.UnimplementedDaemonServiceServer
clientRunning bool // protected by mutex
clientRunningChan chan struct{}
clientGiveUpChan chan struct{}
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
connectClient *internal.ConnectClient
@@ -81,6 +78,9 @@ type Server struct {
persistSyncResponse bool
isSessionActive atomic.Bool
cpuProfileBuf *bytes.Buffer
cpuProfiling bool
profileManager *profilemanager.ServiceManager
profilesDisabled bool
updateSettingsDisabled bool
@@ -796,9 +796,11 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
// Down engine work in the daemon.
func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
giveUpChan := s.clientGiveUpChan
if err := s.cleanupConnection(); err != nil {
s.mutex.Unlock()
// todo review to update the status in case any type of error
log.Errorf("failed to shut down properly: %v", err)
return nil, err
@@ -807,6 +809,20 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
s.mutex.Unlock()
// Wait for the connectWithRetryRuns goroutine to finish with a short timeout.
// This prevents the goroutine from setting ErrResetConnection after Down() returns.
// The giveUpChan is closed at the end of connectWithRetryRuns.
if giveUpChan != nil {
select {
case <-giveUpChan:
log.Debugf("client goroutine finished successfully")
case <-time.After(5 * time.Second):
log.Warnf("timeout waiting for client goroutine to finish, proceeding anyway")
}
}
return &proto.DownResponse{}, nil
}
@@ -1067,11 +1083,9 @@ func (s *Server) Status(
if msg.GetFullPeerStatus {
s.runProbes(msg.ShouldRunProbes)
fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus)
pbFullStatus := fullStatus.ToProto()
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState()
statusResponse.FullStatus = pbFullStatus
}
@@ -1526,7 +1540,7 @@ func (s *Server) connect(ctx context.Context, config *profilemanager.Config, sta
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
if err := s.connectClient.Run(runningChan); err != nil {
if err := s.connectClient.Run(runningChan, s.logFile); err != nil {
return err
}
return nil
@@ -1600,94 +1614,6 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio
return defaultDuration
}
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
if err := fullStatus.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
if err := fullStatus.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fullStatus.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fullStatus.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
return &pbFullStatus
}
// sendTerminalNotification sends a terminal notification message
// to inform the user that the NetBird connection session has expired.
func sendTerminalNotification() error {

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
@@ -306,6 +307,8 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
peersManager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersManager)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
@@ -317,7 +320,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
return nil, "", err
}
@@ -326,7 +329,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
return nil, "", err
}

View File

@@ -132,7 +132,7 @@ func TestSSHProxy_Connect(t *testing.T) {
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audience: audience,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}

View File

@@ -43,7 +43,7 @@ func TestJWTEnforcement(t *testing.T) {
t.Run("blocks_without_jwt", func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
Audiences: []string{"test-audience"},
KeysLocation: "test-keys",
}
serverConfig := &Config{
@@ -202,7 +202,7 @@ func TestJWTDetection(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
Audiences: []string{audience},
KeysLocation: jwksURL,
}
serverConfig := &Config{
@@ -329,7 +329,7 @@ func TestJWTFailClose(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
Audiences: []string{audience},
KeysLocation: jwksURL,
MaxTokenAge: 3600,
}
@@ -567,7 +567,7 @@ func TestJWTAuthentication(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
Audiences: []string{audience},
KeysLocation: jwksURL,
}
serverConfig := &Config{
@@ -646,3 +646,108 @@ func TestJWTAuthentication(t *testing.T) {
})
}
}
// TestJWTMultipleAudiences tests JWT validation with multiple audiences (dashboard and CLI).
func TestJWTMultipleAudiences(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT multiple audiences tests in short mode")
}
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
const (
issuer = "https://test-issuer.example.com"
dashboardAudience = "dashboard-audience"
cliAudience = "cli-audience"
)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testCases := []struct {
name string
audience string
wantAuthOK bool
}{
{
name: "accepts_dashboard_audience",
audience: dashboardAudience,
wantAuthOK: true,
},
{
name: "accepts_cli_audience",
audience: cliAudience,
wantAuthOK: true,
},
{
name: "rejects_unknown_audience",
audience: "unknown-audience",
wantAuthOK: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audiences: []string{dashboardAudience, cliAudience},
KeysLocation: jwksURL,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
testUserHash, err := sshuserhash.HashUserID("test-user")
require.NoError(t, err)
currentUser := testutil.GetTestUsername(t)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
currentUser: {0},
},
}
server.UpdateSSHAuth(authConfig)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
token := generateValidJWT(t, privateKey, issuer, tc.audience)
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{
cryptossh.Password(token),
},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
if tc.wantAuthOK {
require.NoError(t, err, "JWT authentication should succeed for audience %s", tc.audience)
defer func() {
if err := conn.Close(); err != nil {
t.Logf("close connection: %v", err)
}
}()
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
err = session.Shell()
require.NoError(t, err, "Shell should work with valid audience")
} else {
assert.Error(t, err, "JWT authentication should fail for unknown audience")
}
})
}
}

View File

@@ -176,9 +176,9 @@ type Server struct {
type JWTConfig struct {
Issuer string
Audience string
KeysLocation string
MaxTokenAge int64
Audiences []string
}
// Config contains all SSH server configuration options
@@ -427,18 +427,21 @@ func (s *Server) ensureJWTValidator() error {
return fmt.Errorf("JWT config not set")
}
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
if len(config.Audiences) == 0 {
return fmt.Errorf("JWT config has no audiences configured")
}
log.Debugf("Initializing JWT validator (issuer: %s, audiences: %v)", config.Issuer, config.Audiences)
validator := jwt.NewValidator(
config.Issuer,
[]string{config.Audience},
config.Audiences,
config.KeysLocation,
true,
)
// Use custom userIDClaim from authorizer if available
extractorOptions := []jwt.ClaimsExtractorOption{
jwt.WithAudience(config.Audience),
jwt.WithAudience(config.Audiences[0]),
}
if authorizer.GetUserIDClaim() != "" {
extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
@@ -475,8 +478,8 @@ func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
if err != nil {
if jwtConfig != nil {
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
return nil, fmt.Errorf("validate token (expected issuer=%s, audiences=%v, actual issuer=%v, audience=%v): %w",
jwtConfig.Issuer, jwtConfig.Audiences, claims["iss"], claims["aud"], err)
}
}
return nil, fmt.Errorf("validate token: %w", err)

View File

@@ -11,8 +11,12 @@ import (
"strings"
"time"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"gopkg.in/yaml.v3"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
probeRelay "github.com/netbirdio/netbird/client/internal/relay"
@@ -116,9 +120,7 @@ type OutputOverview struct {
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
}
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview {
pbFullStatus := resp.GetFullStatus()
func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, daemonVersion string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview {
managementState := pbFullStatus.GetManagementState()
managementOverview := ManagementStateOutput{
URL: managementState.GetURL(),
@@ -134,13 +136,13 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
}
relayOverview := mapRelays(pbFullStatus.GetRelays())
peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
peersOverview := mapPeers(pbFullStatus.GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
overview := OutputOverview{
Peers: peersOverview,
CliVersion: version.NetbirdVersion(),
DaemonVersion: resp.GetDaemonVersion(),
DaemonVersion: daemonVersion,
ManagementState: managementOverview,
SignalState: signalOverview,
Relays: relayOverview,
@@ -325,61 +327,64 @@ func sortPeersByIP(peersStateDetail []PeerStateDetailOutput) {
}
}
func ParseToJSON(overview OutputOverview) (string, error) {
jsonBytes, err := json.Marshal(overview)
// JSON returns the status overview as a JSON string.
func (o *OutputOverview) JSON() (string, error) {
jsonBytes, err := json.Marshal(o)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
func ParseToYAML(overview OutputOverview) (string, error) {
yamlBytes, err := yaml.Marshal(overview)
// YAML returns the status overview as a YAML string.
func (o *OutputOverview) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(o)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
// GeneralSummary returns a general summary of the status overview.
func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
var managementConnString string
if overview.ManagementState.Connected {
if o.ManagementState.Connected {
managementConnString = "Connected"
if showURL {
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
managementConnString = fmt.Sprintf("%s to %s", managementConnString, o.ManagementState.URL)
}
} else {
managementConnString = "Disconnected"
if overview.ManagementState.Error != "" {
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
if o.ManagementState.Error != "" {
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, o.ManagementState.Error)
}
}
var signalConnString string
if overview.SignalState.Connected {
if o.SignalState.Connected {
signalConnString = "Connected"
if showURL {
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
signalConnString = fmt.Sprintf("%s to %s", signalConnString, o.SignalState.URL)
}
} else {
signalConnString = "Disconnected"
if overview.SignalState.Error != "" {
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
if o.SignalState.Error != "" {
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, o.SignalState.Error)
}
}
interfaceTypeString := "Userspace"
interfaceIP := overview.IP
if overview.KernelInterface {
interfaceIP := o.IP
if o.KernelInterface {
interfaceTypeString = "Kernel"
} else if overview.IP == "" {
} else if o.IP == "" {
interfaceTypeString = "N/A"
interfaceIP = "N/A"
}
var relaysString string
if showRelays {
for _, relay := range overview.Relays.Details {
for _, relay := range o.Relays.Details {
available := "Available"
reason := ""
@@ -395,18 +400,18 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
}
} else {
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
relaysString = fmt.Sprintf("%d/%d Available", o.Relays.Available, o.Relays.Total)
}
networks := "-"
if len(overview.Networks) > 0 {
sort.Strings(overview.Networks)
networks = strings.Join(overview.Networks, ", ")
if len(o.Networks) > 0 {
sort.Strings(o.Networks)
networks = strings.Join(o.Networks, ", ")
}
var dnsServersString string
if showNameServers {
for _, nsServerGroup := range overview.NSServerGroups {
for _, nsServerGroup := range o.NSServerGroups {
enabled := "Available"
if !nsServerGroup.Enabled {
enabled = "Unavailable"
@@ -430,25 +435,25 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
)
}
} else {
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(o.NSServerGroups), len(o.NSServerGroups))
}
rosenpassEnabledStatus := "false"
if overview.RosenpassEnabled {
if o.RosenpassEnabled {
rosenpassEnabledStatus = "true"
if overview.RosenpassPermissive {
if o.RosenpassPermissive {
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
}
}
lazyConnectionEnabledStatus := "false"
if overview.LazyConnectionEnabled {
if o.LazyConnectionEnabled {
lazyConnectionEnabledStatus = "true"
}
sshServerStatus := "Disabled"
if overview.SSHServerState.Enabled {
sessionCount := len(overview.SSHServerState.Sessions)
if o.SSHServerState.Enabled {
sessionCount := len(o.SSHServerState.Sessions)
if sessionCount > 0 {
sessionWord := "session"
if sessionCount > 1 {
@@ -460,7 +465,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
}
if showSSHSessions && sessionCount > 0 {
for _, session := range overview.SSHServerState.Sessions {
for _, session := range o.SSHServerState.Sessions {
var sessionDisplay string
if session.JWTUsername != "" {
sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s",
@@ -484,7 +489,12 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
}
}
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
var forwardingRulesString string
if o.NumberOfForwardingRules > 0 {
forwardingRulesString = fmt.Sprintf("Forwarding rules: %d\n", o.NumberOfForwardingRules)
}
goos := runtime.GOOS
goarch := runtime.GOARCH
@@ -509,33 +519,34 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
"Lazy connection: %s\n"+
"SSH Server: %s\n"+
"Networks: %s\n"+
"Forwarding rules: %d\n"+
"%s"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion,
o.DaemonVersion,
version.NetbirdVersion(),
overview.ProfileName,
o.ProfileName,
managementConnString,
signalConnString,
relaysString,
dnsServersString,
domain.Domain(overview.FQDN).SafeString(),
domain.Domain(o.FQDN).SafeString(),
interfaceIP,
interfaceTypeString,
rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
sshServerStatus,
networks,
overview.NumberOfForwardingRules,
forwardingRulesString,
peersCountString,
)
return summary
}
func ParseToFullDetailSummary(overview OutputOverview) string {
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
parsedEventsString := parseEvents(overview.Events)
summary := ParseGeneralSummary(overview, true, true, true, true)
// FullDetailSummary returns a full detailed summary with peer details and events.
func (o *OutputOverview) FullDetailSummary() string {
parsedPeersString := parsePeers(o.Peers, o.RosenpassEnabled, o.RosenpassPermissive)
parsedEventsString := parseEvents(o.Events)
summary := o.GeneralSummary(true, true, true, true)
return fmt.Sprintf(
"Peers detail:"+
@@ -549,6 +560,94 @@ func ParseToFullDetailSummary(overview OutputOverview) string {
)
}
func ToProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
if err := fullStatus.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
if err := fullStatus.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fullStatus.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fullStatus.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
return &pbFullStatus
}
func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
var (
peersString = ""

View File

@@ -238,7 +238,7 @@ var overview = OutputOverview{
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil, "", "")
convertedResult := ConvertToStatusOutputOverview(resp.GetFullStatus(), false, resp.GetDaemonVersion(), "", nil, nil, nil, "", "")
assert.Equal(t, overview, convertedResult)
}
@@ -268,7 +268,7 @@ func TestSortingOfPeers(t *testing.T) {
}
func TestParsingToJSON(t *testing.T) {
jsonString, _ := ParseToJSON(overview)
jsonString, _ := overview.JSON()
//@formatter:off
expectedJSONString := `
@@ -404,7 +404,7 @@ func TestParsingToJSON(t *testing.T) {
}
func TestParsingToYAML(t *testing.T) {
yaml, _ := ParseToYAML(overview)
yaml, _ := overview.YAML()
expectedYAML :=
`peers:
@@ -511,7 +511,7 @@ func TestParsingToDetail(t *testing.T) {
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
detail := ParseToFullDetailSummary(overview)
detail := overview.FullDetailSummary()
expectedDetail := fmt.Sprintf(
`Peers detail:
@@ -567,7 +567,6 @@ Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
Networks: 10.10.0.0/24
Forwarding rules: 0
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
@@ -575,7 +574,7 @@ Peers count: 2/2 Connected
}
func TestParsingToShortVersion(t *testing.T) {
shortVersion := ParseGeneralSummary(overview, false, false, false, false)
shortVersion := overview.GeneralSummary(false, false, false, false)
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
Daemon version: 0.14.1
@@ -592,7 +591,6 @@ Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
Networks: 10.10.0.0/24
Forwarding rules: 0
Peers count: 2/2 Connected
`

View File

@@ -18,9 +18,7 @@ import (
"github.com/skratchdot/open-golang/open"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
uptypes "github.com/netbirdio/netbird/upload-server/types"
)
@@ -291,19 +289,18 @@ func (s *serviceClient) handleRunForDuration(
return
}
statusOutput, err := s.collectDebugData(conn, initialState, params, progressUI)
if err != nil {
defer s.restoreServiceState(conn, initialState)
if err := s.collectDebugData(conn, initialState, params, progressUI); err != nil {
handleError(progressUI, err.Error())
return
}
if err := s.createDebugBundleFromCollection(conn, params, statusOutput, progressUI); err != nil {
if err := s.createDebugBundleFromCollection(conn, params, progressUI); err != nil {
handleError(progressUI, err.Error())
return
}
s.restoreServiceState(conn, initialState)
progressUI.statusLabel.SetText("Bundle created successfully")
}
@@ -409,6 +406,10 @@ func (s *serviceClient) configureServiceForDebug(
}
time.Sleep(time.Second * 3)
if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil {
log.Warnf("failed to start CPU profiling: %v", err)
}
return nil
}
@@ -417,68 +418,37 @@ func (s *serviceClient) collectDebugData(
state *debugInitialState,
params *debugCollectionParams,
progress *progressUI,
) (string, error) {
) error {
ctx, cancel := context.WithTimeout(s.ctx, params.duration)
defer cancel()
var wg sync.WaitGroup
startProgressTracker(ctx, &wg, params.duration, progress)
if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil {
return "", err
return err
}
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("Failed to get post-up status: %v", err)
}
var postUpStatusOutput string
if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
wg.Wait()
progress.progressBar.Hide()
progress.statusLabel.SetText("Collecting debug data...")
preDownStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("Failed to get pre-down status: %v", err)
if _, err := conn.StopCPUProfile(s.ctx, &proto.StopCPUProfileRequest{}); err != nil {
log.Warnf("failed to stop CPU profiling: %v", err)
}
var preDownStatusOutput string
if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
time.Now().Format(time.RFC3339), params.duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, preDownStatusOutput)
return statusOutput, nil
return nil
}
// Create the debug bundle with collected data
func (s *serviceClient) createDebugBundleFromCollection(
conn proto.DaemonServiceClient,
params *debugCollectionParams,
statusOutput string,
progress *progressUI,
) error {
progress.statusLabel.SetText("Creating debug bundle with collected logs...")
request := &proto.DebugBundleRequest{
Anonymize: params.anonymize,
Status: statusOutput,
SystemInfo: params.systemInfo,
}
@@ -581,26 +551,8 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
return nil, fmt.Errorf("get client: %v", err)
}
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("failed to get status for debug bundle: %v", err)
}
var statusOutput string
if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
request := &proto.DebugBundleRequest{
Anonymize: anonymize,
Status: statusOutput,
SystemInfo: systemInfo,
}

View File

@@ -9,20 +9,28 @@ import (
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protojson"
netbird "github.com/netbirdio/netbird/client/embed"
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/wasm/internal/http"
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
const (
clientStartTimeout = 30 * time.Second
clientStopTimeout = 10 * time.Second
pingTimeout = 10 * time.Second
defaultLogLevel = "warn"
defaultSSHDetectionTimeout = 20 * time.Second
icmpEchoRequest = 8
icmpCodeEcho = 0
pingBufferSize = 1500
)
func main() {
@@ -113,18 +121,45 @@ func createStopMethod(client *netbird.Client) js.Func {
})
}
// validateSSHArgs validates SSH connection arguments
func validateSSHArgs(args []js.Value) (host string, port int, username string, err js.Value) {
if len(args) < 2 {
return "", 0, "", js.ValueOf("error: requires host and port")
}
if args[0].Type() != js.TypeString {
return "", 0, "", js.ValueOf("host parameter must be a string")
}
if args[1].Type() != js.TypeNumber {
return "", 0, "", js.ValueOf("port parameter must be a number")
}
host = args[0].String()
port = args[1].Int()
username = "root"
if len(args) > 2 {
if args[2].Type() == js.TypeString && args[2].String() != "" {
username = args[2].String()
} else if args[2].Type() != js.TypeString {
return "", 0, "", js.ValueOf("username parameter must be a string")
}
}
return host, port, username, js.Undefined()
}
// createSSHMethod creates the SSH connection method
func createSSHMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf("error: requires host and port")
}
host := args[0].String()
port := args[1].Int()
username := "root"
if len(args) > 2 && args[2].String() != "" {
username = args[2].String()
host, port, username, validationErr := validateSSHArgs(args)
if !validationErr.IsUndefined() {
if validationErr.Type() == js.TypeString && validationErr.String() == "error: requires host and port" {
return validationErr
}
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(validationErr)
})
}
var jwtToken string
@@ -154,6 +189,110 @@ func createSSHMethod(client *netbird.Client) js.Func {
})
}
func performPing(client *netbird.Client, hostname string) {
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
start := time.Now()
conn, err := client.Dial(ctx, "ping", hostname)
if err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err))
return
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close ping connection: %v", err)
}
}()
icmpData := make([]byte, 8)
icmpData[0] = icmpEchoRequest
icmpData[1] = icmpCodeEcho
if _, err := conn.Write(icmpData); err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s write failed: %v", hostname, err))
return
}
buf := make([]byte, pingBufferSize)
if _, err := conn.Read(buf); err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s read failed: %v", hostname, err))
return
}
latency := time.Since(start)
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds()))
}
func performPingTCP(client *netbird.Client, hostname string, port int) {
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
address := fmt.Sprintf("%s:%d", hostname, port)
start := time.Now()
conn, err := client.Dial(ctx, "tcp", address)
if err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err))
return
}
latency := time.Since(start)
if err := conn.Close(); err != nil {
log.Debugf("failed to close TCP connection: %v", err)
}
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds()))
}
// createPingMethod creates the ping method
func createPingMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: hostname required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
})
}
hostname := args[0].String()
return createPromise(func(resolve, reject js.Value) {
performPing(client, hostname)
resolve.Invoke(js.Undefined())
})
})
}
// createPingTCPMethod creates the pingtcp method
func createPingTCPMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf("error: hostname and port required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
})
}
if args[1].Type() != js.TypeNumber {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("port parameter must be a number"))
})
}
hostname := args[0].String()
port := args[1].Int()
return createPromise(func(resolve, reject js.Value) {
performPingTCP(client, hostname, port)
resolve.Invoke(js.Undefined())
})
})
}
// createProxyRequestMethod creates the proxyRequest method
func createProxyRequestMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
@@ -162,6 +301,11 @@ func createProxyRequestMethod(client *netbird.Client) js.Func {
}
request := args[0]
if request.Type() != js.TypeObject {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("request parameter must be an object"))
})
}
return createPromise(func(resolve, reject js.Value) {
response, err := http.ProxyRequest(client, request)
@@ -181,11 +325,141 @@ func createRDPProxyMethod(client *netbird.Client) js.Func {
return js.ValueOf("error: hostname and port required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
})
}
if args[1].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("port parameter must be a string"))
})
}
proxy := rdp.NewRDCleanPathProxy(client)
return proxy.CreateProxy(args[0].String(), args[1].String())
})
}
// getStatusOverview is a helper to get the status overview
func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error) {
fullStatus, err := client.Status()
if err != nil {
return nbstatus.OutputOverview{}, err
}
pbFullStatus := fullStatus.ToProto()
return nbstatus.ConvertToStatusOutputOverview(pbFullStatus, false, version.NetbirdVersion(), "", nil, nil, nil, "", ""), nil
}
// createStatusMethod creates the status method that returns JSON
func createStatusMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
overview, err := getStatusOverview(client)
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
jsonStr, err := overview.JSON()
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
jsonObj := js.Global().Get("JSON").Call("parse", jsonStr)
resolve.Invoke(jsonObj)
})
})
}
// createStatusSummaryMethod creates the statusSummary method
func createStatusSummaryMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
overview, err := getStatusOverview(client)
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
summary := overview.GeneralSummary(false, false, false, false)
js.Global().Get("console").Call("log", summary)
resolve.Invoke(js.Undefined())
})
})
}
// createStatusDetailMethod creates the statusDetail method
func createStatusDetailMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
overview, err := getStatusOverview(client)
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
detail := overview.FullDetailSummary()
js.Global().Get("console").Call("log", detail)
resolve.Invoke(js.Undefined())
})
})
}
// createGetSyncResponseMethod creates the getSyncResponse method that returns the latest sync response as JSON
func createGetSyncResponseMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
syncResp, err := client.GetLatestSyncResponse()
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
options := protojson.MarshalOptions{
EmitUnpopulated: true,
UseProtoNames: true,
AllowPartial: true,
}
jsonBytes, err := options.Marshal(syncResp)
if err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("marshal sync response: %v", err)))
return
}
jsonObj := js.Global().Get("JSON").Call("parse", string(jsonBytes))
resolve.Invoke(jsonObj)
})
})
}
// createSetLogLevelMethod creates the setLogLevel method to dynamically change logging level
func createSetLogLevelMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: log level required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("log level parameter must be a string"))
})
}
logLevel := args[0].String()
return createPromise(func(resolve, reject js.Value) {
if err := client.SetLogLevel(logLevel); err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("set log level: %v", err)))
return
}
log.Infof("Log level set to: %s", logLevel)
resolve.Invoke(js.ValueOf(true))
})
})
}
// createPromise is a helper to create JavaScript promises
func createPromise(handler func(resolve, reject js.Value)) js.Value {
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
@@ -237,17 +511,24 @@ func createClientObject(client *netbird.Client) js.Value {
obj["start"] = createStartMethod(client)
obj["stop"] = createStopMethod(client)
obj["ping"] = createPingMethod(client)
obj["pingtcp"] = createPingTCPMethod(client)
obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
obj["createSSHConnection"] = createSSHMethod(client)
obj["proxyRequest"] = createProxyRequestMethod(client)
obj["createRDPProxy"] = createRDPProxyMethod(client)
obj["status"] = createStatusMethod(client)
obj["statusSummary"] = createStatusSummaryMethod(client)
obj["statusDetail"] = createStatusDetailMethod(client)
obj["getSyncResponse"] = createGetSyncResponseMethod(client)
obj["setLogLevel"] = createSetLogLevelMethod(client)
return js.ValueOf(obj)
}
// netBirdClientConstructor acts as a JavaScript constructor function
func netBirdClientConstructor(this js.Value, args []js.Value) any {
return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any {
func netBirdClientConstructor(_ js.Value, args []js.Value) any {
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
resolve := promiseArgs[0]
reject := promiseArgs[1]

View File

@@ -47,8 +47,8 @@ type CustomZone struct {
Records []SimpleRecord
// SearchDomainDisabled indicates whether to add match domains to a search domains list or not
SearchDomainDisabled bool
// SkipPTRProcess indicates whether a client should process PTR records from custom zones
SkipPTRProcess bool
// NonAuthoritative marks user-created zones
NonAuthoritative bool
}
// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records

12
go.mod
View File

@@ -68,8 +68,9 @@ require (
github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/oapi-codegen/runtime v1.1.2
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -78,8 +79,8 @@ require (
github.com/pion/logging v0.2.4
github.com/pion/randutil v0.1.0
github.com/pion/stun/v2 v2.0.0
github.com/pion/stun/v3 v3.0.0
github.com/pion/transport/v3 v3.0.7
github.com/pion/stun/v3 v3.1.0
github.com/pion/transport/v3 v3.1.1
github.com/pion/turn/v3 v3.0.1
github.com/pkg/sftp v1.13.9
github.com/prometheus/client_golang v1.23.2
@@ -141,6 +142,7 @@ require (
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/Microsoft/hcsshim v0.12.3 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
github.com/awnumar/memcall v0.4.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect
@@ -241,7 +243,7 @@ require (
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/dtls/v3 v3.0.7 // indirect
github.com/pion/dtls/v3 v3.0.9 // indirect
github.com/pion/mdns/v2 v2.0.7 // indirect
github.com/pion/transport/v2 v2.2.4 // indirect
github.com/pion/turn/v4 v4.1.1 // indirect
@@ -263,7 +265,7 @@ require (
github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/wlynxg/anet v0.0.3 // indirect
github.com/wlynxg/anet v0.0.5 // indirect
github.com/yuin/goldmark v1.7.8 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect

28
go.sum
View File

@@ -35,12 +35,15 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/Microsoft/hcsshim v0.12.3 h1:LS9NXqXhMoqNCplK1ApmVSfB4UnVLRDWRapB6EIlxE0=
github.com/Microsoft/hcsshim v0.12.3/go.mod h1:Iyl1WVpZzr+UkzjekHZbV8o5Z9ZkxNGx6CtY2Qg/JVQ=
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo=
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I=
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI=
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g=
github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w=
github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A=
@@ -87,6 +90,7 @@ github.com/beevik/etree v1.6.0 h1:u8Kwy8pp9D9XeITj2Z0XtA5qqZEmtJtuXZRQi+j03eE=
github.com/beevik/etree v1.6.0/go.mod h1:bh4zJxiIr62SOf9pRzN7UUYaEDa9HEKafK25+sLc0Gc=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
@@ -320,6 +324,7 @@ github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7X
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8=
github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
@@ -401,8 +406,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f h1:CTBf0je/FpKr2lVSMZLak7m8aaWcS6ur4SOfhSSazFI=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
@@ -416,6 +421,8 @@ github.com/nicksnyder/go-i18n/v2 v2.5.1/go.mod h1:DrhgsSDZxoAfvVrBVLXoxZn/pN5TXq
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/oapi-codegen/runtime v1.1.2 h1:P2+CubHq8fO4Q6fV1tqDBZHCwpVpvPg7oKiYzQgXIyI=
github.com/oapi-codegen/runtime v1.1.2/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg=
github.com/okta/okta-sdk-golang/v2 v2.18.0 h1:cfDasMb7CShbZvOrF6n+DnLevWwiHgedWMGJ8M8xKDc=
github.com/okta/okta-sdk-golang/v2 v2.18.0/go.mod h1:dz30v3ctAiMb7jpsCngGfQUAEGm1/NsWT92uTbNDQIs=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
@@ -444,8 +451,8 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q=
github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8=
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
@@ -455,14 +462,14 @@ github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0=
github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ=
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
github.com/pion/stun/v3 v3.1.0 h1:bS1jjT3tGWZ4UPmIUeyalOylamTMTFg1OvXtY/r6seM=
github.com/pion/stun/v3 v3.1.0/go.mod h1:egmx1CUcfSSGJxQCOjtVlomfPqmQ58BibPyuOWNGQEU=
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo=
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
@@ -522,6 +529,7 @@ github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE=
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q=
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ=
@@ -574,8 +582,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg=
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=

356
idp/dex/connector.go Normal file
View File

@@ -0,0 +1,356 @@
// Package dex provides an embedded Dex OIDC identity provider.
package dex
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/dexidp/dex/storage"
)
// ConnectorConfig represents the configuration for an identity provider connector
type ConnectorConfig struct {
// ID is the unique identifier for the connector
ID string
// Name is a human-readable name for the connector
Name string
// Type is the connector type (oidc, google, microsoft)
Type string
// Issuer is the OIDC issuer URL (for OIDC-based connectors)
Issuer string
// ClientID is the OAuth2 client ID
ClientID string
// ClientSecret is the OAuth2 client secret
ClientSecret string
// RedirectURI is the OAuth2 redirect URI
RedirectURI string
}
// CreateConnector creates a new connector in Dex storage.
// It maps the connector config to the appropriate Dex connector type and configuration.
func (p *Provider) CreateConnector(ctx context.Context, cfg *ConnectorConfig) (*ConnectorConfig, error) {
// Fill in the redirect URI if not provided
if cfg.RedirectURI == "" {
cfg.RedirectURI = p.GetRedirectURI()
}
storageConn, err := p.buildStorageConnector(cfg)
if err != nil {
return nil, fmt.Errorf("failed to build connector: %w", err)
}
if err := p.storage.CreateConnector(ctx, storageConn); err != nil {
return nil, fmt.Errorf("failed to create connector: %w", err)
}
p.logger.Info("connector created", "id", cfg.ID, "type", cfg.Type)
return cfg, nil
}
// GetConnector retrieves a connector by ID from Dex storage.
func (p *Provider) GetConnector(ctx context.Context, id string) (*ConnectorConfig, error) {
conn, err := p.storage.GetConnector(ctx, id)
if err != nil {
if err == storage.ErrNotFound {
return nil, err
}
return nil, fmt.Errorf("failed to get connector: %w", err)
}
return p.parseStorageConnector(conn)
}
// ListConnectors returns all connectors from Dex storage (excluding the local connector).
func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, error) {
connectors, err := p.storage.ListConnectors(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list connectors: %w", err)
}
result := make([]*ConnectorConfig, 0, len(connectors))
for _, conn := range connectors {
// Skip the local password connector
if conn.ID == "local" && conn.Type == "local" {
continue
}
cfg, err := p.parseStorageConnector(conn)
if err != nil {
p.logger.Warn("failed to parse connector", "id", conn.ID, "error", err)
continue
}
result = append(result, cfg)
}
return result, nil
}
// UpdateConnector updates an existing connector in Dex storage.
// It merges incoming updates with existing values to prevent data loss on partial updates.
func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
oldCfg, err := p.parseStorageConnector(old)
if err != nil {
return storage.Connector{}, fmt.Errorf("failed to parse existing connector: %w", err)
}
mergeConnectorConfig(cfg, oldCfg)
storageConn, err := p.buildStorageConnector(cfg)
if err != nil {
return storage.Connector{}, fmt.Errorf("failed to build connector: %w", err)
}
return storageConn, nil
}); err != nil {
return fmt.Errorf("failed to update connector: %w", err)
}
p.logger.Info("connector updated", "id", cfg.ID, "type", cfg.Type)
return nil
}
// mergeConnectorConfig preserves existing values for empty fields in the update.
func mergeConnectorConfig(cfg, oldCfg *ConnectorConfig) {
if cfg.ClientSecret == "" {
cfg.ClientSecret = oldCfg.ClientSecret
}
if cfg.RedirectURI == "" {
cfg.RedirectURI = oldCfg.RedirectURI
}
if cfg.Issuer == "" && cfg.Type == oldCfg.Type {
cfg.Issuer = oldCfg.Issuer
}
if cfg.ClientID == "" {
cfg.ClientID = oldCfg.ClientID
}
if cfg.Name == "" {
cfg.Name = oldCfg.Name
}
}
// DeleteConnector removes a connector from Dex storage.
func (p *Provider) DeleteConnector(ctx context.Context, id string) error {
// Prevent deletion of the local connector
if id == "local" {
return fmt.Errorf("cannot delete the local password connector")
}
if err := p.storage.DeleteConnector(ctx, id); err != nil {
return fmt.Errorf("failed to delete connector: %w", err)
}
p.logger.Info("connector deleted", "id", id)
return nil
}
// GetRedirectURI returns the default redirect URI for connectors.
func (p *Provider) GetRedirectURI() string {
if p.config == nil {
return ""
}
issuer := strings.TrimSuffix(p.config.Issuer, "/")
if !strings.HasSuffix(issuer, "/oauth2") {
issuer += "/oauth2"
}
return issuer + "/callback"
}
// buildStorageConnector creates a storage.Connector from ConnectorConfig.
// It handles the type-specific configuration for each connector type.
func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connector, error) {
redirectURI := p.resolveRedirectURI(cfg.RedirectURI)
var dexType string
var configData []byte
var err error
switch cfg.Type {
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
dexType = "oidc"
configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
case "google":
dexType = "google"
configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI)
case "microsoft":
dexType = "microsoft"
configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI)
default:
return storage.Connector{}, fmt.Errorf("unsupported connector type: %s", cfg.Type)
}
if err != nil {
return storage.Connector{}, err
}
return storage.Connector{ID: cfg.ID, Type: dexType, Name: cfg.Name, Config: configData}, nil
}
// resolveRedirectURI returns the redirect URI, using a default if not provided
func (p *Provider) resolveRedirectURI(redirectURI string) string {
if redirectURI != "" || p.config == nil {
return redirectURI
}
issuer := strings.TrimSuffix(p.config.Issuer, "/")
if !strings.HasSuffix(issuer, "/oauth2") {
issuer += "/oauth2"
}
return issuer + "/callback"
}
// buildOIDCConnectorConfig creates config for OIDC-based connectors
func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
oidcConfig := map[string]interface{}{
"issuer": cfg.Issuer,
"clientID": cfg.ClientID,
"clientSecret": cfg.ClientSecret,
"redirectURI": redirectURI,
"scopes": []string{"openid", "profile", "email"},
"insecureEnableGroups": true,
//some providers don't return email verified, so we need to skip it if not present (e.g., Entra, Okta, Duo)
"insecureSkipEmailVerified": true,
}
switch cfg.Type {
case "zitadel":
oidcConfig["getUserInfo"] = true
case "entra":
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
case "okta":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "pocketid":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
}
return encodeConnectorConfig(oidcConfig)
}
// buildOAuth2ConnectorConfig creates config for OAuth2 connectors (google, microsoft)
func buildOAuth2ConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
return encodeConnectorConfig(map[string]interface{}{
"clientID": cfg.ClientID,
"clientSecret": cfg.ClientSecret,
"redirectURI": redirectURI,
})
}
// parseStorageConnector converts a storage.Connector back to ConnectorConfig.
// It infers the original identity provider type from the Dex connector type and ID.
func (p *Provider) parseStorageConnector(conn storage.Connector) (*ConnectorConfig, error) {
cfg := &ConnectorConfig{
ID: conn.ID,
Name: conn.Name,
}
if len(conn.Config) == 0 {
cfg.Type = conn.Type
return cfg, nil
}
var configMap map[string]interface{}
if err := decodeConnectorConfig(conn.Config, &configMap); err != nil {
return nil, fmt.Errorf("failed to parse connector config: %w", err)
}
// Extract common fields
if v, ok := configMap["clientID"].(string); ok {
cfg.ClientID = v
}
if v, ok := configMap["clientSecret"].(string); ok {
cfg.ClientSecret = v
}
if v, ok := configMap["redirectURI"].(string); ok {
cfg.RedirectURI = v
}
if v, ok := configMap["issuer"].(string); ok {
cfg.Issuer = v
}
// Infer the original identity provider type from Dex connector type and ID
cfg.Type = inferIdentityProviderType(conn.Type, conn.ID, configMap)
return cfg, nil
}
// inferIdentityProviderType determines the original identity provider type
// based on the Dex connector type, connector ID, and configuration.
func inferIdentityProviderType(dexType, connectorID string, _ map[string]interface{}) string {
if dexType != "oidc" {
return dexType
}
return inferOIDCProviderType(connectorID)
}
// inferOIDCProviderType infers the specific OIDC provider from connector ID
func inferOIDCProviderType(connectorID string) string {
connectorIDLower := strings.ToLower(connectorID)
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} {
if strings.Contains(connectorIDLower, provider) {
return provider
}
}
return "oidc"
}
// encodeConnectorConfig serializes connector config to JSON bytes.
func encodeConnectorConfig(config map[string]interface{}) ([]byte, error) {
return json.Marshal(config)
}
// decodeConnectorConfig deserializes connector config from JSON bytes.
func decodeConnectorConfig(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}
// ensureLocalConnector creates a local (password) connector if it doesn't exist
func ensureLocalConnector(ctx context.Context, stor storage.Storage) error {
// Check specifically for the local connector
_, err := stor.GetConnector(ctx, "local")
if err == nil {
// Local connector already exists
return nil
}
if !errors.Is(err, storage.ErrNotFound) {
return fmt.Errorf("failed to get local connector: %w", err)
}
// Create a local connector for password authentication
localConnector := storage.Connector{
ID: "local",
Type: "local",
Name: "Email",
}
if err := stor.CreateConnector(ctx, localConnector); err != nil {
return fmt.Errorf("failed to create local connector: %w", err)
}
return nil
}
// ensureStaticConnectors creates or updates static connectors in storage
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
for _, conn := range connectors {
storConn, err := conn.ToStorageConnector()
if err != nil {
return fmt.Errorf("failed to convert connector %s: %w", conn.ID, err)
}
_, err = stor.GetConnector(ctx, conn.ID)
if err == storage.ErrNotFound {
if err := stor.CreateConnector(ctx, storConn); err != nil {
return fmt.Errorf("failed to create connector %s: %w", conn.ID, err)
}
continue
}
if err != nil {
return fmt.Errorf("failed to get connector %s: %w", conn.ID, err)
}
if err := stor.UpdateConnector(ctx, conn.ID, func(old storage.Connector) (storage.Connector, error) {
old.Name = storConn.Name
old.Config = storConn.Config
return old, nil
}); err != nil {
return fmt.Errorf("failed to update connector %s: %w", conn.ID, err)
}
}
return nil
}

View File

@@ -4,7 +4,6 @@ package dex
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
@@ -245,34 +244,6 @@ func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []st
return nil
}
// ensureStaticConnectors creates or updates static connectors in storage
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
for _, conn := range connectors {
storConn, err := conn.ToStorageConnector()
if err != nil {
return fmt.Errorf("failed to convert connector %s: %w", conn.ID, err)
}
_, err = stor.GetConnector(ctx, conn.ID)
if errors.Is(err, storage.ErrNotFound) {
if err := stor.CreateConnector(ctx, storConn); err != nil {
return fmt.Errorf("failed to create connector %s: %w", conn.ID, err)
}
continue
}
if err != nil {
return fmt.Errorf("failed to get connector %s: %w", conn.ID, err)
}
if err := stor.UpdateConnector(ctx, conn.ID, func(old storage.Connector) (storage.Connector, error) {
old.Name = storConn.Name
old.Config = storConn.Config
return old, nil
}); err != nil {
return fmt.Errorf("failed to update connector %s: %w", conn.ID, err)
}
}
return nil
}
// buildDexConfig creates a server.Config with defaults applied
func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.Logger) server.Config {
cfg := yamlConfig.ToServerConfig(stor, logger)
@@ -613,290 +584,37 @@ func (p *Provider) ListUsers(ctx context.Context) ([]storage.Password, error) {
return p.storage.ListPasswords(ctx)
}
// ensureLocalConnector creates a local (password) connector if none exists
func ensureLocalConnector(ctx context.Context, stor storage.Storage) error {
connectors, err := stor.ListConnectors(ctx)
// UpdateUserPassword updates the password for a user identified by userID.
// The userID can be either an encoded Dex ID (base64 protobuf) or a raw UUID.
// It verifies the current password before updating.
func (p *Provider) UpdateUserPassword(ctx context.Context, userID string, oldPassword, newPassword string) error {
// Get the user by ID to find their email
user, err := p.GetUserByID(ctx, userID)
if err != nil {
return fmt.Errorf("failed to list connectors: %w", err)
return fmt.Errorf("failed to get user: %w", err)
}
// If any connector exists, we're good
if len(connectors) > 0 {
return nil
// Verify old password
if err := bcrypt.CompareHashAndPassword(user.Hash, []byte(oldPassword)); err != nil {
return fmt.Errorf("current password is incorrect")
}
// Create a local connector for password authentication
localConnector := storage.Connector{
ID: "local",
Type: "local",
Name: "Email",
}
if err := stor.CreateConnector(ctx, localConnector); err != nil {
return fmt.Errorf("failed to create local connector: %w", err)
}
return nil
}
// ConnectorConfig represents the configuration for an identity provider connector
type ConnectorConfig struct {
// ID is the unique identifier for the connector
ID string
// Name is a human-readable name for the connector
Name string
// Type is the connector type (oidc, google, microsoft)
Type string
// Issuer is the OIDC issuer URL (for OIDC-based connectors)
Issuer string
// ClientID is the OAuth2 client ID
ClientID string
// ClientSecret is the OAuth2 client secret
ClientSecret string
// RedirectURI is the OAuth2 redirect URI
RedirectURI string
}
// CreateConnector creates a new connector in Dex storage.
// It maps the connector config to the appropriate Dex connector type and configuration.
func (p *Provider) CreateConnector(ctx context.Context, cfg *ConnectorConfig) (*ConnectorConfig, error) {
// Fill in the redirect URI if not provided
if cfg.RedirectURI == "" {
cfg.RedirectURI = p.GetRedirectURI()
}
storageConn, err := p.buildStorageConnector(cfg)
// Hash the new password
newHash, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("failed to build connector: %w", err)
return fmt.Errorf("failed to hash new password: %w", err)
}
if err := p.storage.CreateConnector(ctx, storageConn); err != nil {
return nil, fmt.Errorf("failed to create connector: %w", err)
}
p.logger.Info("connector created", "id", cfg.ID, "type", cfg.Type)
return cfg, nil
}
// GetConnector retrieves a connector by ID from Dex storage.
func (p *Provider) GetConnector(ctx context.Context, id string) (*ConnectorConfig, error) {
conn, err := p.storage.GetConnector(ctx, id)
if err != nil {
if err == storage.ErrNotFound {
return nil, err
}
return nil, fmt.Errorf("failed to get connector: %w", err)
}
return p.parseStorageConnector(conn)
}
// ListConnectors returns all connectors from Dex storage (excluding the local connector).
func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, error) {
connectors, err := p.storage.ListConnectors(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list connectors: %w", err)
}
result := make([]*ConnectorConfig, 0, len(connectors))
for _, conn := range connectors {
// Skip the local password connector
if conn.ID == "local" && conn.Type == "local" {
continue
}
cfg, err := p.parseStorageConnector(conn)
if err != nil {
p.logger.Warn("failed to parse connector", "id", conn.ID, "error", err)
continue
}
result = append(result, cfg)
}
return result, nil
}
// UpdateConnector updates an existing connector in Dex storage.
func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
storageConn, err := p.buildStorageConnector(cfg)
if err != nil {
return fmt.Errorf("failed to build connector: %w", err)
}
if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
return storageConn, nil
}); err != nil {
return fmt.Errorf("failed to update connector: %w", err)
}
p.logger.Info("connector updated", "id", cfg.ID, "type", cfg.Type)
return nil
}
// DeleteConnector removes a connector from Dex storage.
func (p *Provider) DeleteConnector(ctx context.Context, id string) error {
// Prevent deletion of the local connector
if id == "local" {
return fmt.Errorf("cannot delete the local password connector")
}
if err := p.storage.DeleteConnector(ctx, id); err != nil {
return fmt.Errorf("failed to delete connector: %w", err)
}
p.logger.Info("connector deleted", "id", id)
return nil
}
// buildStorageConnector creates a storage.Connector from ConnectorConfig.
// It handles the type-specific configuration for each connector type.
func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connector, error) {
redirectURI := p.resolveRedirectURI(cfg.RedirectURI)
var dexType string
var configData []byte
var err error
switch cfg.Type {
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
dexType = "oidc"
configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
case "google":
dexType = "google"
configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI)
case "microsoft":
dexType = "microsoft"
configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI)
default:
return storage.Connector{}, fmt.Errorf("unsupported connector type: %s", cfg.Type)
}
if err != nil {
return storage.Connector{}, err
}
return storage.Connector{ID: cfg.ID, Type: dexType, Name: cfg.Name, Config: configData}, nil
}
// resolveRedirectURI returns the redirect URI, using a default if not provided
func (p *Provider) resolveRedirectURI(redirectURI string) string {
if redirectURI != "" || p.config == nil {
return redirectURI
}
issuer := strings.TrimSuffix(p.config.Issuer, "/")
if !strings.HasSuffix(issuer, "/oauth2") {
issuer += "/oauth2"
}
return issuer + "/callback"
}
// buildOIDCConnectorConfig creates config for OIDC-based connectors
func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
oidcConfig := map[string]interface{}{
"issuer": cfg.Issuer,
"clientID": cfg.ClientID,
"clientSecret": cfg.ClientSecret,
"redirectURI": redirectURI,
"scopes": []string{"openid", "profile", "email"},
}
switch cfg.Type {
case "zitadel":
oidcConfig["getUserInfo"] = true
case "entra":
oidcConfig["insecureSkipEmailVerified"] = true
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
case "okta":
oidcConfig["insecureSkipEmailVerified"] = true
}
return encodeConnectorConfig(oidcConfig)
}
// buildOAuth2ConnectorConfig creates config for OAuth2 connectors (google, microsoft)
func buildOAuth2ConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
return encodeConnectorConfig(map[string]interface{}{
"clientID": cfg.ClientID,
"clientSecret": cfg.ClientSecret,
"redirectURI": redirectURI,
// Update the password in storage
err = p.storage.UpdatePassword(ctx, user.Email, func(old storage.Password) (storage.Password, error) {
old.Hash = newHash
return old, nil
})
}
// parseStorageConnector converts a storage.Connector back to ConnectorConfig.
// It infers the original identity provider type from the Dex connector type and ID.
func (p *Provider) parseStorageConnector(conn storage.Connector) (*ConnectorConfig, error) {
cfg := &ConnectorConfig{
ID: conn.ID,
Name: conn.Name,
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
if len(conn.Config) == 0 {
cfg.Type = conn.Type
return cfg, nil
}
var configMap map[string]interface{}
if err := decodeConnectorConfig(conn.Config, &configMap); err != nil {
return nil, fmt.Errorf("failed to parse connector config: %w", err)
}
// Extract common fields
if v, ok := configMap["clientID"].(string); ok {
cfg.ClientID = v
}
if v, ok := configMap["clientSecret"].(string); ok {
cfg.ClientSecret = v
}
if v, ok := configMap["redirectURI"].(string); ok {
cfg.RedirectURI = v
}
if v, ok := configMap["issuer"].(string); ok {
cfg.Issuer = v
}
// Infer the original identity provider type from Dex connector type and ID
cfg.Type = inferIdentityProviderType(conn.Type, conn.ID, configMap)
return cfg, nil
}
// inferIdentityProviderType determines the original identity provider type
// based on the Dex connector type, connector ID, and configuration.
func inferIdentityProviderType(dexType, connectorID string, _ map[string]interface{}) string {
if dexType != "oidc" {
return dexType
}
return inferOIDCProviderType(connectorID)
}
// inferOIDCProviderType infers the specific OIDC provider from connector ID
func inferOIDCProviderType(connectorID string) string {
connectorIDLower := strings.ToLower(connectorID)
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} {
if strings.Contains(connectorIDLower, provider) {
return provider
}
}
return "oidc"
}
// encodeConnectorConfig serializes connector config to JSON bytes.
func encodeConnectorConfig(config map[string]interface{}) ([]byte, error) {
return json.Marshal(config)
}
// decodeConnectorConfig deserializes connector config from JSON bytes.
func decodeConnectorConfig(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}
// GetRedirectURI returns the default redirect URI for connectors.
func (p *Provider) GetRedirectURI() string {
if p.config == nil {
return ""
}
issuer := strings.TrimSuffix(p.config.Issuer, "/")
if !strings.HasSuffix(issuer, "/oauth2") {
issuer += "/oauth2"
}
return issuer + "/callback"
return nil
}
// GetIssuer returns the OIDC issuer URL.

File diff suppressed because it is too large Load Diff

View File

@@ -143,7 +143,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Confi
applyCommandLineOverrides(loadedConfig)
// Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled
err := applyEmbeddedIdPConfig(loadedConfig)
err := applyEmbeddedIdPConfig(ctx, loadedConfig)
if err != nil {
return nil, err
}
@@ -177,7 +177,7 @@ func applyCommandLineOverrides(cfg *nbconfig.Config) {
// applyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
// This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig.
func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
return nil
}
@@ -193,11 +193,6 @@ func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
// Set LocalAddress for embedded IdP if enabled, used for internal JWT validation
cfg.EmbeddedIdP.LocalAddress = fmt.Sprintf("localhost:%d", mgmtPort)
// Ensure HttpConfig exists
if cfg.HttpConfig == nil {
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
}
// Set storage defaults based on Datadir
if cfg.EmbeddedIdP.Storage.Type == "" {
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
@@ -208,40 +203,22 @@ func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
issuer := cfg.EmbeddedIdP.Issuer
// Set AuthIssuer from EmbeddedIdP issuer
if cfg.HttpConfig.AuthIssuer == "" {
cfg.HttpConfig.AuthIssuer = issuer
if cfg.HttpConfig != nil {
log.WithContext(ctx).Warnf("overriding HttpConfig with EmbeddedIdP config. " +
"HttpConfig is ignored when EmbeddedIdP is enabled. Please remove HttpConfig section from the config file")
} else {
// Ensure HttpConfig exists. We need it for backwards compatibility with the old config format.
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
}
// Set AuthAudience to the dashboard client ID
if cfg.HttpConfig.AuthAudience == "" {
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
}
// Set CLIAuthAudience to the client app client ID
if cfg.HttpConfig.CLIAuthAudience == "" {
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
}
// Set AuthUserIDClaim to "sub" (standard OIDC claim)
if cfg.HttpConfig.AuthUserIDClaim == "" {
cfg.HttpConfig.AuthUserIDClaim = "sub"
}
// Set AuthKeysLocation to the JWKS endpoint
if cfg.HttpConfig.AuthKeysLocation == "" {
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
}
// Set OIDCConfigEndpoint to the discovery endpoint
if cfg.HttpConfig.OIDCConfigEndpoint == "" {
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
}
// Copy SignKeyRefreshEnabled from EmbeddedIdP config
if cfg.EmbeddedIdP.SignKeyRefreshEnabled {
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
}
// Set HttpConfig values from EmbeddedIdP
cfg.HttpConfig.AuthIssuer = issuer
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
cfg.HttpConfig.AuthUserIDClaim = "sub"
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
return nil
}
@@ -249,7 +226,12 @@ func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
// applyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint
if oidcEndpoint == "" || cfg.EmbeddedIdP != nil {
if oidcEndpoint == "" {
return nil
}
if cfg.EmbeddedIdP != nil && cfg.EmbeddedIdP.Enabled {
// skip OIDC config fetching if EmbeddedIdP is enabled as it is unnecessary given it is embedded
return nil
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
@@ -62,6 +63,8 @@ type Controller struct {
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
compactedNetworkMap bool
}
type bufferUpdate struct {
@@ -84,6 +87,12 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
newNetworkMapBuilder = false
}
compactedNetworkMap, err := strconv.ParseBool(os.Getenv(types.EnvNewNetworkMapCompacted))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", types.EnvNewNetworkMapCompacted, err)
compactedNetworkMap = false
}
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
@@ -107,6 +116,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
compactedNetworkMap: compactedNetworkMap,
}
}
@@ -175,7 +186,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
@@ -197,6 +208,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return fmt.Errorf("failed to get account zones: %v", err)
}
for _, peer := range account.Peers {
if !c.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
@@ -223,9 +240,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
if c.compactedNetworkMap {
account.ShadowCompareNetworkMap(ctx, p.ID, remotePeerNetworkMap, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, groupIDToUserIDs, c.accountManagerMetrics)
}
}
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -318,7 +338,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
@@ -335,12 +355,21 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return err
}
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return err
}
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
if c.compactedNetworkMap {
account.ShadowCompareNetworkMap(ctx, peerId, remotePeerNetworkMap, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, groupIDToUserIDs, c.accountManagerMetrics)
}
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -434,7 +463,14 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return nil, nil, nil, 0, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
@@ -445,11 +481,15 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
groupIDToUserIDs := account.GetActiveGroupUsers()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
if c.compactedNetworkMap {
account.ShadowCompareNetworkMap(ctx, peer.ID, networkMap, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, groupIDToUserIDs, c.accountManagerMetrics)
}
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -472,7 +512,8 @@ func (c *Controller) getPeerNetworkMapExp(
accountId string,
peerId string,
validatedPeers map[string]struct{},
customZone nbdns.CustomZone,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
metrics *telemetry.AccountManagerMetrics,
) *types.NetworkMap {
account := c.getAccountFromHolderOrInit(ctx, accountId)
@@ -483,7 +524,7 @@ func (c *Controller) getPeerNetworkMapExp(
}
}
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
}
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
@@ -798,7 +839,15 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
if err != nil {
return nil, err
}
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return nil, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
if err != nil {
@@ -809,11 +858,14 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
if c.compactedNetworkMap {
account.ShadowCompareNetworkMap(ctx, peer.ID, networkMap, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, account.GetActiveGroupUsers(), c.accountManagerMetrics)
}
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]

View File

@@ -3,6 +3,7 @@ package controller
import (
"context"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -14,6 +15,7 @@ type Repository interface {
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error)
}
type repository struct {
@@ -47,3 +49,7 @@ func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerID
func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) {
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
}
func (r *repository) GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error) {
return r.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
}

View File

@@ -31,6 +31,7 @@ type Manager interface {
SetNetworkMapController(networkMapController network_map.Controller)
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
SetAccountManager(accountManager account.Manager)
GetPeerID(ctx context.Context, peerKey string) (string, error)
}
type managerImpl struct {
@@ -167,3 +168,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
return nil
}
func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, error) {
return m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
}

View File

@@ -97,6 +97,21 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
}
// GetPeerID mocks base method.
func (m *MockManager) GetPeerID(ctx context.Context, peerKey string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerID", ctx, peerKey)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeerID indicates an expected call of GetPeerID.
func (mr *MockManagerMockRecorder) GetPeerID(ctx, peerKey interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerID", reflect.TypeOf((*MockManager)(nil).GetPeerID), ctx, peerKey)
}
// GetPeersByGroupIDs mocks base method.
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
m.ctrl.T.Helper()

View File

@@ -0,0 +1,13 @@
package zones
import (
"context"
)
type Manager interface {
GetAllZones(ctx context.Context, accountID, userID string) ([]*Zone, error)
GetZone(ctx context.Context, accountID, userID, zone string) (*Zone, error)
CreateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
UpdateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
DeleteZone(ctx context.Context, accountID, userID, zoneID string) error
}

View File

@@ -0,0 +1,161 @@
package manager
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/zones"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
manager zones.Manager
}
func RegisterEndpoints(router *mux.Router, manager zones.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
apiZones := make([]*api.Zone, 0, len(allZones))
for _, zone := range allZones {
apiZones = append(apiZones, zone.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, apiZones)
}
func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.PostApiDnsZonesJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
zone := new(zones.Zone)
zone.FromAPIRequest(&req)
if err = zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
createdZone, err := h.manager.CreateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse())
}
func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
zone, err := h.manager.GetZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse())
}
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
var req api.PutApiDnsZonesZoneIdJSONRequestBody
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
zone := new(zones.Zone)
zone.FromAPIRequest(&req)
zone.ID = zoneID
if err = zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
updatedZone, err := h.manager.UpdateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse())
}
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -0,0 +1,229 @@
package manager
import (
"context"
"fmt"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type managerImpl struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
dnsDomain string
}
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, dnsDomain string) zones.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
dnsDomain: dnsDomain,
}
}
func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
}
func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID)
}
func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil {
return nil, err
}
zone = zones.NewZone(accountID, zone.Name, zone.Domain, zone.Enabled, zone.EnableSearchDomain, zone.DistributionGroups)
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingZone, err := transaction.GetZoneByDomain(ctx, accountID, zone.Domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing zone: %w", err)
}
}
if existingZone != nil {
return status.Errorf(status.AlreadyExists, "zone with domain %s already exists", zone.Domain)
}
for _, groupID := range zone.DistributionGroups {
_, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil {
return status.Errorf(status.InvalidArgument, "%s", err.Error())
}
}
if err = transaction.CreateZone(ctx, zone); err != nil {
return fmt.Errorf("failed to create zone: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneCreated, zone.EventMeta())
return zone, nil
}
func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID)
if err != nil {
return nil, fmt.Errorf("failed to get zone: %w", err)
}
if zone.Domain != updatedZone.Domain {
return nil, status.Errorf(status.InvalidArgument, "zone domain cannot be updated")
}
zone.Name = updatedZone.Name
zone.Enabled = updatedZone.Enabled
zone.EnableSearchDomain = updatedZone.EnableSearchDomain
zone.DistributionGroups = updatedZone.DistributionGroups
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range zone.DistributionGroups {
_, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil {
return status.Errorf(status.InvalidArgument, "%s", err.Error())
}
}
if err = transaction.UpdateZone(ctx, zone); err != nil {
return fmt.Errorf("failed to update zone: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta())
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return zone, nil
}
func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
}
var eventsToStore []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
records, err := transaction.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get records: %w", err)
}
err = transaction.DeleteZoneDNSRecords(ctx, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to delete zone dns records: %w", err)
}
err = transaction.DeleteZone(ctx, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to delete zone: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
for _, record := range records {
eventsToStore = append(eventsToStore, func() {
meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordDeleted, meta)
})
}
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, zoneID, accountID, activity.DNSZoneDeleted, zone.EventMeta())
})
return nil
})
if err != nil {
return err
}
for _, event := range eventsToStore {
event()
}
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) validateZoneDomainConflict(ctx context.Context, accountID, domain string) error {
if m.dnsDomain != "" && m.dnsDomain == domain {
return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain)
}
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
if settings.DNSDomain != "" && settings.DNSDomain == domain {
return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain)
}
return nil
}

View File

@@ -0,0 +1,553 @@
package manager
import (
"context"
"fmt"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
testAccountID = "test-account-id"
testUserID = "test-user-id"
testZoneID = "test-zone-id"
testGroupID = "test-group-id"
testDNSDomain = "netbird.selfhosted"
)
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
t.Helper()
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
err = testStore.SaveAccount(ctx, &types.Account{
Id: testAccountID,
Groups: map[string]*types.Group{
testGroupID: {
ID: testGroupID,
Name: "Test Group",
},
},
})
require.NoError(t, err)
ctrl := gomock.NewController(t)
mockAccountManager := &mock_server.MockAccountManager{}
mockPermissionsManager := permissions.NewMockManager(ctrl)
manager := &managerImpl{
store: testStore,
accountManager: mockAccountManager,
permissionsManager: mockPermissionsManager,
dnsDomain: testDNSDomain,
}
return manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup
}
func TestManagerImpl_GetAllZones(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
zone1 := zones.NewZone(testAccountID, "Zone 1", "zone1.example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, zone1)
require.NoError(t, err)
zone2 := zones.NewZone(testAccountID, "Zone 2", "zone2.example.com", false, false, []string{testGroupID})
err = testStore.CreateZone(ctx, zone2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.NoError(t, err)
assert.Len(t, result, 2)
assert.Equal(t, zone1.ID, result[0].ID)
assert.Equal(t, zone2.ID, result[1].ID)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("permission validation error", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
assert.Nil(t, result)
})
}
func TestManagerImpl_GetZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
zone := zones.NewZone(testAccountID, "Test Zone", "test.example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, zone)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.Equal(t, zone.ID, result.ID)
assert.Equal(t, zone.Name, result.Name)
assert.Equal(t, zone.Domain, result.Domain)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
}
func TestManagerImpl_CreateZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, _, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "new.example.com",
Enabled: true,
EnableSearchDomain: true,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSZoneCreated, activityID)
}
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.NoError(t, err)
assert.NotNil(t, result)
assert.NotEmpty(t, result.ID)
assert.Equal(t, testAccountID, result.AccountID)
assert.Equal(t, inputZone.Name, result.Name)
assert.Equal(t, inputZone.Domain, result.Domain)
assert.Equal(t, inputZone.Enabled, result.Enabled)
assert.Equal(t, inputZone.EnableSearchDomain, result.EnableSearchDomain)
assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "new.example.com",
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("invalid group", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "new.example.com",
DistributionGroups: []string{"invalid-group"},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
})
t.Run("duplicate domain", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingZone := zones.NewZone(testAccountID, "Existing Zone", "duplicate.example.com", true, false, []string{testGroupID})
err := testStore.CreateZone(ctx, existingZone)
require.NoError(t, err)
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "duplicate.example.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "zone with domain duplicate.example.com already exists")
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.AlreadyExists, s.Type())
})
t.Run("peer DNS domain conflict", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
account, err := testStore.GetAccount(ctx, testAccountID)
require.NoError(t, err)
account.Settings.DNSDomain = "peers.example.com"
err = testStore.SaveAccount(ctx, account)
require.NoError(t, err)
inputZone := &zones.Zone{
Name: "Test Zone",
Domain: "peers.example.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "zone domain peers.example.com conflicts with peer DNS domain")
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.InvalidArgument, s.Type())
})
t.Run("default DNS domain conflict", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "Test Zone",
Domain: testDNSDomain,
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), fmt.Sprintf("zone domain %s conflicts with peer DNS domain", testDNSDomain))
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.InvalidArgument, s.Type())
})
}
func TestManagerImpl_UpdateZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingZone := zones.NewZone(testAccountID, "Old Name", "example.com", false, false, []string{testGroupID})
err := testStore.CreateZone(ctx, existingZone)
require.NoError(t, err)
updatedZone := &zones.Zone{
ID: existingZone.ID,
Name: "Updated Name",
Domain: "example.com",
Enabled: true,
EnableSearchDomain: true,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, existingZone.ID, targetID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSZoneUpdated, activityID)
}
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, updatedZone.Name, result.Name)
assert.Equal(t, updatedZone.Enabled, result.Enabled)
assert.Equal(t, updatedZone.EnableSearchDomain, result.EnableSearchDomain)
assert.True(t, storeEventCalled, "StoreEvent should have been called")
})
t.Run("domain change not allowed", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingZone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, existingZone)
require.NoError(t, err)
updatedZone := &zones.Zone{
ID: existingZone.ID,
Name: "Test Zone",
Domain: "different.com",
Enabled: true,
EnableSearchDomain: true,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "zone domain cannot be updated")
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.InvalidArgument, s.Type())
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedZone := &zones.Zone{
ID: testZoneID,
Name: "Updated Name",
Domain: "example.com",
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("zone not found", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedZone := &zones.Zone{
ID: "non-existent-zone",
Name: "Updated Name",
Domain: "example.com",
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
})
}
func TestManagerImpl_DeleteZone(t *testing.T) {
ctx := context.Background()
t.Run("success with records", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, zone)
require.NoError(t, err)
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = testStore.CreateDNSRecord(ctx, record1)
require.NoError(t, err)
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCallCount := 0
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCallCount++
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
}
err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.Equal(t, 3, storeEventCallCount)
_, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
require.Error(t, err)
zoneRecords, err := testStore.GetZoneDNSRecords(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
require.NoError(t, err)
assert.Empty(t, zoneRecords)
})
t.Run("success without records", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, zone)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, zone.ID, targetID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSZoneDeleted, activityID)
}
err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.True(t, storeEventCalled, "StoreEvent should have been called")
_, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
require.Error(t, err)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("zone not found", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone")
require.Error(t, err)
})
}

View File

@@ -0,0 +1,13 @@
package records
import (
"context"
)
type Manager interface {
GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*Record, error)
GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*Record, error)
CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error)
UpdateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error)
DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error
}

View File

@@ -0,0 +1,191 @@
package manager
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
manager records.Manager
}
func RegisterEndpoints(router *mux.Router, manager records.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
allRecords, err := h.manager.GetAllRecords(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
apiRecords := make([]*api.DNSRecord, 0, len(allRecords))
for _, record := range allRecords {
apiRecords = append(apiRecords, record.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, apiRecords)
}
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
var req api.PostApiDnsZonesZoneIdRecordsJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
record := new(records.Record)
record.FromAPIRequest(&req)
if err = record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
createdRecord, err := h.manager.CreateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse())
}
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
recordID := mux.Vars(r)["recordId"]
if recordID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
return
}
record, err := h.manager.GetRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, record.ToAPIResponse())
}
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
recordID := mux.Vars(r)["recordId"]
if recordID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
return
}
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
record := new(records.Record)
record.FromAPIRequest(&req)
record.ID = recordID
if err = record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
updatedRecord, err := h.manager.UpdateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse())
}
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
recordID := mux.Vars(r)["recordId"]
if recordID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
return
}
if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -0,0 +1,236 @@
package manager
import (
"context"
"fmt"
"strings"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type managerImpl struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
}
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager) records.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
}
}
func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
}
func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID)
}
func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var zone *zones.Zone
record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL)
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
}
err = validateRecordConflicts(ctx, transaction, zone, record)
if err != nil {
return err
}
if err = transaction.CreateDNSRecord(ctx, record); err != nil {
return fmt.Errorf("failed to create dns record: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta)
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return record, nil
}
func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var zone *zones.Zone
var record *records.Record
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
}
record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, updatedRecord.ID)
if err != nil {
return fmt.Errorf("failed to get record: %w", err)
}
hasChanges := record.Name != updatedRecord.Name || record.Type != updatedRecord.Type || record.Content != updatedRecord.Content
record.Name = updatedRecord.Name
record.Type = updatedRecord.Type
record.Content = updatedRecord.Content
record.TTL = updatedRecord.TTL
if hasChanges {
if err = validateRecordConflicts(ctx, transaction, zone, record); err != nil {
return err
}
}
if err = transaction.UpdateDNSRecord(ctx, record); err != nil {
return fmt.Errorf("failed to update dns record: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta)
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return record, nil
}
func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var record *records.Record
var zone *zones.Zone
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
}
record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, recordID)
if err != nil {
return fmt.Errorf("failed to get record: %w", err)
}
err = transaction.DeleteDNSRecord(ctx, accountID, zoneID, recordID)
if err != nil {
return fmt.Errorf("failed to delete dns record: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return err
}
meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta)
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
// validateRecordConflicts checks for duplicate records and CNAME conflicts
func validateRecordConflicts(ctx context.Context, transaction store.Store, zone *zones.Zone, record *records.Record) error {
if record.Name != zone.Domain && !strings.HasSuffix(record.Name, "."+zone.Domain) {
return status.Errorf(status.InvalidArgument, "record name does not belong to zone")
}
existingRecords, err := transaction.GetZoneDNSRecordsByName(ctx, store.LockingStrengthNone, zone.AccountID, zone.ID, record.Name)
if err != nil {
return fmt.Errorf("failed to check existing records: %w", err)
}
for _, existing := range existingRecords {
if existing.ID == record.ID {
continue
}
if existing.Type == record.Type && existing.Content == record.Content {
return status.Errorf(status.AlreadyExists, "identical record already exists")
}
if record.Type == records.RecordTypeCNAME || existing.Type == records.RecordTypeCNAME {
return status.Errorf(status.InvalidArgument,
"An A, AAAA, or CNAME record with name %s already exists", record.Name)
}
}
return nil
}

View File

@@ -0,0 +1,573 @@
package manager
import (
"context"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
testAccountID = "test-account-id"
testUserID = "test-user-id"
testRecordID = "test-record-id"
testGroupID = "test-group-id"
)
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
t.Helper()
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
err = testStore.SaveAccount(ctx, &types.Account{
Id: testAccountID,
Groups: map[string]*types.Group{
testGroupID: {
ID: testGroupID,
Name: "Test Group",
},
},
})
require.NoError(t, err)
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
err = testStore.CreateZone(ctx, zone)
require.NoError(t, err)
ctrl := gomock.NewController(t)
mockAccountManager := &mock_server.MockAccountManager{}
mockPermissionsManager := permissions.NewMockManager(ctrl)
manager := &managerImpl{
store: testStore,
accountManager: mockAccountManager,
permissionsManager: mockPermissionsManager,
}
return manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup
}
func TestManagerImpl_GetAllRecords(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, record1)
require.NoError(t, err)
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.Len(t, result, 2)
assert.Equal(t, record1.ID, result[0].ID)
assert.Equal(t, record2.ID, result[1].ID)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("permission validation error", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
assert.Nil(t, result)
})
}
func TestManagerImpl_GetRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, record)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
require.NoError(t, err)
assert.Equal(t, record.ID, result.ID)
assert.Equal(t, record.Name, result.Name)
assert.Equal(t, record.Type, result.Type)
assert.Equal(t, record.Content, result.Content)
assert.Equal(t, record.TTL, result.TTL)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
}
func TestManagerImpl_CreateRecord(t *testing.T) {
ctx := context.Background()
t.Run("success - A record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordCreated, activityID)
}
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.NotEmpty(t, result.ID)
assert.Equal(t, testAccountID, result.AccountID)
assert.Equal(t, zone.ID, result.ZoneID)
assert.Equal(t, inputRecord.Name, result.Name)
assert.Equal(t, inputRecord.Type, result.Type)
assert.Equal(t, inputRecord.Content, result.Content)
assert.Equal(t, inputRecord.TTL, result.TTL)
})
t.Run("success - AAAA record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "ipv6.example.com",
Type: records.RecordTypeAAAA,
Content: "2001:db8::1",
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordCreated, activityID)
}
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, inputRecord.Type, result.Type)
assert.Equal(t, inputRecord.Content, result.Content)
})
t.Run("success - CNAME record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "www.example.com",
Type: records.RecordTypeCNAME,
Content: "example.com",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordCreated, activityID)
}
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, inputRecord.Type, result.Type)
assert.Equal(t, inputRecord.Content, result.Content)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record name not in zone", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "api.different.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "does not belong to zone")
})
t.Run("duplicate record", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, existingRecord)
require.NoError(t, err)
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "identical record already exists")
})
t.Run("CNAME conflict with existing A record", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, existingRecord)
require.NoError(t, err)
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeCNAME,
Content: "example.com",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "already exists")
})
}
func TestManagerImpl_UpdateRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, existingRecord)
require.NoError(t, err)
updatedRecord := &records.Record{
ID: existingRecord.ID,
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.100", // Changed IP
TTL: 600, // Changed TTL
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, existingRecord.ID, targetID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordUpdated, activityID)
}
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, updatedRecord.Content, result.Content)
assert.Equal(t, updatedRecord.TTL, result.TTL)
assert.True(t, storeEventCalled, "StoreEvent should have been called")
})
t.Run("update only TTL - no validation", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, existingRecord)
require.NoError(t, err)
updatedRecord := &records.Record{
ID: existingRecord.ID,
Name: existingRecord.Name,
Type: existingRecord.Type,
Content: existingRecord.Content,
TTL: 600, // Only TTL changed
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
// Event should be stored
}
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, 600, result.TTL)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedRecord := &records.Record{
ID: testRecordID,
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.100",
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record not found", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedRecord := &records.Record{
ID: "non-existent-record",
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.100",
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
})
t.Run("update creates duplicate", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, record1)
require.NoError(t, err)
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err)
updatedRecord := &records.Record{
ID: record2.ID,
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "identical record already exists")
})
}
func TestManagerImpl_DeleteRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, record)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, record.ID, targetID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordDeleted, activityID)
}
err = manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
require.NoError(t, err)
assert.True(t, storeEventCalled, "StoreEvent should have been called")
_, err = testStore.GetDNSRecordByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID, record.ID)
require.Error(t, err)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record not found", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record")
require.Error(t, err)
})
}

View File

@@ -0,0 +1,129 @@
package records
import (
"errors"
"net"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/http/api"
)
type RecordType string
const (
RecordTypeA RecordType = "A"
RecordTypeAAAA RecordType = "AAAA"
RecordTypeCNAME RecordType = "CNAME"
)
type Record struct {
AccountID string `gorm:"index"`
ZoneID string `gorm:"index"`
ID string `gorm:"primaryKey"`
Name string
Type RecordType
Content string
TTL int
}
func NewRecord(accountID, zoneID, name string, recordType RecordType, content string, ttl int) *Record {
return &Record{
ID: xid.New().String(),
AccountID: accountID,
ZoneID: zoneID,
Name: name,
Type: recordType,
Content: content,
TTL: ttl,
}
}
func (r *Record) ToAPIResponse() *api.DNSRecord {
recordType := api.DNSRecordType(r.Type)
return &api.DNSRecord{
Id: r.ID,
Name: r.Name,
Type: recordType,
Content: r.Content,
Ttl: r.TTL,
}
}
func (r *Record) FromAPIRequest(req *api.DNSRecordRequest) {
r.Name = req.Name
r.Type = RecordType(req.Type)
r.Content = req.Content
r.TTL = req.Ttl
}
func (r *Record) Validate() error {
if r.Name == "" {
return errors.New("record name is required")
}
if !util.IsValidDomain(r.Name) {
return errors.New("invalid record name format")
}
if r.Type == "" {
return errors.New("record type is required")
}
switch r.Type {
case RecordTypeA:
if err := validateIPv4(r.Content); err != nil {
return err
}
case RecordTypeAAAA:
if err := validateIPv6(r.Content); err != nil {
return err
}
case RecordTypeCNAME:
if !util.IsValidDomain(r.Content) {
return errors.New("invalid CNAME record format")
}
default:
return errors.New("invalid record type, must be A, AAAA, or CNAME")
}
if r.TTL < 0 {
return errors.New("TTL cannot be negative")
}
return nil
}
func (r *Record) EventMeta(zoneID, zoneName string) map[string]any {
return map[string]any{
"name": r.Name,
"type": string(r.Type),
"content": r.Content,
"ttl": r.TTL,
"zone_id": zoneID,
"zone_name": zoneName,
}
}
func validateIPv4(content string) error {
if content == "" {
return errors.New("A record is required") //nolint:staticcheck
}
ip := net.ParseIP(content)
if ip == nil || ip.To4() == nil {
return errors.New("A record must be a valid IPv4 address") //nolint:staticcheck
}
return nil
}
func validateIPv6(content string) error {
if content == "" {
return errors.New("AAAA record is required")
}
ip := net.ParseIP(content)
if ip == nil || ip.To4() != nil {
return errors.New("AAAA record must be a valid IPv6 address")
}
return nil
}

View File

@@ -0,0 +1,89 @@
package zones
import (
"errors"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/http/api"
)
type Zone struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string
Enabled bool
EnableSearchDomain bool
DistributionGroups []string `gorm:"serializer:json"`
Records []*records.Record `gorm:"foreignKey:ZoneID;references:ID"`
}
func NewZone(accountID, name, domain string, enabled, enableSearchDomain bool, distributionGroups []string) *Zone {
return &Zone{
ID: xid.New().String(),
AccountID: accountID,
Name: name,
Domain: domain,
Enabled: enabled,
EnableSearchDomain: enableSearchDomain,
DistributionGroups: distributionGroups,
}
}
func (z *Zone) ToAPIResponse() *api.Zone {
apiRecords := make([]api.DNSRecord, 0, len(z.Records))
for _, record := range z.Records {
if apiRecord := record.ToAPIResponse(); apiRecord != nil {
apiRecords = append(apiRecords, *apiRecord)
}
}
return &api.Zone{
DistributionGroups: z.DistributionGroups,
Domain: z.Domain,
EnableSearchDomain: z.EnableSearchDomain,
Enabled: z.Enabled,
Id: z.ID,
Name: z.Name,
Records: apiRecords,
}
}
func (z *Zone) FromAPIRequest(req *api.ZoneRequest) {
z.Name = req.Name
z.Domain = req.Domain
z.EnableSearchDomain = req.EnableSearchDomain
z.DistributionGroups = req.DistributionGroups
enabled := true
if req.Enabled != nil {
enabled = *req.Enabled
}
z.Enabled = enabled
}
func (z *Zone) Validate() error {
if z.Name == "" {
return errors.New("zone name is required")
}
if len(z.Name) > 255 {
return errors.New("zone name exceeds maximum length of 255 characters")
}
if !util.IsValidDomain(z.Domain) {
return errors.New("invalid zone domain format")
}
if len(z.DistributionGroups) == 0 {
return errors.New("at least one distribution group is required")
}
return nil
}
func (z *Zone) EventMeta() map[string]any {
return map[string]any{"name": z.Name, "domain": z.Domain}
}

View File

@@ -92,7 +92,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController(), s.IdpManager())
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager())
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
@@ -144,7 +144,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
}
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider())
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider())
if err != nil {
log.Fatalf("failed to create management server: %v", err)
}

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
@@ -16,6 +17,7 @@ import (
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
)
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
@@ -24,6 +26,12 @@ func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
})
}
func (s *BaseServer) JobManager() *job.Manager {
return Create(s, func() *job.Manager {
return job.NewJobManager(s.Metrics(), s.Store(), s.PeersManager())
})
}
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
return Create(s, func() integrated_validator.IntegratedValidator {
integratedPeerValidator, err := integrations.NewIntegratedValidator(

View File

@@ -8,6 +8,10 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/geolocation"
@@ -83,7 +87,7 @@ func (s *BaseServer) PeersManager() peers.Manager {
func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager {
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
if err != nil {
log.Fatalf("failed to create account manager: %v", err)
}
@@ -158,3 +162,15 @@ func (s *BaseServer) NetworksManager() networks.Manager {
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
})
}
func (s *BaseServer) ZonesManager() zones.Manager {
return Create(s, func() zones.Manager {
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain())
})
}
func (s *BaseServer) RecordsManager() records.Manager {
return Create(s, func() records.Manager {
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
})
}

View File

@@ -374,8 +374,10 @@ func shouldUsePortRange(rule *proto.FirewallRule) bool {
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
SearchDomainDisabled: zone.SearchDomainDisabled,
NonAuthoritative: zone.NonAuthoritative,
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
@@ -432,9 +434,16 @@ func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfi
if config.CLIAuthAudience != "" {
audience = config.CLIAuthAudience
}
audiences := []string{config.AuthAudience}
if config.CLIAuthAudience != "" && config.CLIAuthAudience != config.AuthAudience {
audiences = append(audiences, config.CLIAuthAudience)
}
return &proto.JWTConfig{
Issuer: issuer,
Audience: audience,
Audiences: audiences,
KeysLocation: keysLocation,
}
}

View File

@@ -6,9 +6,12 @@ import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
)
func TestToProtocolDNSConfigWithCache(t *testing.T) {
@@ -148,3 +151,52 @@ func generateTestData(size int) nbdns.Config {
return config
}
func TestBuildJWTConfig_Audiences(t *testing.T) {
tests := []struct {
name string
authAudience string
cliAuthAudience string
expectedAudiences []string
expectedAudience string
}{
{
name: "only_auth_audience",
authAudience: "dashboard-aud",
cliAuthAudience: "",
expectedAudiences: []string{"dashboard-aud"},
expectedAudience: "dashboard-aud",
},
{
name: "both_audiences_different",
authAudience: "dashboard-aud",
cliAuthAudience: "cli-aud",
expectedAudiences: []string{"dashboard-aud", "cli-aud"},
expectedAudience: "cli-aud",
},
{
name: "both_audiences_same",
authAudience: "same-aud",
cliAuthAudience: "same-aud",
expectedAudiences: []string{"same-aud"},
expectedAudience: "same-aud",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
config := &nbconfig.HttpServerConfig{
AuthIssuer: "https://issuer.example.com",
AuthAudience: tc.authAudience,
CLIAuthAudience: tc.cliAuthAudience,
}
result := buildJWTConfig(config, nil)
assert.NotNil(t, result)
assert.Equal(t, tc.expectedAudiences, result.Audiences, "audiences should match expected")
//nolint:staticcheck // SA1019: Testing backwards compatibility - Audience field must still be populated
assert.Equal(t, tc.expectedAudience, result.Audience, "audience should match expected")
})
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"net/netip"
"os"
@@ -26,6 +27,7 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/store"
@@ -57,6 +59,7 @@ type Server struct {
accountManager account.Manager
settingsManager settings.Manager
proto.UnimplementedManagementServiceServer
jobManager *job.Manager
config *nbconfig.Config
secretsManager SecretsManager
appMetrics telemetry.AppMetrics
@@ -82,6 +85,7 @@ func NewServer(
config *nbconfig.Config,
accountManager account.Manager,
settingsManager settings.Manager,
jobManager *job.Manager,
secretsManager SecretsManager,
appMetrics telemetry.AppMetrics,
authManager auth.Manager,
@@ -114,6 +118,7 @@ func NewServer(
}
return &Server{
jobManager: jobManager,
accountManager: accountManager,
settingsManager: settingsManager,
config: config,
@@ -169,6 +174,40 @@ func getRealIP(ctx context.Context) net.IP {
return nil
}
func (s *Server) Job(srv proto.ManagementService_JobServer) error {
reqStart := time.Now()
ctx := srv.Context()
peerKey, err := s.handleHandshake(ctx, srv)
if err != nil {
return err
}
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
return status.Errorf(codes.PermissionDenied, "peer is not registered")
}
return err
}
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
if err != nil {
return status.Errorf(codes.Unauthenticated, "peer is not registered")
}
s.startResponseReceiver(ctx, srv)
updates := s.jobManager.CreateJobChannel(ctx, accountID, peer.ID)
log.WithContext(ctx).Debugf("Job: took %v", time.Since(reqStart))
return s.sendJobsLoop(ctx, accountID, peerKey, peer, updates, srv)
}
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
@@ -289,6 +328,70 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
}
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
hello, err := srv.Recv()
if err != nil {
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "missing hello: %v", err)
}
jobReq := &proto.JobRequest{}
peerKey, err := s.parseRequest(ctx, hello, jobReq)
if err != nil {
return wgtypes.Key{}, err
}
return peerKey, nil
}
func (s *Server) startResponseReceiver(ctx context.Context, srv proto.ManagementService_JobServer) {
go func() {
for {
msg, err := srv.Recv()
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
return
}
log.WithContext(ctx).Warnf("recv job response error: %v", err)
return
}
jobResp := &proto.JobResponse{}
if _, err := s.parseRequest(ctx, msg, jobResp); err != nil {
log.WithContext(ctx).Warnf("invalid job response: %v", err)
continue
}
if err := s.jobManager.HandleResponse(ctx, jobResp, msg.WgPubKey); err != nil {
log.WithContext(ctx).Errorf("handle job response failed: %v", err)
}
}
}()
}
func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates *job.Channel, srv proto.ManagementService_JobServer) error {
// todo figure out better error handling strategy
defer s.jobManager.CloseChannel(ctx, accountID, peer.ID)
for {
event, err := updates.Event(ctx)
if err != nil {
if errors.Is(err, job.ErrJobChannelClosed) {
log.WithContext(ctx).Debugf("jobs channel for peer %s was closed", peerKey.String())
return nil
}
// happens when connection drops, e.g. client disconnects
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
return ctx.Err()
}
if err := s.sendJob(ctx, peerKey, event, srv); err != nil {
log.WithContext(ctx).Warnf("send job failed: %v", err)
return nil
}
}
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
@@ -306,7 +409,6 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
return nil
}
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
@@ -336,7 +438,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.SendMsg(&proto.EncryptedMessage{
err = srv.Send(&proto.EncryptedMessage{
WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
})
@@ -348,6 +450,31 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
return nil
}
// sendJob encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Event, srv proto.ManagementService_JobServer) error {
wgKey, err := s.secretsManager.GetWGKey()
if err != nil {
log.WithContext(ctx).Errorf("failed to get wg key for peer %s: %v", peerKey.String(), err)
return status.Errorf(codes.Internal, "failed processing job message")
}
encryptedResp, err := encryption.EncryptMessage(peerKey, wgKey, job.Request)
if err != nil {
log.WithContext(ctx).Errorf("failed to encrypt job for peer %s: %v", peerKey.String(), err)
return status.Errorf(codes.Internal, "failed processing job message")
}
err = srv.Send(&proto.EncryptedMessage{
WgPubKey: wgKey.PublicKey().String(),
Body: encryptedResp,
})
if err != nil {
return status.Errorf(codes.Internal, "failed sending job message")
}
log.WithContext(ctx).Debugf("sent a job to peer: %s", peerKey.String())
return nil
}
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock()
@@ -690,8 +817,8 @@ func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty,
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer, dnsFwdPort int64) error {
var err error
var turnToken *Token
if s.config.TURNConfig != nil && s.config.TURNConfig.TimeBasedCredentials {
turnToken, err = s.secretsManager.GenerateTurnToken()
if err != nil {

View File

@@ -15,6 +15,7 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/shared/auth"
cacheStore "github.com/eko/gocache/lib/v4/store"
@@ -70,6 +71,7 @@ type DefaultAccountManager struct {
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
cacheLoading map[string]chan struct{}
networkMapController network_map.Controller
jobManager *job.Manager
idpManager idp.Manager
cacheManager *nbcache.AccountUserDataCache
externalCacheManager nbcache.UserDataCache
@@ -178,6 +180,7 @@ func BuildManager(
config *nbconfig.Config,
store store.Store,
networkMapController network_map.Controller,
jobManager *job.Manager,
idpManager idp.Manager,
singleAccountModeDomain string,
eventStore activity.Store,
@@ -200,6 +203,7 @@ func BuildManager(
config: config,
geo: geo,
networkMapController: networkMapController,
jobManager: jobManager,
idpManager: idpManager,
ctx: context.Background(),
cacheMux: sync.Mutex{},
@@ -295,7 +299,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return err
}
if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil {
if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil {
return err
}
@@ -388,7 +392,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return newSettings, nil
}
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error {
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@@ -402,6 +406,18 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, new
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
}
if newSettings.DNSDomain != oldSettings.DNSDomain && newSettings.DNSDomain != "" {
existingZone, err := transaction.GetZoneByDomain(ctx, accountID, newSettings.DNSDomain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing zone: %w", err)
}
}
if existingZone != nil {
return status.Errorf(status.InvalidArgument, "peer DNS domain %s conflicts with existing custom DNS zone", newSettings.DNSDomain)
}
}
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID)
}

View File

@@ -32,6 +32,7 @@ type Manager interface {
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error)
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
UpdateUserPassword(ctx context.Context, accountID, currentUserID, targetUserID string, oldPassword, newPassword string) error
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error
@@ -129,4 +130,7 @@ type Manager interface {
CreateIdentityProvider(ctx context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
UpdateIdentityProvider(ctx context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
DeleteIdentityProvider(ctx context.Context, accountID, idpID, userID string) error
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
}

View File

@@ -27,6 +27,7 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
@@ -34,6 +35,7 @@ import (
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -397,7 +399,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
}
@@ -1676,7 +1678,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
},
}
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
assert.Len(t, routes, 2)
routeIDs := make(map[route.ID]struct{}, 2)
@@ -1686,7 +1688,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
assert.Contains(t, routeIDs, route.ID("route-2"))
assert.Contains(t, routeIDs, route.ID("route-3"))
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
assert.Len(t, emptyRoutes, 0)
}
@@ -2095,6 +2097,35 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T)
}
}
func TestDefaultAccountManager_UpdateAccountSettings_DNSDomainConflict(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
ctx := context.Background()
err = manager.Store.CreateZone(ctx, &zones.Zone{
ID: "test-zone-id",
AccountID: accountID,
Name: "Test Zone",
Domain: "custom.example.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{},
})
require.NoError(t, err, "unable to create custom DNS zone")
_, err = manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{
DNSDomain: "custom.example.com",
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
Extra: &types.ExtraSettings{},
})
require.Error(t, err, "expecting to fail when DNS domain conflicts with custom zone")
assert.Contains(t, err.Error(), "conflicts with existing custom DNS zone")
}
func TestAccount_GetExpiredPeers(t *testing.T) {
type test struct {
name string
@@ -2993,13 +3024,14 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
AnyTimes()
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, nil, err
}

Some files were not shown because too many files have changed in this diff Show More